mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-28 21:01:55 +08:00
update for pylint
This commit is contained in:
parent
c181fea3bd
commit
7c39b06640
@ -2,20 +2,20 @@
|
||||
GenerateCodeNode Module
|
||||
"""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnableParallel
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
import ast
|
||||
import sys
|
||||
from io import StringIO
|
||||
from bs4 import BeautifulSoup
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
from .base_node import BaseNode
|
||||
import json
|
||||
from pydantic import ValidationError
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from bs4 import BeautifulSoup
|
||||
from ..prompts import (
|
||||
TEMPLATE_INIT_CODE_GENERATION, TEMPLATE_SEMANTIC_COMPARISON
|
||||
)
|
||||
from ..utils import (transform_schema,
|
||||
extract_code,
|
||||
syntax_focused_analysis, syntax_focused_code_generation,
|
||||
@ -23,15 +23,14 @@ from ..utils import (transform_schema,
|
||||
validation_focused_analysis, validation_focused_code_generation,
|
||||
semantic_focused_analysis, semantic_focused_code_generation,
|
||||
are_content_equal)
|
||||
from .base_node import BaseNode
|
||||
from jsonschema import validate, ValidationError
|
||||
import json
|
||||
from ..prompts import (
|
||||
TEMPLATE_INIT_CODE_GENERATION, TEMPLATE_SEMANTIC_COMPARISON
|
||||
)
|
||||
|
||||
|
||||
class GenerateCodeNode(BaseNode):
|
||||
"""
|
||||
A node that generates Python code for a function that extracts data from HTML based on a output schema.
|
||||
A node that generates Python code for a function that extracts data
|
||||
from HTML based on a output schema.
|
||||
|
||||
Attributes:
|
||||
llm_model: An instance of a language model client, configured for generating answers.
|
||||
@ -72,7 +71,7 @@ class GenerateCodeNode(BaseNode):
|
||||
)
|
||||
|
||||
self.additional_info = node_config.get("additional_info")
|
||||
|
||||
|
||||
self.max_iterations = node_config.get("max_iterations", {
|
||||
"overall": 10,
|
||||
"syntax": 3,
|
||||
@ -80,7 +79,7 @@ class GenerateCodeNode(BaseNode):
|
||||
"validation": 3,
|
||||
"semantic": 3
|
||||
})
|
||||
|
||||
|
||||
self.output_schema = node_config.get("schema")
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
@ -97,25 +96,26 @@ class GenerateCodeNode(BaseNode):
|
||||
Raises:
|
||||
KeyError: If the input keys are not found in the state, indicating
|
||||
that the necessary information for generating an answer is missing.
|
||||
RuntimeError: If the maximum number of iterations is reached without obtaining the desired code.
|
||||
RuntimeError: If the maximum number of iterations is
|
||||
reached without obtaining the desired code.
|
||||
"""
|
||||
|
||||
|
||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
|
||||
user_prompt = input_data[0]
|
||||
refined_prompt = input_data[1]
|
||||
html_info = input_data[2]
|
||||
reduced_html = input_data[3]
|
||||
answer = input_data[4]
|
||||
|
||||
answer = input_data[4]
|
||||
|
||||
self.raw_html = state['original_html'][0].page_content
|
||||
|
||||
|
||||
simplefied_schema = str(transform_schema(self.output_schema.schema()))
|
||||
|
||||
|
||||
reasoning_state = {
|
||||
"user_input": user_prompt,
|
||||
"json_schema": simplefied_schema,
|
||||
@ -133,89 +133,101 @@ class GenerateCodeNode(BaseNode):
|
||||
},
|
||||
"iteration": 0
|
||||
}
|
||||
|
||||
|
||||
|
||||
final_state = self.overall_reasoning_loop(reasoning_state)
|
||||
|
||||
|
||||
state.update({self.output[0]: final_state["generated_code"]})
|
||||
return state
|
||||
|
||||
|
||||
def overall_reasoning_loop(self, state: dict) -> dict:
|
||||
"""
|
||||
overrall_reasoning_loop
|
||||
"""
|
||||
self.logger.info(f"--- (Generating Code) ---")
|
||||
state["generated_code"] = self.generate_initial_code(state)
|
||||
state["generated_code"] = extract_code(state["generated_code"])
|
||||
|
||||
|
||||
while state["iteration"] < self.max_iterations["overall"]:
|
||||
state["iteration"] += 1
|
||||
if self.verbose:
|
||||
self.logger.info(f"--- Iteration {state['iteration']} ---")
|
||||
|
||||
|
||||
self.logger.info(f"--- (Checking Code Syntax) ---")
|
||||
state = self.syntax_reasoning_loop(state)
|
||||
if state["errors"]["syntax"]:
|
||||
continue
|
||||
|
||||
|
||||
self.logger.info(f"--- (Executing the Generated Code) ---")
|
||||
state = self.execution_reasoning_loop(state)
|
||||
if state["errors"]["execution"]:
|
||||
continue
|
||||
|
||||
|
||||
self.logger.info(f"--- (Validate the Code Output Schema) ---")
|
||||
state = self.validation_reasoning_loop(state)
|
||||
if state["errors"]["validation"]:
|
||||
continue
|
||||
|
||||
|
||||
self.logger.info(f"--- (Checking if the informations exctrcated are the ones Requested) ---")
|
||||
state = self.semantic_comparison_loop(state)
|
||||
if state["errors"]["semantic"]:
|
||||
continue
|
||||
continue
|
||||
break
|
||||
|
||||
|
||||
if state["iteration"] == self.max_iterations["overall"] and (state["errors"]["syntax"] or state["errors"]["execution"] or state["errors"]["validation"] or state["errors"]["semantic"]):
|
||||
raise RuntimeError("Max iterations reached without obtaining the desired code.")
|
||||
|
||||
|
||||
self.logger.info(f"--- (Code Generated Correctly) ---")
|
||||
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def syntax_reasoning_loop(self, state: dict) -> dict:
|
||||
"""
|
||||
syntax reasoning loop
|
||||
"""
|
||||
for _ in range(self.max_iterations["syntax"]):
|
||||
syntax_valid, syntax_message = self.syntax_check(state["generated_code"])
|
||||
if syntax_valid:
|
||||
state["errors"]["syntax"] = []
|
||||
return state
|
||||
|
||||
|
||||
state["errors"]["syntax"] = [syntax_message]
|
||||
self.logger.info(f"--- (Synax Error Found: {syntax_message}) ---")
|
||||
analysis = syntax_focused_analysis(state, self.llm_model)
|
||||
self.logger.info(f"--- (Regenerating Code to fix the Error) ---")
|
||||
state["generated_code"] = syntax_focused_code_generation(state, analysis, self.llm_model)
|
||||
self.logger.info(f"""--- (Regenerating Code
|
||||
to fix the Error) ---""")
|
||||
state["generated_code"] = syntax_focused_code_generation(state,
|
||||
analysis, self.llm_model)
|
||||
state["generated_code"] = extract_code(state["generated_code"])
|
||||
return state
|
||||
|
||||
|
||||
def execution_reasoning_loop(self, state: dict) -> dict:
|
||||
"""
|
||||
execution of the reasoning loop
|
||||
"""
|
||||
for _ in range(self.max_iterations["execution"]):
|
||||
execution_success, execution_result = self.create_sandbox_and_execute(state["generated_code"])
|
||||
if execution_success:
|
||||
state["execution_result"] = execution_result
|
||||
state["errors"]["execution"] = []
|
||||
return state
|
||||
|
||||
|
||||
state["errors"]["execution"] = [execution_result]
|
||||
self.logger.info(f"--- (Code Execution Error: {execution_result}) ---")
|
||||
analysis = execution_focused_analysis(state, self.llm_model)
|
||||
self.logger.info(f"--- (Regenerating Code to fix the Error) ---")
|
||||
state["generated_code"] = execution_focused_code_generation(state, analysis, self.llm_model)
|
||||
state["generated_code"] = execution_focused_code_generation(state,
|
||||
analysis, self.llm_model)
|
||||
state["generated_code"] = extract_code(state["generated_code"])
|
||||
return state
|
||||
|
||||
|
||||
def validation_reasoning_loop(self, state: dict) -> dict:
|
||||
for _ in range(self.max_iterations["validation"]):
|
||||
validation, errors = self.validate_dict(state["execution_result"], self.output_schema.schema())
|
||||
validation, errors = self.validate_dict(state["execution_result"],
|
||||
self.output_schema.schema())
|
||||
if validation:
|
||||
state["errors"]["validation"] = []
|
||||
return state
|
||||
|
||||
|
||||
state["errors"]["validation"] = errors
|
||||
self.logger.info(f"--- (Code Output not compliant to the deisred Output Schema) ---")
|
||||
analysis = validation_focused_analysis(state, self.llm_model)
|
||||
@ -223,14 +235,14 @@ class GenerateCodeNode(BaseNode):
|
||||
state["generated_code"] = validation_focused_code_generation(state, analysis, self.llm_model)
|
||||
state["generated_code"] = extract_code(state["generated_code"])
|
||||
return state
|
||||
|
||||
|
||||
def semantic_comparison_loop(self, state: dict) -> dict:
|
||||
for _ in range(self.max_iterations["semantic"]):
|
||||
comparison_result = self.semantic_comparison(state["execution_result"], state["reference_answer"])
|
||||
if comparison_result["are_semantically_equivalent"]:
|
||||
state["errors"]["semantic"] = []
|
||||
return state
|
||||
|
||||
|
||||
state["errors"]["semantic"] = comparison_result["differences"]
|
||||
self.logger.info(f"--- (The informations exctrcated are not the all ones requested) ---")
|
||||
analysis = semantic_focused_analysis(state, comparison_result, self.llm_model)
|
||||
@ -238,8 +250,11 @@ class GenerateCodeNode(BaseNode):
|
||||
state["generated_code"] = semantic_focused_code_generation(state, analysis, self.llm_model)
|
||||
state["generated_code"] = extract_code(state["generated_code"])
|
||||
return state
|
||||
|
||||
|
||||
def generate_initial_code(self, state: dict) -> str:
|
||||
"""
|
||||
function for generating the initial code
|
||||
"""
|
||||
prompt = PromptTemplate(
|
||||
template=TEMPLATE_INIT_CODE_GENERATION,
|
||||
partial_variables={
|
||||
@ -255,22 +270,29 @@ class GenerateCodeNode(BaseNode):
|
||||
chain = prompt | self.llm_model | output_parser
|
||||
generated_code = chain.invoke({})
|
||||
return generated_code
|
||||
|
||||
|
||||
def semantic_comparison(self, generated_result: Any, reference_result: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
semtantic comparison formula
|
||||
"""
|
||||
reference_result_dict = self.output_schema(**reference_result).dict()
|
||||
|
||||
# Check if generated result and reference result are actually equal
|
||||
if are_content_equal(generated_result, reference_result_dict):
|
||||
return {
|
||||
"are_semantically_equivalent": True,
|
||||
"differences": [],
|
||||
"explanation": "The generated result and reference result are exactly equal."
|
||||
}
|
||||
|
||||
|
||||
response_schemas = [
|
||||
ResponseSchema(name="are_semantically_equivalent", description="Boolean indicating if the results are semantically equivalent"),
|
||||
ResponseSchema(name="differences", description="List of semantic differences between the results, if any"),
|
||||
ResponseSchema(name="explanation", description="Detailed explanation of the comparison and reasoning")
|
||||
ResponseSchema(name="are_semantically_equivalent",
|
||||
description="""Boolean indicating if the
|
||||
results are semantically equivalent"""),
|
||||
ResponseSchema(name="differences",
|
||||
description="""List of semantic differences
|
||||
between the results, if any"""),
|
||||
ResponseSchema(name="explanation",
|
||||
description="""Detailed explanation of the
|
||||
comparison and reasoning""")
|
||||
]
|
||||
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
|
||||
|
||||
@ -285,8 +307,11 @@ class GenerateCodeNode(BaseNode):
|
||||
"generated_result": json.dumps(generated_result, indent=2),
|
||||
"reference_result": json.dumps(reference_result_dict, indent=2)
|
||||
})
|
||||
|
||||
|
||||
def syntax_check(self, code):
|
||||
"""
|
||||
syntax checker
|
||||
"""
|
||||
try:
|
||||
ast.parse(code)
|
||||
return True, "Syntax is correct."
|
||||
@ -294,36 +319,40 @@ class GenerateCodeNode(BaseNode):
|
||||
return False, f"Syntax error: {str(e)}"
|
||||
|
||||
def create_sandbox_and_execute(self, function_code):
|
||||
# Create a sandbox environment
|
||||
"""
|
||||
Create a sandbox environment
|
||||
"""
|
||||
sandbox_globals = {
|
||||
'BeautifulSoup': BeautifulSoup,
|
||||
're': re,
|
||||
'__builtins__': __builtins__,
|
||||
}
|
||||
|
||||
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = StringIO()
|
||||
|
||||
|
||||
try:
|
||||
exec(function_code, sandbox_globals)
|
||||
|
||||
|
||||
extract_data = sandbox_globals.get('extract_data')
|
||||
|
||||
|
||||
if not extract_data:
|
||||
raise NameError("Function 'extract_data' not found in the generated code.")
|
||||
|
||||
result = extract_data(self.raw_html)
|
||||
|
||||
|
||||
result = extract_data(self.raw_html)
|
||||
return True, result
|
||||
except Exception as e:
|
||||
return False, f"Error during execution: {str(e)}"
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
|
||||
|
||||
def validate_dict(self, data: dict, schema):
|
||||
"""
|
||||
validate_dict method
|
||||
"""
|
||||
try:
|
||||
validate(instance=data, schema=schema)
|
||||
return True, None
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
return False, errors
|
||||
return False, errors
|
||||
|
||||
@ -4,12 +4,7 @@ PromptRefinerNode Module
|
||||
from typing import List, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnableParallel
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from tqdm import tqdm
|
||||
from .base_node import BaseNode
|
||||
from ..utils import transform_schema
|
||||
from ..prompts import (
|
||||
@ -61,7 +56,7 @@ class PromptRefinerNode(BaseNode):
|
||||
)
|
||||
|
||||
self.additional_info = node_config.get("additional_info")
|
||||
|
||||
|
||||
self.output_schema = node_config.get("schema")
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
@ -85,7 +80,7 @@ class PromptRefinerNode(BaseNode):
|
||||
user_prompt = state['user_prompt']
|
||||
|
||||
self.simplefied_schema = transform_schema(self.output_schema.schema())
|
||||
|
||||
|
||||
if self.additional_info is not None:
|
||||
prompt = PromptTemplate(
|
||||
template=TEMPLATE_REFINER_WITH_CONTEXT,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user