From 04ac7362d4735b77d946721a420f4e2c0539de41 Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra <88108002+VinciGit00@users.noreply.github.com> Date: Tue, 24 Sep 2024 16:57:52 +0200 Subject: [PATCH] i don't like comments --- scrapegraphai/nodes/generate_code_node.py | 30 ++++++----------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/scrapegraphai/nodes/generate_code_node.py b/scrapegraphai/nodes/generate_code_node.py index 4b28bb71..ec0c310a 100644 --- a/scrapegraphai/nodes/generate_code_node.py +++ b/scrapegraphai/nodes/generate_code_node.py @@ -73,7 +73,7 @@ class GenerateCodeNode(BaseNode): "semantic": 3 }) - self.output_schema = node_config.get("schema") # get JSON output schema + self.output_schema = node_config.get("schema") def execute(self, state: dict) -> dict: """ @@ -98,15 +98,15 @@ class GenerateCodeNode(BaseNode): input_data = [state[key] for key in input_keys] - user_prompt = input_data[0] # get user prompt - refined_prompt = input_data[1] # get refined prompt - html_info = input_data[2] # get html analysis - reduced_html = input_data[3] # get html code - answer = input_data[4] # get answer generated from the generate answer node for verification + user_prompt = input_data[0] + refined_prompt = input_data[1] + html_info = input_data[2] + reduced_html = input_data[3] + answer = input_data[4] self.raw_html = state['original_html'][0].page_content - simplefied_schema = str(transform_schema(self.output_schema.schema())) # get JSON output schema + simplefied_schema = str(transform_schema(self.output_schema.schema())) reasoning_state = { "user_input": user_prompt, @@ -160,9 +160,7 @@ class GenerateCodeNode(BaseNode): self.logger.info(f"--- (Checking if the informations exctrcated are the ones Requested) ---") state = self.semantic_comparison_loop(state) if state["errors"]["semantic"]: - continue - - # If we've made it here, the code is valid and produces the correct output + continue break if state["iteration"] == self.max_iterations["overall"] and (state["errors"]["syntax"] or state["errors"]["execution"] or state["errors"]["validation"] or state["errors"]["semantic"]): @@ -488,7 +486,6 @@ class GenerateCodeNode(BaseNode): Human: Are the generated result and reference result semantically equivalent? If not, what are the key differences? Assistant: Let's analyze the two results carefully: - """ prompt = PromptTemplate( @@ -576,28 +573,23 @@ class GenerateCodeNode(BaseNode): '__builtins__': __builtins__, } - # Capture stdout old_stdout = sys.stdout sys.stdout = StringIO() try: - # Execute the function code in the sandbox exec(function_code, sandbox_globals) - # Get the extract_data function from the sandbox extract_data = sandbox_globals.get('extract_data') if not extract_data: raise NameError("Function 'extract_data' not found in the generated code.") - # Execute the extract_data function with the provided HTML result = extract_data(self.raw_html) return True, result except Exception as e: return False, f"Error during execution: {str(e)}" finally: - # Restore stdout sys.stdout = old_stdout def validate_dict(self, data: dict, schema): @@ -609,17 +601,13 @@ class GenerateCodeNode(BaseNode): return False, errors def extract_code(self, code: str) -> str: - # Pattern to match the code inside a code block pattern = r'```(?:python)?\n(.*?)```' - # Search for the code block, if present match = re.search(pattern, code, re.DOTALL) - # If a code block is found, return the code, otherwise return the entire string return match.group(1) if match else code def normalize_string(s: str) -> str: - # Convert to lowercase, remove extra spaces, and strip punctuation return ''.join(c for c in s.lower().strip() if c not in string.punctuation) def normalize_string(s: str) -> str: @@ -641,12 +629,10 @@ def normalize_dict(d: dict) -> dict: normalized[key] = normalize_dict(value) elif isinstance(value, list): if all(isinstance(v, dict) for v in value): - # For lists of dicts, normalize each dict and sort based on their string representation normalized[key] = sorted( normalize_dict(v) for v in value ) else: - # For lists of primitives, sort normally normalized[key] = sorted( normalize_dict(v) if isinstance(v, dict) else normalize_string(v) if isinstance(v, str)