mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
Update prompt_refiner_node.py
This commit is contained in:
parent
545970ce54
commit
330c22fd5e
@ -3,7 +3,7 @@ PromptRefinerNode Module
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnableParallel
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||
@ -61,8 +61,7 @@ class PromptRefinerNode(BaseNode):
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
"""
|
||||
Generates an answer by constructing a prompt from the user's input and the scraped
|
||||
content, querying the language model, and parsing its response.
|
||||
Generate a refined prompt using the user's prompt, the schema, and additional context.
|
||||
|
||||
Args:
|
||||
state (dict): The current state of the graph. The input keys will be used
|
||||
@ -76,72 +75,83 @@ class PromptRefinerNode(BaseNode):
|
||||
that the necessary information for generating an answer is missing.
|
||||
"""
|
||||
|
||||
template_prompt_builder = """
|
||||
**Task**: Analyze the user's request and the desired output schema to create a structured description for web scraping. Carefully examine both the user's request and the JSON schema to understand the desired data elements and their relationships.
|
||||
|
||||
**User's Request**:
|
||||
{user_input}
|
||||
|
||||
**Desired JSON Output Schema**:
|
||||
```json
|
||||
{json_schema}
|
||||
```
|
||||
|
||||
**Analysis Instructions**:
|
||||
Genarate the breakdown of the user request and link the elements of the user's request with the json schema
|
||||
|
||||
This analysis will be used to guide the HTML structure examination and ultimately inform the code generation process.
|
||||
Please generate only the analysis and no other text.
|
||||
|
||||
**Response**:
|
||||
"""
|
||||
|
||||
template_prompt_builder_with_context = """
|
||||
**Task**: Analyze the user's request, the desired output schema, and the additional context the user provided to create a structured description for web scraping. Carefully examine both the user's request and the JSON schema to understand the desired data elements and their relationships.
|
||||
|
||||
**User's Request**:
|
||||
{user_input}
|
||||
|
||||
**Desired JSON Output Schema**:
|
||||
```json
|
||||
{json_schema}
|
||||
```
|
||||
|
||||
**Additional Context**:
|
||||
{additional_context}
|
||||
|
||||
**Analysis Instructions**:
|
||||
Genarate the breakdown of the user request and link the elements of the user's request with the json schema
|
||||
|
||||
This analysis will be used to guide the HTML structure examination and ultimately inform the code generation process.
|
||||
Please generate only the analysis and no other text.
|
||||
|
||||
**Response**:
|
||||
"""
|
||||
|
||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
input_data = [state[key] for key in input_keys]
|
||||
user_prompt = input_data[0]
|
||||
user_prompt = input_data[0] # get user prompt
|
||||
|
||||
if self.node_config.get("schema", None) is not None:
|
||||
|
||||
self.schema = self.node_config["schema"]
|
||||
self.schema = self.node_config["schema"] # get JSON schema
|
||||
|
||||
if self.additional_info is not None: # add context to the prompt
|
||||
pass
|
||||
if self.additional_info is not None: # use additional context if present
|
||||
prompt = PromptTemplate(
|
||||
template=template_prompt_builder_with_context,
|
||||
partial_variables={"user_input": user_prompt,
|
||||
"json_schema": self.schema,
|
||||
"additional_context": self.additional_info})
|
||||
else:
|
||||
prompt = PromptTemplate(
|
||||
template=template_prompt_builder,
|
||||
partial_variables={"user_input": user_prompt,
|
||||
"json_schema": self.schema})
|
||||
|
||||
template_prompt_builder = """
|
||||
You are tasked with generating a prompt that will guide an LLM in reasoning about how to identify specific elements within an HTML page for data extraction.
|
||||
**Input:**
|
||||
|
||||
* **User Prompt:** The user's natural language description of the data they want to extract from the HTML page.
|
||||
* **JSON Schema:** A JSON schema representing the desired output structure of the extracted data.
|
||||
* **Additional Information (Optional):** Any supplementary details provided by the user, such as specific HTML patterns they've observed, known challenges in identifying certain elements, or preferences for particular scraping strategies.
|
||||
output_parser = StrOutputParser()
|
||||
|
||||
**Output:**
|
||||
"""
|
||||
|
||||
example_prompts = [
|
||||
"""
|
||||
|
||||
"""
|
||||
]
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template=template_no_chunks_prompt ,
|
||||
input_variables=["question"],
|
||||
partial_variables={"context": doc,
|
||||
"format_instructions": format_instructions})
|
||||
chain = prompt | self.llm_model | output_parser
|
||||
answer = chain.invoke({"question": user_prompt})
|
||||
refined_prompt = chain.invoke({})
|
||||
|
||||
state.update({self.output[0]: answer})
|
||||
state.update({self.output[0]: refined_prompt})
|
||||
return state
|
||||
|
||||
chains_dict = {}
|
||||
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template=TEMPLATE_CHUNKS,
|
||||
input_variables=["question"],
|
||||
partial_variables={"context": chunk,
|
||||
"chunk_id": i + 1,
|
||||
"format_instructions": format_instructions})
|
||||
chain_name = f"chunk{i+1}"
|
||||
chains_dict[chain_name] = prompt | self.llm_model | output_parser
|
||||
|
||||
async_runner = RunnableParallel(**chains_dict)
|
||||
|
||||
batch_results = async_runner.invoke({"question": user_prompt})
|
||||
|
||||
merge_prompt = PromptTemplate(
|
||||
template = template_merge_prompt ,
|
||||
input_variables=["context", "question"],
|
||||
partial_variables={"format_instructions": format_instructions},
|
||||
)
|
||||
|
||||
merge_chain = merge_prompt | self.llm_model | output_parser
|
||||
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
|
||||
|
||||
state.update({self.output[0]: answer})
|
||||
return state
|
||||
else: # no schema provided
|
||||
self.logger.error("No schema provided for prompt refinement.")
|
||||
|
||||
# TODO: Handle the case where no schema is provided => error handling
|
||||
|
||||
return state
|
||||
|
||||
Loading…
Reference in New Issue
Block a user