From 7c39b06640d4c876fc728e9eaaa7156bbab12fae Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Fri, 27 Sep 2024 09:34:56 +0200 Subject: [PATCH] update for pylint --- scrapegraphai/nodes/generate_code_node.py | 165 ++++++++++++--------- scrapegraphai/nodes/prompt_refiner_node.py | 9 +- 2 files changed, 99 insertions(+), 75 deletions(-) diff --git a/scrapegraphai/nodes/generate_code_node.py b/scrapegraphai/nodes/generate_code_node.py index 1174a4aa..35de0da0 100644 --- a/scrapegraphai/nodes/generate_code_node.py +++ b/scrapegraphai/nodes/generate_code_node.py @@ -2,20 +2,20 @@ GenerateCodeNode Module """ from typing import Any, Dict, List, Optional -from langchain.prompts import PromptTemplate -from langchain.output_parsers import ResponseSchema, StructuredOutputParser -from langchain_core.output_parsers import StrOutputParser -from langchain_core.runnables import RunnableParallel -from langchain_core.utils.pydantic import is_basemodel_subclass -from langchain_community.chat_models import ChatOllama import ast import sys from io import StringIO -from bs4 import BeautifulSoup import re -from tqdm import tqdm -from .base_node import BaseNode +import json from pydantic import ValidationError +from langchain.prompts import PromptTemplate +from langchain.output_parsers import ResponseSchema, StructuredOutputParser +from langchain_core.output_parsers import StrOutputParser +from langchain_community.chat_models import ChatOllama +from bs4 import BeautifulSoup +from ..prompts import ( + TEMPLATE_INIT_CODE_GENERATION, TEMPLATE_SEMANTIC_COMPARISON +) from ..utils import (transform_schema, extract_code, syntax_focused_analysis, syntax_focused_code_generation, @@ -23,15 +23,14 @@ from ..utils import (transform_schema, validation_focused_analysis, validation_focused_code_generation, semantic_focused_analysis, semantic_focused_code_generation, are_content_equal) +from .base_node import BaseNode from jsonschema import validate, ValidationError -import json -from ..prompts import ( - TEMPLATE_INIT_CODE_GENERATION, TEMPLATE_SEMANTIC_COMPARISON -) + class GenerateCodeNode(BaseNode): """ - A node that generates Python code for a function that extracts data from HTML based on a output schema. + A node that generates Python code for a function that extracts data + from HTML based on a output schema. Attributes: llm_model: An instance of a language model client, configured for generating answers. @@ -72,7 +71,7 @@ class GenerateCodeNode(BaseNode): ) self.additional_info = node_config.get("additional_info") - + self.max_iterations = node_config.get("max_iterations", { "overall": 10, "syntax": 3, @@ -80,7 +79,7 @@ class GenerateCodeNode(BaseNode): "validation": 3, "semantic": 3 }) - + self.output_schema = node_config.get("schema") def execute(self, state: dict) -> dict: @@ -97,25 +96,26 @@ class GenerateCodeNode(BaseNode): Raises: KeyError: If the input keys are not found in the state, indicating that the necessary information for generating an answer is missing. - RuntimeError: If the maximum number of iterations is reached without obtaining the desired code. + RuntimeError: If the maximum number of iterations is + reached without obtaining the desired code. """ - + self.logger.info(f"--- Executing {self.node_name} Node ---") input_keys = self.get_input_keys(state) - + input_data = [state[key] for key in input_keys] - + user_prompt = input_data[0] refined_prompt = input_data[1] html_info = input_data[2] reduced_html = input_data[3] - answer = input_data[4] - + answer = input_data[4] + self.raw_html = state['original_html'][0].page_content - + simplefied_schema = str(transform_schema(self.output_schema.schema())) - + reasoning_state = { "user_input": user_prompt, "json_schema": simplefied_schema, @@ -133,89 +133,101 @@ class GenerateCodeNode(BaseNode): }, "iteration": 0 } - - + final_state = self.overall_reasoning_loop(reasoning_state) - + state.update({self.output[0]: final_state["generated_code"]}) return state - + def overall_reasoning_loop(self, state: dict) -> dict: + """ + overrall_reasoning_loop + """ self.logger.info(f"--- (Generating Code) ---") state["generated_code"] = self.generate_initial_code(state) state["generated_code"] = extract_code(state["generated_code"]) - + while state["iteration"] < self.max_iterations["overall"]: state["iteration"] += 1 if self.verbose: self.logger.info(f"--- Iteration {state['iteration']} ---") - + self.logger.info(f"--- (Checking Code Syntax) ---") state = self.syntax_reasoning_loop(state) if state["errors"]["syntax"]: continue - + self.logger.info(f"--- (Executing the Generated Code) ---") state = self.execution_reasoning_loop(state) if state["errors"]["execution"]: continue - + self.logger.info(f"--- (Validate the Code Output Schema) ---") state = self.validation_reasoning_loop(state) if state["errors"]["validation"]: continue - + self.logger.info(f"--- (Checking if the informations exctrcated are the ones Requested) ---") state = self.semantic_comparison_loop(state) if state["errors"]["semantic"]: - continue + continue break - + if state["iteration"] == self.max_iterations["overall"] and (state["errors"]["syntax"] or state["errors"]["execution"] or state["errors"]["validation"] or state["errors"]["semantic"]): raise RuntimeError("Max iterations reached without obtaining the desired code.") - + self.logger.info(f"--- (Code Generated Correctly) ---") - + return state - + def syntax_reasoning_loop(self, state: dict) -> dict: + """ + syntax reasoning loop + """ for _ in range(self.max_iterations["syntax"]): syntax_valid, syntax_message = self.syntax_check(state["generated_code"]) if syntax_valid: state["errors"]["syntax"] = [] return state - + state["errors"]["syntax"] = [syntax_message] self.logger.info(f"--- (Synax Error Found: {syntax_message}) ---") analysis = syntax_focused_analysis(state, self.llm_model) - self.logger.info(f"--- (Regenerating Code to fix the Error) ---") - state["generated_code"] = syntax_focused_code_generation(state, analysis, self.llm_model) + self.logger.info(f"""--- (Regenerating Code + to fix the Error) ---""") + state["generated_code"] = syntax_focused_code_generation(state, + analysis, self.llm_model) state["generated_code"] = extract_code(state["generated_code"]) return state - + def execution_reasoning_loop(self, state: dict) -> dict: + """ + execution of the reasoning loop + """ for _ in range(self.max_iterations["execution"]): execution_success, execution_result = self.create_sandbox_and_execute(state["generated_code"]) if execution_success: state["execution_result"] = execution_result state["errors"]["execution"] = [] return state - + state["errors"]["execution"] = [execution_result] self.logger.info(f"--- (Code Execution Error: {execution_result}) ---") analysis = execution_focused_analysis(state, self.llm_model) self.logger.info(f"--- (Regenerating Code to fix the Error) ---") - state["generated_code"] = execution_focused_code_generation(state, analysis, self.llm_model) + state["generated_code"] = execution_focused_code_generation(state, + analysis, self.llm_model) state["generated_code"] = extract_code(state["generated_code"]) return state - + def validation_reasoning_loop(self, state: dict) -> dict: for _ in range(self.max_iterations["validation"]): - validation, errors = self.validate_dict(state["execution_result"], self.output_schema.schema()) + validation, errors = self.validate_dict(state["execution_result"], + self.output_schema.schema()) if validation: state["errors"]["validation"] = [] return state - + state["errors"]["validation"] = errors self.logger.info(f"--- (Code Output not compliant to the deisred Output Schema) ---") analysis = validation_focused_analysis(state, self.llm_model) @@ -223,14 +235,14 @@ class GenerateCodeNode(BaseNode): state["generated_code"] = validation_focused_code_generation(state, analysis, self.llm_model) state["generated_code"] = extract_code(state["generated_code"]) return state - + def semantic_comparison_loop(self, state: dict) -> dict: for _ in range(self.max_iterations["semantic"]): comparison_result = self.semantic_comparison(state["execution_result"], state["reference_answer"]) if comparison_result["are_semantically_equivalent"]: state["errors"]["semantic"] = [] return state - + state["errors"]["semantic"] = comparison_result["differences"] self.logger.info(f"--- (The informations exctrcated are not the all ones requested) ---") analysis = semantic_focused_analysis(state, comparison_result, self.llm_model) @@ -238,8 +250,11 @@ class GenerateCodeNode(BaseNode): state["generated_code"] = semantic_focused_code_generation(state, analysis, self.llm_model) state["generated_code"] = extract_code(state["generated_code"]) return state - + def generate_initial_code(self, state: dict) -> str: + """ + function for generating the initial code + """ prompt = PromptTemplate( template=TEMPLATE_INIT_CODE_GENERATION, partial_variables={ @@ -255,22 +270,29 @@ class GenerateCodeNode(BaseNode): chain = prompt | self.llm_model | output_parser generated_code = chain.invoke({}) return generated_code - + def semantic_comparison(self, generated_result: Any, reference_result: Any) -> Dict[str, Any]: + """ + semtantic comparison formula + """ reference_result_dict = self.output_schema(**reference_result).dict() - - # Check if generated result and reference result are actually equal if are_content_equal(generated_result, reference_result_dict): return { "are_semantically_equivalent": True, "differences": [], "explanation": "The generated result and reference result are exactly equal." } - + response_schemas = [ - ResponseSchema(name="are_semantically_equivalent", description="Boolean indicating if the results are semantically equivalent"), - ResponseSchema(name="differences", description="List of semantic differences between the results, if any"), - ResponseSchema(name="explanation", description="Detailed explanation of the comparison and reasoning") + ResponseSchema(name="are_semantically_equivalent", + description="""Boolean indicating if the + results are semantically equivalent"""), + ResponseSchema(name="differences", + description="""List of semantic differences + between the results, if any"""), + ResponseSchema(name="explanation", + description="""Detailed explanation of the + comparison and reasoning""") ] output_parser = StructuredOutputParser.from_response_schemas(response_schemas) @@ -285,8 +307,11 @@ class GenerateCodeNode(BaseNode): "generated_result": json.dumps(generated_result, indent=2), "reference_result": json.dumps(reference_result_dict, indent=2) }) - + def syntax_check(self, code): + """ + syntax checker + """ try: ast.parse(code) return True, "Syntax is correct." @@ -294,36 +319,40 @@ class GenerateCodeNode(BaseNode): return False, f"Syntax error: {str(e)}" def create_sandbox_and_execute(self, function_code): - # Create a sandbox environment + """ + Create a sandbox environment + """ sandbox_globals = { 'BeautifulSoup': BeautifulSoup, 're': re, '__builtins__': __builtins__, } - + old_stdout = sys.stdout sys.stdout = StringIO() - + try: exec(function_code, sandbox_globals) - + extract_data = sandbox_globals.get('extract_data') - + if not extract_data: raise NameError("Function 'extract_data' not found in the generated code.") - - result = extract_data(self.raw_html) - + + result = extract_data(self.raw_html) return True, result except Exception as e: return False, f"Error during execution: {str(e)}" finally: sys.stdout = old_stdout - + def validate_dict(self, data: dict, schema): + """ + validate_dict method + """ try: validate(instance=data, schema=schema) return True, None except ValidationError as e: errors = e.errors() - return False, errors \ No newline at end of file + return False, errors diff --git a/scrapegraphai/nodes/prompt_refiner_node.py b/scrapegraphai/nodes/prompt_refiner_node.py index dfb62eb6..66c960ff 100644 --- a/scrapegraphai/nodes/prompt_refiner_node.py +++ b/scrapegraphai/nodes/prompt_refiner_node.py @@ -4,12 +4,7 @@ PromptRefinerNode Module from typing import List, Optional from langchain.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser -from langchain_core.runnables import RunnableParallel -from langchain_core.utils.pydantic import is_basemodel_subclass -from langchain_openai import ChatOpenAI, AzureChatOpenAI -from langchain_mistralai import ChatMistralAI from langchain_community.chat_models import ChatOllama -from tqdm import tqdm from .base_node import BaseNode from ..utils import transform_schema from ..prompts import ( @@ -61,7 +56,7 @@ class PromptRefinerNode(BaseNode): ) self.additional_info = node_config.get("additional_info") - + self.output_schema = node_config.get("schema") def execute(self, state: dict) -> dict: @@ -85,7 +80,7 @@ class PromptRefinerNode(BaseNode): user_prompt = state['user_prompt'] self.simplefied_schema = transform_schema(self.output_schema.schema()) - + if self.additional_info is not None: prompt = PromptTemplate( template=TEMPLATE_REFINER_WITH_CONTEXT,