Added logs

This commit is contained in:
Matteo Vedovati 2024-09-21 18:46:16 +02:00
parent f38c5e1d8f
commit 2ff0f0113f
4 changed files with 50 additions and 35 deletions

View File

@ -17,8 +17,10 @@ from ..nodes import (
class CodeGeneratorGraph(AbstractGraph):
"""
...
CodeGeneratorGraph is a script generator pipeline that generates the function extract_data(html: str) -> dict() for
extarcting the wanted informations from a HTML page. The code generated is in Python and uses the library BeautifulSoup.
It requires a user prompt, a source URL, and a output schema.
Attributes:
prompt (str): The prompt for the graph.
source (str): The source of the graph.

View File

@ -26,7 +26,7 @@ import string
class GenerateCodeNode(BaseNode):
"""
...
A node that generates Python code for a function that extracts data from HTML based on a output schema.
Attributes:
llm_model: An instance of a language model client, configured for generating answers.
@ -80,7 +80,7 @@ class GenerateCodeNode(BaseNode):
def execute(self, state: dict) -> dict:
"""
...
Generates Python code for a function that extracts data from HTML based on a output schema.
Args:
state (dict): The current state of the graph. The input keys will be used
@ -92,6 +92,7 @@ class GenerateCodeNode(BaseNode):
Raises:
KeyError: If the input keys are not found in the state, indicating
that the necessary information for generating an answer is missing.
RuntimeError: If the maximum number of iterations is reached without obtaining the desired code.
"""
self.logger.info(f"--- Executing {self.node_name} Node ---")
@ -135,25 +136,31 @@ class GenerateCodeNode(BaseNode):
return state
def overall_reasoning_loop(self, state: dict) -> dict:
self.logger.info(f"--- (Generating Code) ---")
state["generated_code"] = self.generate_initial_code(state)
state["generated_code"] = self.extract_code(state["generated_code"])
while state["iteration"] < self.max_iterations["overall"]:
state["iteration"] += 1
if self.verbose:
self.logger.info(f"--- Iteration {state['iteration']} ---")
self.logger.info(f"--- (Checking Code Syntax) ---")
state = self.syntax_reasoning_loop(state)
if state["errors"]["syntax"]:
continue
self.logger.info(f"--- (Executing the Generated Code) ---")
state = self.execution_reasoning_loop(state)
if state["errors"]["execution"]:
continue
self.logger.info(f"--- (Validate the Code Output Schema) ---")
state = self.validation_reasoning_loop(state)
if state["errors"]["validation"]:
continue
self.logger.info(f"--- (Checking if the informations exctrcated are the ones Requested) ---")
state = self.semantic_comparison_loop(state)
if state["errors"]["semantic"]:
continue
@ -161,6 +168,11 @@ class GenerateCodeNode(BaseNode):
# If we've made it here, the code is valid and produces the correct output
break
if state["iteration"] == self.max_iterations["overall"] and (state["errors"]["syntax"] or state["errors"]["execution"] or state["errors"]["validation"] or state["errors"]["semantic"]):
raise RuntimeError("Max iterations reached without obtaining the desired code.")
self.logger.info(f"--- (Code Generated Correctly) ---")
return state
def syntax_reasoning_loop(self, state: dict) -> dict:
@ -171,7 +183,9 @@ class GenerateCodeNode(BaseNode):
return state
state["errors"]["syntax"] = [syntax_message]
self.logger.info(f"--- (Synax Error Found: {syntax_message}) ---")
analysis = self.syntax_focused_analysis(state)
self.logger.info(f"--- (Regenerating Code to fix the Error) ---")
state["generated_code"] = self.syntax_focused_code_generation(state, analysis)
state["generated_code"] = self.extract_code(state["generated_code"])
return state
@ -185,7 +199,9 @@ class GenerateCodeNode(BaseNode):
return state
state["errors"]["execution"] = [execution_result]
self.logger.info(f"--- (Code Execution Error: {execution_result}) ---")
analysis = self.execution_focused_analysis(state)
self.logger.info(f"--- (Regenerating Code to fix the Error) ---")
state["generated_code"] = self.execution_focused_code_generation(state, analysis)
state["generated_code"] = self.extract_code(state["generated_code"])
return state
@ -198,7 +214,9 @@ class GenerateCodeNode(BaseNode):
return state
state["errors"]["validation"] = errors
self.logger.info(f"--- (Code Output not compliant to the deisred Output Schema) ---")
analysis = self.validation_focused_analysis(state)
self.logger.info(f"--- (Regenerating Code to make the Output compliant to the deisred Output Schema) ---")
state["generated_code"] = self.validation_focused_code_generation(state, analysis)
state["generated_code"] = self.extract_code(state["generated_code"])
return state
@ -211,7 +229,9 @@ class GenerateCodeNode(BaseNode):
return state
state["errors"]["semantic"] = comparison_result["differences"]
self.logger.info(f"--- (The informations exctrcated are not the all ones requested) ---")
analysis = self.semantic_focused_analysis(state, comparison_result)
self.logger.info(f"--- (Regenerating Code to obtain all the infromation requested) ---")
state["generated_code"] = self.semantic_focused_code_generation(state, analysis)
state["generated_code"] = self.extract_code(state["generated_code"])
return state

View File

@ -16,8 +16,8 @@ from ..utils import reduce_html
class HtmlAnalyzerNode(BaseNode):
"""
...
A node that generates an analysis of the provided HTML code based on the wanted infromations to be extracted.
Attributes:
llm_model: An instance of a language model client, configured for generating answers.
verbose (bool): A flag indicating whether to show print statements during execution.
@ -60,7 +60,7 @@ class HtmlAnalyzerNode(BaseNode):
def execute(self, state: dict) -> dict:
"""
...
Generates an analysis of the provided HTML code based on the wanted infromations to be extracted.
Args:
state (dict): The current state of the graph. The input keys will be used

View File

@ -59,6 +59,8 @@ class PromptRefinerNode(BaseNode):
)
self.additional_info = node_config.get("additional_info")
self.output_schema = node_config.get("schema") # get JSON output schema
def execute(self, state: dict) -> dict:
"""
@ -137,33 +139,24 @@ class PromptRefinerNode(BaseNode):
user_prompt = state['user_prompt'] # get user prompt
if self.node_config.get("schema", None) is not None:
self.simplefied_schema = transform_schema(self.output_schema.schema()) # get JSON schema
if self.additional_info is not None: # use additional context if present
prompt = PromptTemplate(
template=template_prompt_builder_with_context,
partial_variables={"user_input": user_prompt,
"json_schema": str(self.simplefied_schema),
"additional_context": self.additional_info})
else:
prompt = PromptTemplate(
template=template_prompt_builder,
partial_variables={"user_input": user_prompt,
"json_schema": str(self.simplefied_schema)})
self.simplefied_schema = transform_schema(self.node_config["schema"].schema()) # get JSON schema
if self.additional_info is not None: # use additional context if present
prompt = PromptTemplate(
template=template_prompt_builder_with_context,
partial_variables={"user_input": user_prompt,
"json_schema": str(self.simplefied_schema),
"additional_context": self.additional_info})
else:
prompt = PromptTemplate(
template=template_prompt_builder,
partial_variables={"user_input": user_prompt,
"json_schema": str(self.simplefied_schema)})
output_parser = StrOutputParser()
output_parser = StrOutputParser()
chain = prompt | self.llm_model | output_parser
refined_prompt = chain.invoke({})
chain = prompt | self.llm_model | output_parser
refined_prompt = chain.invoke({})
state.update({self.output[0]: refined_prompt})
return state
else: # no schema provided
self.logger.error("No schema provided for prompt refinement.")
# TODO: Handle the case where no schema is provided => error handling
return state
state.update({self.output[0]: refined_prompt})
return state