Merge pull request #6 from VinciGit00/patch-3

Patch 3
This commit is contained in:
Matteo Vedovati 2024-09-24 17:32:46 +02:00 committed by GitHub
commit 9927397c10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -73,7 +73,7 @@ class GenerateCodeNode(BaseNode):
"semantic": 3
})
self.output_schema = node_config.get("schema") # get JSON output schema
self.output_schema = node_config.get("schema")
def execute(self, state: dict) -> dict:
"""
@ -98,15 +98,15 @@ class GenerateCodeNode(BaseNode):
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0] # get user prompt
refined_prompt = input_data[1] # get refined prompt
html_info = input_data[2] # get html analysis
reduced_html = input_data[3] # get html code
answer = input_data[4] # get answer generated from the generate answer node for verification
user_prompt = input_data[0]
refined_prompt = input_data[1]
html_info = input_data[2]
reduced_html = input_data[3]
answer = input_data[4]
self.raw_html = state['original_html'][0].page_content
simplefied_schema = str(transform_schema(self.output_schema.schema())) # get JSON output schema
simplefied_schema = str(transform_schema(self.output_schema.schema()))
reasoning_state = {
"user_input": user_prompt,
@ -160,9 +160,7 @@ class GenerateCodeNode(BaseNode):
self.logger.info(f"--- (Checking if the informations exctrcated are the ones Requested) ---")
state = self.semantic_comparison_loop(state)
if state["errors"]["semantic"]:
continue
# If we've made it here, the code is valid and produces the correct output
continue
break
if state["iteration"] == self.max_iterations["overall"] and (state["errors"]["syntax"] or state["errors"]["execution"] or state["errors"]["validation"] or state["errors"]["semantic"]):
@ -488,7 +486,6 @@ class GenerateCodeNode(BaseNode):
Human: Are the generated result and reference result semantically equivalent? If not, what are the key differences?
Assistant: Let's analyze the two results carefully:
"""
prompt = PromptTemplate(
@ -576,28 +573,23 @@ class GenerateCodeNode(BaseNode):
'__builtins__': __builtins__,
}
# Capture stdout
old_stdout = sys.stdout
sys.stdout = StringIO()
try:
# Execute the function code in the sandbox
exec(function_code, sandbox_globals)
# Get the extract_data function from the sandbox
extract_data = sandbox_globals.get('extract_data')
if not extract_data:
raise NameError("Function 'extract_data' not found in the generated code.")
# Execute the extract_data function with the provided HTML
result = extract_data(self.raw_html)
return True, result
except Exception as e:
return False, f"Error during execution: {str(e)}"
finally:
# Restore stdout
sys.stdout = old_stdout
def validate_dict(self, data: dict, schema):
@ -609,17 +601,13 @@ class GenerateCodeNode(BaseNode):
return False, errors
def extract_code(self, code: str) -> str:
# Pattern to match the code inside a code block
pattern = r'```(?:python)?\n(.*?)```'
# Search for the code block, if present
match = re.search(pattern, code, re.DOTALL)
# If a code block is found, return the code, otherwise return the entire string
return match.group(1) if match else code
def normalize_string(s: str) -> str:
# Convert to lowercase, remove extra spaces, and strip punctuation
return ''.join(c for c in s.lower().strip() if c not in string.punctuation)
def normalize_string(s: str) -> str:
@ -641,12 +629,10 @@ def normalize_dict(d: dict) -> dict:
normalized[key] = normalize_dict(value)
elif isinstance(value, list):
if all(isinstance(v, dict) for v in value):
# For lists of dicts, normalize each dict and sort based on their string representation
normalized[key] = sorted(
normalize_dict(v) for v in value
)
else:
# For lists of primitives, sort normally
normalized[key] = sorted(
normalize_dict(v) if isinstance(v, dict)
else normalize_string(v) if isinstance(v, str)