diff --git a/examples/openai/search_graph_multi.py b/examples/openai/search_graph_multi.py index 01b3d634..962397c7 100644 --- a/examples/openai/search_graph_multi.py +++ b/examples/openai/search_graph_multi.py @@ -45,6 +45,7 @@ search_internet_node = SearchInternetNode( output=["urls"], node_config={ "llm_model": llm_model, + "max_results": 5, # num of search results to fetch "verbose": True, } ) diff --git a/examples/openai/search_graph_openai.py b/examples/openai/search_graph_openai.py index 0e0ca28d..486d9a62 100644 --- a/examples/openai/search_graph_openai.py +++ b/examples/openai/search_graph_openai.py @@ -19,6 +19,8 @@ graph_config = { "api_key": openai_key, "model": "gpt-3.5-turbo", }, + "max_results": 5, + "verbose": True, } # ************************************************ @@ -26,7 +28,7 @@ graph_config = { # ************************************************ search_graph = SearchGraph( - prompt="List me top 5 eyeliner products for a gift.", + prompt="List me the best escursions near Trento", config=graph_config ) diff --git a/scrapegraphai/graphs/search_graph.py b/scrapegraphai/graphs/search_graph.py index 9c463e1a..75d0d304 100644 --- a/scrapegraphai/graphs/search_graph.py +++ b/scrapegraphai/graphs/search_graph.py @@ -5,12 +5,11 @@ SearchGraph Module from .base_graph import BaseGraph from ..nodes import ( SearchInternetNode, - FetchNode, - ParseNode, - RAGNode, - GenerateAnswerNode + GraphIteratorNode, + MergeAnswersNode ) from .abstract_graph import AbstractGraph +from .smart_scraper_graph import SmartScraperGraph class SearchGraph(AbstractGraph): @@ -38,6 +37,11 @@ class SearchGraph(AbstractGraph): >>> result = search_graph.run() """ + def __init__(self, prompt: str, config: dict): + + self.max_results = config.get("max_results", 3) + super().__init__(prompt, config) + def _create_graph(self) -> BaseGraph: """ Creates the graph of nodes representing the workflow for web scraping and searching. @@ -46,53 +50,53 @@ class SearchGraph(AbstractGraph): BaseGraph: A graph instance representing the web scraping and searching workflow. """ + # ************************************************ + # Create a SmartScraperGraph instance + # ************************************************ + + smart_scraper_instance = SmartScraperGraph( + prompt="", + source="", + config=self.config + ) + + # ************************************************ + # Define the graph nodes + # ************************************************ + search_internet_node = SearchInternetNode( input="user_prompt", - output=["url"], - node_config={ - "llm_model": self.llm_model - } - ) - fetch_node = FetchNode( - input="url | local_dir", - output=["doc"] - ) - parse_node = ParseNode( - input="doc", - output=["parsed_doc"], - node_config={ - "chunk_size": self.model_token - } - ) - rag_node = RAGNode( - input="user_prompt & (parsed_doc | doc)", - output=["relevant_chunks"], + output=["urls"], node_config={ "llm_model": self.llm_model, - "embedder_model": self.embedder_model + "max_results": self.max_results } ) - generate_answer_node = GenerateAnswerNode( - input="user_prompt & (relevant_chunks | parsed_doc | doc)", + graph_iterator_node = GraphIteratorNode( + input="user_prompt & urls", + output=["results"], + node_config={ + "graph_instance": smart_scraper_instance, + } + ) + + merge_answers_node = MergeAnswersNode( + input="user_prompt & results", output=["answer"], node_config={ - "llm_model": self.llm_model + "llm_model": self.llm_model, } ) return BaseGraph( nodes=[ search_internet_node, - fetch_node, - parse_node, - rag_node, - generate_answer_node, + graph_iterator_node, + merge_answers_node ], edges=[ - (search_internet_node, fetch_node), - (fetch_node, parse_node), - (parse_node, rag_node), - (rag_node, generate_answer_node) + (search_internet_node, graph_iterator_node), + (graph_iterator_node, merge_answers_node) ], entry_point=search_internet_node ) diff --git a/scrapegraphai/nodes/graph_iterator_node.py b/scrapegraphai/nodes/graph_iterator_node.py index cea43df4..663adc62 100644 --- a/scrapegraphai/nodes/graph_iterator_node.py +++ b/scrapegraphai/nodes/graph_iterator_node.py @@ -10,11 +10,8 @@ from .base_node import BaseNode class GraphIteratorNode(BaseNode): """ - A node responsible for parsing HTML content from a document. - The parsed content is split into chunks for further processing. - - This node enhances the scraping workflow by allowing for targeted extraction of - content, thereby optimizing the processing of large HTML documents. + A node responsible for instantiating and running multiple graph instances in parallel. + It creates as many graph instances as the number of elements in the input list. Attributes: verbose (bool): A flag indicating whether to show print statements during execution. @@ -33,18 +30,18 @@ class GraphIteratorNode(BaseNode): def execute(self, state: dict) -> dict: """ - Executes the node's logic to parse the HTML document content and split it into chunks. + Executes the node's logic to instantiate and run multiple graph instances in parallel. Args: - state (dict): The current state of the graph. The input keys will be used to fetch the - correct data from the state. + state (dict): The current state of the graph. The input keys will be used to fetch + the correct data from the state. Returns: - dict: The updated state with the output key containing the parsed content chunks. + dict: The updated state with the output key containing the results of the graph instances. Raises: KeyError: If the input keys are not found in the state, indicating that the - necessary information for parsing the content is missing. + necessary information for running the graph instances is missing. """ if self.verbose: @@ -79,5 +76,4 @@ class GraphIteratorNode(BaseNode): graphs_answers.append(result) state.update({self.output[0]: graphs_answers}) - return state diff --git a/scrapegraphai/nodes/merge_answers_node.py b/scrapegraphai/nodes/merge_answers_node.py index aa3e410c..a5f52220 100644 --- a/scrapegraphai/nodes/merge_answers_node.py +++ b/scrapegraphai/nodes/merge_answers_node.py @@ -9,7 +9,6 @@ from tqdm import tqdm # Imports from Langchain from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser -from langchain_core.runnables import RunnableParallel # Imports from the library from .base_node import BaseNode @@ -17,10 +16,7 @@ from .base_node import BaseNode class MergeAnswersNode(BaseNode): """ - A node that generates an answer using a large language model (LLM) based on the user's input - and the content extracted from a webpage. It constructs a prompt from the user's input - and the scraped content, feeds it to the LLM, and parses the LLM's response to produce - an answer. + A node responsible for merging the answers from multiple graph instances into a single answer. Attributes: llm_model: An instance of a language model client, configured for generating answers. @@ -42,8 +38,7 @@ class MergeAnswersNode(BaseNode): def execute(self, state: dict) -> dict: """ - Generates an answer by constructing a prompt from the user's input and the scraped - content, querying the language model, and parsing its response. + Executes the node's logic to merge the answers from multiple graph instances into a single answer. Args: state (dict): The current state of the graph. The input keys will be used