diff --git a/scrapegraphai/graphs/omni_search_graph.py b/scrapegraphai/graphs/omni_search_graph.py index 01fb9177..669e2df8 100644 --- a/scrapegraphai/graphs/omni_search_graph.py +++ b/scrapegraphai/graphs/omni_search_graph.py @@ -61,12 +61,12 @@ class OmniSearchGraph(AbstractGraph): BaseGraph: A graph instance representing the web scraping and searching workflow. """ - omni_scraper_instance = OmniScraperGraph( - prompt="", - source="", - config=self.copy_config, - schema=self.copy_schema - ) + # omni_scraper_instance = OmniScraperGraph( + # prompt="", + # source="", + # config=self.copy_config, + # schema=self.copy_schema + # ) search_internet_node = SearchInternetNode( input="user_prompt", @@ -81,8 +81,10 @@ class OmniSearchGraph(AbstractGraph): input="user_prompt & urls", output=["results"], node_config={ - "graph_instance": omni_scraper_instance, - } + "graph_instance": OmniScraperGraph, + "scraper_config": self.copy_config, + }, + schema=self.copy_schema ) merge_answers_node = MergeAnswersNode( diff --git a/scrapegraphai/graphs/search_graph.py b/scrapegraphai/graphs/search_graph.py index e1bdd72c..461dc80c 100644 --- a/scrapegraphai/graphs/search_graph.py +++ b/scrapegraphai/graphs/search_graph.py @@ -62,12 +62,12 @@ class SearchGraph(AbstractGraph): BaseGraph: A graph instance representing the web scraping and searching workflow. """ - smart_scraper_instance = SmartScraperGraph( - prompt="", - source="", - config=self.copy_config, - schema=self.copy_schema - ) + # smart_scraper_instance = SmartScraperGraph( + # prompt="", + # source="", + # config=self.copy_config, + # schema=self.copy_schema + # ) search_internet_node = SearchInternetNode( input="user_prompt", @@ -82,8 +82,10 @@ class SearchGraph(AbstractGraph): input="user_prompt & urls", output=["results"], node_config={ - "graph_instance": smart_scraper_instance, - } + "graph_instance": SmartScraperGraph, + "scraper_config": self.copy_config + }, + schema=self.copy_schema ) merge_answers_node = MergeAnswersNode( diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index be97b832..ae92f6c5 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -89,14 +89,14 @@ class GenerateAnswerNode(BaseNode): doc = input_data[1] if self.node_config.get("schema", None) is not None: - + if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): self.llm_model = self.llm_model.with_structured_output( - schema = self.node_config["schema"], - method="function_calling") # json schema works only on specific models + schema = self.node_config["schema"]) # json schema works only on specific models # default parser to empty lambda function - output_parser = lambda x: x + def output_parser(x): + return x if is_basemodel_subclass(self.node_config["schema"]): output_parser = dict format_instructions = "NA" diff --git a/scrapegraphai/nodes/graph_iterator_node.py b/scrapegraphai/nodes/graph_iterator_node.py index fe355d51..8781cf2d 100644 --- a/scrapegraphai/nodes/graph_iterator_node.py +++ b/scrapegraphai/nodes/graph_iterator_node.py @@ -2,11 +2,10 @@ GraphIterator Module """ import asyncio -import copy from typing import List, Optional from tqdm.asyncio import tqdm -from ..utils.logging import get_logger from .base_node import BaseNode +from langchain_core.pydantic_v1 import BaseModel DEFAULT_BATCHSIZE = 16 @@ -31,12 +30,14 @@ class GraphIteratorNode(BaseNode): output: List[str], node_config: Optional[dict] = None, node_name: str = "GraphIterator", + schema: Optional[BaseModel] = None, ): super().__init__(node_name, "node", input, output, 2, node_config) self.verbose = ( False if node_config is None else node_config.get("verbose", False) ) + self.schema = schema def execute(self, state: dict) -> dict: """ @@ -97,16 +98,24 @@ class GraphIteratorNode(BaseNode): urls = input_data[1] graph_instance = self.node_config.get("graph_instance", None) + scraper_config = self.node_config.get("scraper_config", None) if graph_instance is None: raise ValueError("graph instance is required for concurrent execution") - if "graph_depth" in graph_instance.config: - graph_instance.config["graph_depth"] += 1 - else: - graph_instance.config["graph_depth"] = 1 + graph_instance = [graph_instance( + prompt="", + source="", + config=scraper_config, + schema=self.schema) for _ in range(len(urls))] - graph_instance.prompt = user_prompt + for graph in graph_instance: + if "graph_depth" in graph.config: + graph.config["graph_depth"] += 1 + else: + graph.config["graph_depth"] = 1 + + graph.prompt = user_prompt participants = [] @@ -116,13 +125,12 @@ class GraphIteratorNode(BaseNode): async with semaphore: return await asyncio.to_thread(graph.run) - for url in urls: - instance = copy.copy(graph_instance) - instance.source = url + for url, graph in zip(urls, graph_instance): + graph.source = url if url.startswith("http"): - instance.input_key = "url" - participants.append(instance) - + graph.input_key = "url" + participants.append(graph) + futures = [_async_run(graph) for graph in participants] answers = await tqdm.gather(