mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-28 21:01:55 +08:00
362 lines
14 KiB
Python
362 lines
14 KiB
Python
"""
|
|
GenerateCodeNode Module
|
|
"""
|
|
from typing import Any, Dict, List, Optional
|
|
import ast
|
|
import sys
|
|
from io import StringIO
|
|
import re
|
|
import json
|
|
from pydantic import ValidationError
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_community.chat_models import ChatOllama
|
|
from bs4 import BeautifulSoup
|
|
from ..prompts import (
|
|
TEMPLATE_INIT_CODE_GENERATION, TEMPLATE_SEMANTIC_COMPARISON
|
|
)
|
|
from ..utils import (transform_schema,
|
|
extract_code,
|
|
syntax_focused_analysis, syntax_focused_code_generation,
|
|
execution_focused_analysis, execution_focused_code_generation,
|
|
validation_focused_analysis, validation_focused_code_generation,
|
|
semantic_focused_analysis, semantic_focused_code_generation,
|
|
are_content_equal)
|
|
from .base_node import BaseNode
|
|
from jsonschema import validate, ValidationError
|
|
|
|
|
|
class GenerateCodeNode(BaseNode):
|
|
"""
|
|
A node that generates Python code for a function that extracts data
|
|
from HTML based on a output schema.
|
|
|
|
Attributes:
|
|
llm_model: An instance of a language model client, configured for generating answers.
|
|
verbose (bool): A flag indicating whether to show print statements during execution.
|
|
|
|
Args:
|
|
input (str): Boolean expression defining the input keys needed from the state.
|
|
output (List[str]): List of output keys to be updated in the state.
|
|
node_config (dict): Additional configuration for the node.
|
|
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input: str,
|
|
output: List[str],
|
|
node_config: Optional[dict] = None,
|
|
node_name: str = "GenerateCode",
|
|
):
|
|
super().__init__(node_name, "node", input, output, 2, node_config)
|
|
|
|
self.llm_model = node_config["llm_model"]
|
|
|
|
if isinstance(node_config["llm_model"], ChatOllama):
|
|
self.llm_model.format="json"
|
|
|
|
self.verbose = (
|
|
True if node_config is None else node_config.get("verbose", False)
|
|
)
|
|
self.force = (
|
|
False if node_config is None else node_config.get("force", False)
|
|
)
|
|
self.script_creator = (
|
|
False if node_config is None else node_config.get("script_creator", False)
|
|
)
|
|
self.is_md_scraper = (
|
|
False if node_config is None else node_config.get("is_md_scraper", False)
|
|
)
|
|
|
|
self.additional_info = node_config.get("additional_info")
|
|
|
|
self.max_iterations = node_config.get("max_iterations", {
|
|
"overall": 10,
|
|
"syntax": 3,
|
|
"execution": 3,
|
|
"validation": 3,
|
|
"semantic": 3
|
|
})
|
|
|
|
self.output_schema = node_config.get("schema")
|
|
|
|
def execute(self, state: dict) -> dict:
|
|
"""
|
|
Generates Python code for a function that extracts data from HTML based on a output schema.
|
|
|
|
Args:
|
|
state (dict): The current state of the graph. The input keys will be used
|
|
to fetch the correct data from the state.
|
|
|
|
Returns:
|
|
dict: The updated state with the output key containing the generated answer.
|
|
|
|
Raises:
|
|
KeyError: If the input keys are not found in the state, indicating
|
|
that the necessary information for generating an answer is missing.
|
|
RuntimeError: If the maximum number of iterations is
|
|
reached without obtaining the desired code.
|
|
"""
|
|
|
|
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
|
|
|
input_keys = self.get_input_keys(state)
|
|
|
|
input_data = [state[key] for key in input_keys]
|
|
|
|
user_prompt = input_data[0]
|
|
refined_prompt = input_data[1]
|
|
html_info = input_data[2]
|
|
reduced_html = input_data[3]
|
|
answer = input_data[4]
|
|
|
|
self.raw_html = state['original_html'][0].page_content
|
|
|
|
simplefied_schema = str(transform_schema(self.output_schema.schema()))
|
|
|
|
reasoning_state = {
|
|
"user_input": user_prompt,
|
|
"json_schema": simplefied_schema,
|
|
"initial_analysis": refined_prompt,
|
|
"html_code": reduced_html,
|
|
"html_analysis": html_info,
|
|
"generated_code": "",
|
|
"execution_result": None,
|
|
"reference_answer": answer,
|
|
"errors": {
|
|
"syntax": [],
|
|
"execution": [],
|
|
"validation": [],
|
|
"semantic": []
|
|
},
|
|
"iteration": 0
|
|
}
|
|
|
|
final_state = self.overall_reasoning_loop(reasoning_state)
|
|
|
|
state.update({self.output[0]: final_state["generated_code"]})
|
|
return state
|
|
|
|
def overall_reasoning_loop(self, state: dict) -> dict:
|
|
"""
|
|
overrall_reasoning_loop
|
|
"""
|
|
self.logger.info(f"--- (Generating Code) ---")
|
|
state["generated_code"] = self.generate_initial_code(state)
|
|
state["generated_code"] = extract_code(state["generated_code"])
|
|
|
|
while state["iteration"] < self.max_iterations["overall"]:
|
|
state["iteration"] += 1
|
|
if self.verbose:
|
|
self.logger.info(f"--- Iteration {state['iteration']} ---")
|
|
|
|
self.logger.info(f"--- (Checking Code Syntax) ---")
|
|
state = self.syntax_reasoning_loop(state)
|
|
if state["errors"]["syntax"]:
|
|
continue
|
|
|
|
self.logger.info(f"--- (Executing the Generated Code) ---")
|
|
state = self.execution_reasoning_loop(state)
|
|
if state["errors"]["execution"]:
|
|
continue
|
|
|
|
self.logger.info(f"--- (Validate the Code Output Schema) ---")
|
|
state = self.validation_reasoning_loop(state)
|
|
if state["errors"]["validation"]:
|
|
continue
|
|
|
|
self.logger.info(f"--- (Checking if the informations exctrcated are the ones Requested) ---")
|
|
state = self.semantic_comparison_loop(state)
|
|
if state["errors"]["semantic"]:
|
|
continue
|
|
break
|
|
|
|
if state["iteration"] == self.max_iterations["overall"] and \
|
|
(state["errors"]["syntax"] or state["errors"]["execution"] \
|
|
or state["errors"]["validation"] or state["errors"]["semantic"]):
|
|
raise RuntimeError("Max iterations reached without obtaining the desired code.")
|
|
|
|
self.logger.info(f"--- (Code Generated Correctly) ---")
|
|
|
|
return state
|
|
|
|
def syntax_reasoning_loop(self, state: dict) -> dict:
|
|
"""
|
|
syntax reasoning loop
|
|
"""
|
|
for _ in range(self.max_iterations["syntax"]):
|
|
syntax_valid, syntax_message = self.syntax_check(state["generated_code"])
|
|
if syntax_valid:
|
|
state["errors"]["syntax"] = []
|
|
return state
|
|
|
|
state["errors"]["syntax"] = [syntax_message]
|
|
self.logger.info(f"--- (Synax Error Found: {syntax_message}) ---")
|
|
analysis = syntax_focused_analysis(state, self.llm_model)
|
|
self.logger.info(f"""--- (Regenerating Code
|
|
to fix the Error) ---""")
|
|
state["generated_code"] = syntax_focused_code_generation(state,
|
|
analysis, self.llm_model)
|
|
state["generated_code"] = extract_code(state["generated_code"])
|
|
return state
|
|
|
|
def execution_reasoning_loop(self, state: dict) -> dict:
|
|
"""
|
|
execution of the reasoning loop
|
|
"""
|
|
for _ in range(self.max_iterations["execution"]):
|
|
execution_success, execution_result = self.create_sandbox_and_execute(state["generated_code"])
|
|
if execution_success:
|
|
state["execution_result"] = execution_result
|
|
state["errors"]["execution"] = []
|
|
return state
|
|
|
|
state["errors"]["execution"] = [execution_result]
|
|
self.logger.info(f"--- (Code Execution Error: {execution_result}) ---")
|
|
analysis = execution_focused_analysis(state, self.llm_model)
|
|
self.logger.info(f"--- (Regenerating Code to fix the Error) ---")
|
|
state["generated_code"] = execution_focused_code_generation(state,
|
|
analysis, self.llm_model)
|
|
state["generated_code"] = extract_code(state["generated_code"])
|
|
return state
|
|
|
|
def validation_reasoning_loop(self, state: dict) -> dict:
|
|
for _ in range(self.max_iterations["validation"]):
|
|
validation, errors = self.validate_dict(state["execution_result"],
|
|
self.output_schema.schema())
|
|
if validation:
|
|
state["errors"]["validation"] = []
|
|
return state
|
|
|
|
state["errors"]["validation"] = errors
|
|
self.logger.info(f"--- (Code Output not compliant to the deisred Output Schema) ---")
|
|
analysis = validation_focused_analysis(state, self.llm_model)
|
|
self.logger.info(f"--- (Regenerating Code to make the Output compliant to the deisred Output Schema) ---")
|
|
state["generated_code"] = validation_focused_code_generation(state, analysis, self.llm_model)
|
|
state["generated_code"] = extract_code(state["generated_code"])
|
|
return state
|
|
|
|
def semantic_comparison_loop(self, state: dict) -> dict:
|
|
for _ in range(self.max_iterations["semantic"]):
|
|
comparison_result = self.semantic_comparison(state["execution_result"],
|
|
state["reference_answer"])
|
|
if comparison_result["are_semantically_equivalent"]:
|
|
state["errors"]["semantic"] = []
|
|
return state
|
|
|
|
state["errors"]["semantic"] = comparison_result["differences"]
|
|
self.logger.info(f"--- (The informations exctrcated are not the all ones requested) ---")
|
|
analysis = semantic_focused_analysis(state, comparison_result, self.llm_model)
|
|
self.logger.info(f"--- (Regenerating Code to obtain all the infromation requested) ---")
|
|
state["generated_code"] = semantic_focused_code_generation(state, analysis, self.llm_model)
|
|
state["generated_code"] = extract_code(state["generated_code"])
|
|
return state
|
|
|
|
def generate_initial_code(self, state: dict) -> str:
|
|
"""
|
|
function for generating the initial code
|
|
"""
|
|
prompt = PromptTemplate(
|
|
template=TEMPLATE_INIT_CODE_GENERATION,
|
|
partial_variables={
|
|
"user_input": state["user_input"],
|
|
"json_schema": state["json_schema"],
|
|
"initial_analysis": state["initial_analysis"],
|
|
"html_code": state["html_code"],
|
|
"html_analysis": state["html_analysis"]
|
|
})
|
|
|
|
output_parser = StrOutputParser()
|
|
|
|
chain = prompt | self.llm_model | output_parser
|
|
generated_code = chain.invoke({})
|
|
return generated_code
|
|
|
|
def semantic_comparison(self, generated_result: Any, reference_result: Any) -> Dict[str, Any]:
|
|
"""
|
|
semtantic comparison formula
|
|
"""
|
|
reference_result_dict = self.output_schema(**reference_result).dict()
|
|
if are_content_equal(generated_result, reference_result_dict):
|
|
return {
|
|
"are_semantically_equivalent": True,
|
|
"differences": [],
|
|
"explanation": "The generated result and reference result are exactly equal."
|
|
}
|
|
|
|
response_schemas = [
|
|
ResponseSchema(name="are_semantically_equivalent",
|
|
description="""Boolean indicating if the
|
|
results are semantically equivalent"""),
|
|
ResponseSchema(name="differences",
|
|
description="""List of semantic differences
|
|
between the results, if any"""),
|
|
ResponseSchema(name="explanation",
|
|
description="""Detailed explanation of the
|
|
comparison and reasoning""")
|
|
]
|
|
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
|
|
|
|
prompt = PromptTemplate(
|
|
template=TEMPLATE_SEMANTIC_COMPARISON,
|
|
input_variables=["generated_result", "reference_result"],
|
|
partial_variables={"format_instructions": output_parser.get_format_instructions()}
|
|
)
|
|
|
|
chain = prompt | self.llm_model | output_parser
|
|
return chain.invoke({
|
|
"generated_result": json.dumps(generated_result, indent=2),
|
|
"reference_result": json.dumps(reference_result_dict, indent=2)
|
|
})
|
|
|
|
def syntax_check(self, code):
|
|
"""
|
|
syntax checker
|
|
"""
|
|
try:
|
|
ast.parse(code)
|
|
return True, "Syntax is correct."
|
|
except SyntaxError as e:
|
|
return False, f"Syntax error: {str(e)}"
|
|
|
|
def create_sandbox_and_execute(self, function_code):
|
|
"""
|
|
Create a sandbox environment
|
|
"""
|
|
sandbox_globals = {
|
|
'BeautifulSoup': BeautifulSoup,
|
|
're': re,
|
|
'__builtins__': __builtins__,
|
|
}
|
|
|
|
old_stdout = sys.stdout
|
|
sys.stdout = StringIO()
|
|
|
|
try:
|
|
exec(function_code, sandbox_globals)
|
|
|
|
extract_data = sandbox_globals.get('extract_data')
|
|
|
|
if not extract_data:
|
|
raise NameError("Function 'extract_data' not found in the generated code.")
|
|
|
|
result = extract_data(self.raw_html)
|
|
return True, result
|
|
except Exception as e:
|
|
return False, f"Error during execution: {str(e)}"
|
|
finally:
|
|
sys.stdout = old_stdout
|
|
|
|
def validate_dict(self, data: dict, schema):
|
|
"""
|
|
validate_dict method
|
|
"""
|
|
try:
|
|
validate(instance=data, schema=schema)
|
|
return True, None
|
|
except ValidationError as e:
|
|
errors = e.errors()
|
|
return False, errors
|