update for pylint

This commit is contained in:
Marco Vinciguerra 2024-09-27 09:34:56 +02:00
parent c181fea3bd
commit 7c39b06640
2 changed files with 99 additions and 75 deletions

View File

@ -2,20 +2,20 @@
GenerateCodeNode Module
"""
from typing import Any, Dict, List, Optional
from langchain.prompts import PromptTemplate
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_community.chat_models import ChatOllama
import ast
import sys
from io import StringIO
from bs4 import BeautifulSoup
import re
from tqdm import tqdm
from .base_node import BaseNode
import json
from pydantic import ValidationError
from langchain.prompts import PromptTemplate
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_core.output_parsers import StrOutputParser
from langchain_community.chat_models import ChatOllama
from bs4 import BeautifulSoup
from ..prompts import (
TEMPLATE_INIT_CODE_GENERATION, TEMPLATE_SEMANTIC_COMPARISON
)
from ..utils import (transform_schema,
extract_code,
syntax_focused_analysis, syntax_focused_code_generation,
@ -23,15 +23,14 @@ from ..utils import (transform_schema,
validation_focused_analysis, validation_focused_code_generation,
semantic_focused_analysis, semantic_focused_code_generation,
are_content_equal)
from .base_node import BaseNode
from jsonschema import validate, ValidationError
import json
from ..prompts import (
TEMPLATE_INIT_CODE_GENERATION, TEMPLATE_SEMANTIC_COMPARISON
)
class GenerateCodeNode(BaseNode):
"""
A node that generates Python code for a function that extracts data from HTML based on a output schema.
A node that generates Python code for a function that extracts data
from HTML based on a output schema.
Attributes:
llm_model: An instance of a language model client, configured for generating answers.
@ -72,7 +71,7 @@ class GenerateCodeNode(BaseNode):
)
self.additional_info = node_config.get("additional_info")
self.max_iterations = node_config.get("max_iterations", {
"overall": 10,
"syntax": 3,
@ -80,7 +79,7 @@ class GenerateCodeNode(BaseNode):
"validation": 3,
"semantic": 3
})
self.output_schema = node_config.get("schema")
def execute(self, state: dict) -> dict:
@ -97,25 +96,26 @@ class GenerateCodeNode(BaseNode):
Raises:
KeyError: If the input keys are not found in the state, indicating
that the necessary information for generating an answer is missing.
RuntimeError: If the maximum number of iterations is reached without obtaining the desired code.
RuntimeError: If the maximum number of iterations is
reached without obtaining the desired code.
"""
self.logger.info(f"--- Executing {self.node_name} Node ---")
input_keys = self.get_input_keys(state)
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0]
refined_prompt = input_data[1]
html_info = input_data[2]
reduced_html = input_data[3]
answer = input_data[4]
answer = input_data[4]
self.raw_html = state['original_html'][0].page_content
simplefied_schema = str(transform_schema(self.output_schema.schema()))
reasoning_state = {
"user_input": user_prompt,
"json_schema": simplefied_schema,
@ -133,89 +133,101 @@ class GenerateCodeNode(BaseNode):
},
"iteration": 0
}
final_state = self.overall_reasoning_loop(reasoning_state)
state.update({self.output[0]: final_state["generated_code"]})
return state
def overall_reasoning_loop(self, state: dict) -> dict:
"""
overrall_reasoning_loop
"""
self.logger.info(f"--- (Generating Code) ---")
state["generated_code"] = self.generate_initial_code(state)
state["generated_code"] = extract_code(state["generated_code"])
while state["iteration"] < self.max_iterations["overall"]:
state["iteration"] += 1
if self.verbose:
self.logger.info(f"--- Iteration {state['iteration']} ---")
self.logger.info(f"--- (Checking Code Syntax) ---")
state = self.syntax_reasoning_loop(state)
if state["errors"]["syntax"]:
continue
self.logger.info(f"--- (Executing the Generated Code) ---")
state = self.execution_reasoning_loop(state)
if state["errors"]["execution"]:
continue
self.logger.info(f"--- (Validate the Code Output Schema) ---")
state = self.validation_reasoning_loop(state)
if state["errors"]["validation"]:
continue
self.logger.info(f"--- (Checking if the informations exctrcated are the ones Requested) ---")
state = self.semantic_comparison_loop(state)
if state["errors"]["semantic"]:
continue
continue
break
if state["iteration"] == self.max_iterations["overall"] and (state["errors"]["syntax"] or state["errors"]["execution"] or state["errors"]["validation"] or state["errors"]["semantic"]):
raise RuntimeError("Max iterations reached without obtaining the desired code.")
self.logger.info(f"--- (Code Generated Correctly) ---")
return state
def syntax_reasoning_loop(self, state: dict) -> dict:
"""
syntax reasoning loop
"""
for _ in range(self.max_iterations["syntax"]):
syntax_valid, syntax_message = self.syntax_check(state["generated_code"])
if syntax_valid:
state["errors"]["syntax"] = []
return state
state["errors"]["syntax"] = [syntax_message]
self.logger.info(f"--- (Synax Error Found: {syntax_message}) ---")
analysis = syntax_focused_analysis(state, self.llm_model)
self.logger.info(f"--- (Regenerating Code to fix the Error) ---")
state["generated_code"] = syntax_focused_code_generation(state, analysis, self.llm_model)
self.logger.info(f"""--- (Regenerating Code
to fix the Error) ---""")
state["generated_code"] = syntax_focused_code_generation(state,
analysis, self.llm_model)
state["generated_code"] = extract_code(state["generated_code"])
return state
def execution_reasoning_loop(self, state: dict) -> dict:
"""
execution of the reasoning loop
"""
for _ in range(self.max_iterations["execution"]):
execution_success, execution_result = self.create_sandbox_and_execute(state["generated_code"])
if execution_success:
state["execution_result"] = execution_result
state["errors"]["execution"] = []
return state
state["errors"]["execution"] = [execution_result]
self.logger.info(f"--- (Code Execution Error: {execution_result}) ---")
analysis = execution_focused_analysis(state, self.llm_model)
self.logger.info(f"--- (Regenerating Code to fix the Error) ---")
state["generated_code"] = execution_focused_code_generation(state, analysis, self.llm_model)
state["generated_code"] = execution_focused_code_generation(state,
analysis, self.llm_model)
state["generated_code"] = extract_code(state["generated_code"])
return state
def validation_reasoning_loop(self, state: dict) -> dict:
for _ in range(self.max_iterations["validation"]):
validation, errors = self.validate_dict(state["execution_result"], self.output_schema.schema())
validation, errors = self.validate_dict(state["execution_result"],
self.output_schema.schema())
if validation:
state["errors"]["validation"] = []
return state
state["errors"]["validation"] = errors
self.logger.info(f"--- (Code Output not compliant to the deisred Output Schema) ---")
analysis = validation_focused_analysis(state, self.llm_model)
@ -223,14 +235,14 @@ class GenerateCodeNode(BaseNode):
state["generated_code"] = validation_focused_code_generation(state, analysis, self.llm_model)
state["generated_code"] = extract_code(state["generated_code"])
return state
def semantic_comparison_loop(self, state: dict) -> dict:
for _ in range(self.max_iterations["semantic"]):
comparison_result = self.semantic_comparison(state["execution_result"], state["reference_answer"])
if comparison_result["are_semantically_equivalent"]:
state["errors"]["semantic"] = []
return state
state["errors"]["semantic"] = comparison_result["differences"]
self.logger.info(f"--- (The informations exctrcated are not the all ones requested) ---")
analysis = semantic_focused_analysis(state, comparison_result, self.llm_model)
@ -238,8 +250,11 @@ class GenerateCodeNode(BaseNode):
state["generated_code"] = semantic_focused_code_generation(state, analysis, self.llm_model)
state["generated_code"] = extract_code(state["generated_code"])
return state
def generate_initial_code(self, state: dict) -> str:
"""
function for generating the initial code
"""
prompt = PromptTemplate(
template=TEMPLATE_INIT_CODE_GENERATION,
partial_variables={
@ -255,22 +270,29 @@ class GenerateCodeNode(BaseNode):
chain = prompt | self.llm_model | output_parser
generated_code = chain.invoke({})
return generated_code
def semantic_comparison(self, generated_result: Any, reference_result: Any) -> Dict[str, Any]:
"""
semtantic comparison formula
"""
reference_result_dict = self.output_schema(**reference_result).dict()
# Check if generated result and reference result are actually equal
if are_content_equal(generated_result, reference_result_dict):
return {
"are_semantically_equivalent": True,
"differences": [],
"explanation": "The generated result and reference result are exactly equal."
}
response_schemas = [
ResponseSchema(name="are_semantically_equivalent", description="Boolean indicating if the results are semantically equivalent"),
ResponseSchema(name="differences", description="List of semantic differences between the results, if any"),
ResponseSchema(name="explanation", description="Detailed explanation of the comparison and reasoning")
ResponseSchema(name="are_semantically_equivalent",
description="""Boolean indicating if the
results are semantically equivalent"""),
ResponseSchema(name="differences",
description="""List of semantic differences
between the results, if any"""),
ResponseSchema(name="explanation",
description="""Detailed explanation of the
comparison and reasoning""")
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
@ -285,8 +307,11 @@ class GenerateCodeNode(BaseNode):
"generated_result": json.dumps(generated_result, indent=2),
"reference_result": json.dumps(reference_result_dict, indent=2)
})
def syntax_check(self, code):
"""
syntax checker
"""
try:
ast.parse(code)
return True, "Syntax is correct."
@ -294,36 +319,40 @@ class GenerateCodeNode(BaseNode):
return False, f"Syntax error: {str(e)}"
def create_sandbox_and_execute(self, function_code):
# Create a sandbox environment
"""
Create a sandbox environment
"""
sandbox_globals = {
'BeautifulSoup': BeautifulSoup,
're': re,
'__builtins__': __builtins__,
}
old_stdout = sys.stdout
sys.stdout = StringIO()
try:
exec(function_code, sandbox_globals)
extract_data = sandbox_globals.get('extract_data')
if not extract_data:
raise NameError("Function 'extract_data' not found in the generated code.")
result = extract_data(self.raw_html)
result = extract_data(self.raw_html)
return True, result
except Exception as e:
return False, f"Error during execution: {str(e)}"
finally:
sys.stdout = old_stdout
def validate_dict(self, data: dict, schema):
"""
validate_dict method
"""
try:
validate(instance=data, schema=schema)
return True, None
except ValidationError as e:
errors = e.errors()
return False, errors
return False, errors

View File

@ -4,12 +4,7 @@ PromptRefinerNode Module
from typing import List, Optional
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_mistralai import ChatMistralAI
from langchain_community.chat_models import ChatOllama
from tqdm import tqdm
from .base_node import BaseNode
from ..utils import transform_schema
from ..prompts import (
@ -61,7 +56,7 @@ class PromptRefinerNode(BaseNode):
)
self.additional_info = node_config.get("additional_info")
self.output_schema = node_config.get("schema")
def execute(self, state: dict) -> dict:
@ -85,7 +80,7 @@ class PromptRefinerNode(BaseNode):
user_prompt = state['user_prompt']
self.simplefied_schema = transform_schema(self.output_schema.schema())
if self.additional_info is not None:
prompt = PromptTemplate(
template=TEMPLATE_REFINER_WITH_CONTEXT,