diff --git a/scrapegraphai/nodes/prompt_refiner_node.py b/scrapegraphai/nodes/prompt_refiner_node.py index 3bd219ed..1748aec0 100644 --- a/scrapegraphai/nodes/prompt_refiner_node.py +++ b/scrapegraphai/nodes/prompt_refiner_node.py @@ -3,7 +3,7 @@ PromptRefinerNode Module """ from typing import List, Optional from langchain.prompts import PromptTemplate -from langchain_core.output_parsers import JsonOutputParser +from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableParallel from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_openai import ChatOpenAI, AzureChatOpenAI @@ -61,8 +61,7 @@ class PromptRefinerNode(BaseNode): def execute(self, state: dict) -> dict: """ - Generates an answer by constructing a prompt from the user's input and the scraped - content, querying the language model, and parsing its response. + Generate a refined prompt using the user's prompt, the schema, and additional context. Args: state (dict): The current state of the graph. The input keys will be used @@ -76,72 +75,83 @@ class PromptRefinerNode(BaseNode): that the necessary information for generating an answer is missing. """ + template_prompt_builder = """ + **Task**: Analyze the user's request and the desired output schema to create a structured description for web scraping. Carefully examine both the user's request and the JSON schema to understand the desired data elements and their relationships. + + **User's Request**: + {user_input} + + **Desired JSON Output Schema**: + ```json + {json_schema} + ``` + + **Analysis Instructions**: + Genarate the breakdown of the user request and link the elements of the user's request with the json schema + + This analysis will be used to guide the HTML structure examination and ultimately inform the code generation process. + Please generate only the analysis and no other text. + + **Response**: + """ + + template_prompt_builder_with_context = """ + **Task**: Analyze the user's request, the desired output schema, and the additional context the user provided to create a structured description for web scraping. Carefully examine both the user's request and the JSON schema to understand the desired data elements and their relationships. + + **User's Request**: + {user_input} + + **Desired JSON Output Schema**: + ```json + {json_schema} + ``` + + **Additional Context**: + {additional_context} + + **Analysis Instructions**: + Genarate the breakdown of the user request and link the elements of the user's request with the json schema + + This analysis will be used to guide the HTML structure examination and ultimately inform the code generation process. + Please generate only the analysis and no other text. + + **Response**: + """ + self.logger.info(f"--- Executing {self.node_name} Node ---") input_keys = self.get_input_keys(state) input_data = [state[key] for key in input_keys] - user_prompt = input_data[0] + user_prompt = input_data[0] # get user prompt if self.node_config.get("schema", None) is not None: - self.schema = self.node_config["schema"] + self.schema = self.node_config["schema"] # get JSON schema - if self.additional_info is not None: # add context to the prompt - pass + if self.additional_info is not None: # use additional context if present + prompt = PromptTemplate( + template=template_prompt_builder_with_context, + partial_variables={"user_input": user_prompt, + "json_schema": self.schema, + "additional_context": self.additional_info}) + else: + prompt = PromptTemplate( + template=template_prompt_builder, + partial_variables={"user_input": user_prompt, + "json_schema": self.schema}) - template_prompt_builder = """ - You are tasked with generating a prompt that will guide an LLM in reasoning about how to identify specific elements within an HTML page for data extraction. - **Input:** - - * **User Prompt:** The user's natural language description of the data they want to extract from the HTML page. - * **JSON Schema:** A JSON schema representing the desired output structure of the extracted data. - * **Additional Information (Optional):** Any supplementary details provided by the user, such as specific HTML patterns they've observed, known challenges in identifying certain elements, or preferences for particular scraping strategies. + output_parser = StrOutputParser() - **Output:** - """ - - example_prompts = [ - """ - - """ - ] - - prompt = PromptTemplate( - template=template_no_chunks_prompt , - input_variables=["question"], - partial_variables={"context": doc, - "format_instructions": format_instructions}) chain = prompt | self.llm_model | output_parser - answer = chain.invoke({"question": user_prompt}) + refined_prompt = chain.invoke({}) - state.update({self.output[0]: answer}) + state.update({self.output[0]: refined_prompt}) return state - chains_dict = {} - for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)): - - prompt = PromptTemplate( - template=TEMPLATE_CHUNKS, - input_variables=["question"], - partial_variables={"context": chunk, - "chunk_id": i + 1, - "format_instructions": format_instructions}) - chain_name = f"chunk{i+1}" - chains_dict[chain_name] = prompt | self.llm_model | output_parser - - async_runner = RunnableParallel(**chains_dict) - - batch_results = async_runner.invoke({"question": user_prompt}) - - merge_prompt = PromptTemplate( - template = template_merge_prompt , - input_variables=["context", "question"], - partial_variables={"format_instructions": format_instructions}, - ) - - merge_chain = merge_prompt | self.llm_model | output_parser - answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) - - state.update({self.output[0]: answer}) - return state + else: # no schema provided + self.logger.error("No schema provided for prompt refinement.") + + # TODO: Handle the case where no schema is provided => error handling + + return state