Update prompt_refiner_node.py

This commit is contained in:
Matteo Vedovati 2024-09-19 10:40:18 +02:00
parent 545970ce54
commit 330c22fd5e

View File

@ -3,7 +3,7 @@ PromptRefinerNode Module
"""
from typing import List, Optional
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI, AzureChatOpenAI
@ -61,8 +61,7 @@ class PromptRefinerNode(BaseNode):
def execute(self, state: dict) -> dict:
"""
Generates an answer by constructing a prompt from the user's input and the scraped
content, querying the language model, and parsing its response.
Generate a refined prompt using the user's prompt, the schema, and additional context.
Args:
state (dict): The current state of the graph. The input keys will be used
@ -76,72 +75,83 @@ class PromptRefinerNode(BaseNode):
that the necessary information for generating an answer is missing.
"""
template_prompt_builder = """
**Task**: Analyze the user's request and the desired output schema to create a structured description for web scraping. Carefully examine both the user's request and the JSON schema to understand the desired data elements and their relationships.
**User's Request**:
{user_input}
**Desired JSON Output Schema**:
```json
{json_schema}
```
**Analysis Instructions**:
Genarate the breakdown of the user request and link the elements of the user's request with the json schema
This analysis will be used to guide the HTML structure examination and ultimately inform the code generation process.
Please generate only the analysis and no other text.
**Response**:
"""
template_prompt_builder_with_context = """
**Task**: Analyze the user's request, the desired output schema, and the additional context the user provided to create a structured description for web scraping. Carefully examine both the user's request and the JSON schema to understand the desired data elements and their relationships.
**User's Request**:
{user_input}
**Desired JSON Output Schema**:
```json
{json_schema}
```
**Additional Context**:
{additional_context}
**Analysis Instructions**:
Genarate the breakdown of the user request and link the elements of the user's request with the json schema
This analysis will be used to guide the HTML structure examination and ultimately inform the code generation process.
Please generate only the analysis and no other text.
**Response**:
"""
self.logger.info(f"--- Executing {self.node_name} Node ---")
input_keys = self.get_input_keys(state)
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0]
user_prompt = input_data[0] # get user prompt
if self.node_config.get("schema", None) is not None:
self.schema = self.node_config["schema"]
self.schema = self.node_config["schema"] # get JSON schema
if self.additional_info is not None: # add context to the prompt
pass
if self.additional_info is not None: # use additional context if present
prompt = PromptTemplate(
template=template_prompt_builder_with_context,
partial_variables={"user_input": user_prompt,
"json_schema": self.schema,
"additional_context": self.additional_info})
else:
prompt = PromptTemplate(
template=template_prompt_builder,
partial_variables={"user_input": user_prompt,
"json_schema": self.schema})
template_prompt_builder = """
You are tasked with generating a prompt that will guide an LLM in reasoning about how to identify specific elements within an HTML page for data extraction.
**Input:**
* **User Prompt:** The user's natural language description of the data they want to extract from the HTML page.
* **JSON Schema:** A JSON schema representing the desired output structure of the extracted data.
* **Additional Information (Optional):** Any supplementary details provided by the user, such as specific HTML patterns they've observed, known challenges in identifying certain elements, or preferences for particular scraping strategies.
output_parser = StrOutputParser()
**Output:**
"""
example_prompts = [
"""
"""
]
prompt = PromptTemplate(
template=template_no_chunks_prompt ,
input_variables=["question"],
partial_variables={"context": doc,
"format_instructions": format_instructions})
chain = prompt | self.llm_model | output_parser
answer = chain.invoke({"question": user_prompt})
refined_prompt = chain.invoke({})
state.update({self.output[0]: answer})
state.update({self.output[0]: refined_prompt})
return state
chains_dict = {}
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
prompt = PromptTemplate(
template=TEMPLATE_CHUNKS,
input_variables=["question"],
partial_variables={"context": chunk,
"chunk_id": i + 1,
"format_instructions": format_instructions})
chain_name = f"chunk{i+1}"
chains_dict[chain_name] = prompt | self.llm_model | output_parser
async_runner = RunnableParallel(**chains_dict)
batch_results = async_runner.invoke({"question": user_prompt})
merge_prompt = PromptTemplate(
template = template_merge_prompt ,
input_variables=["context", "question"],
partial_variables={"format_instructions": format_instructions},
)
merge_chain = merge_prompt | self.llm_model | output_parser
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
state.update({self.output[0]: answer})
return state
else: # no schema provided
self.logger.error("No schema provided for prompt refinement.")
# TODO: Handle the case where no schema is provided => error handling
return state