fix: Fixed pydantic error on SearchGraphs

Changed instatiation location of iterated graph classes
This commit is contained in:
Lorenzo Paleari 2024-09-13 01:56:58 +02:00
parent 88b2c469ae
commit 039ba2e95a
No known key found for this signature in database
GPG Key ID: 010F47E3CB681DED
4 changed files with 45 additions and 33 deletions

View File

@ -61,12 +61,12 @@ class OmniSearchGraph(AbstractGraph):
BaseGraph: A graph instance representing the web scraping and searching workflow. BaseGraph: A graph instance representing the web scraping and searching workflow.
""" """
omni_scraper_instance = OmniScraperGraph( # omni_scraper_instance = OmniScraperGraph(
prompt="", # prompt="",
source="", # source="",
config=self.copy_config, # config=self.copy_config,
schema=self.copy_schema # schema=self.copy_schema
) # )
search_internet_node = SearchInternetNode( search_internet_node = SearchInternetNode(
input="user_prompt", input="user_prompt",
@ -81,8 +81,10 @@ class OmniSearchGraph(AbstractGraph):
input="user_prompt & urls", input="user_prompt & urls",
output=["results"], output=["results"],
node_config={ node_config={
"graph_instance": omni_scraper_instance, "graph_instance": OmniScraperGraph,
} "scraper_config": self.copy_config,
},
schema=self.copy_schema
) )
merge_answers_node = MergeAnswersNode( merge_answers_node = MergeAnswersNode(

View File

@ -62,12 +62,12 @@ class SearchGraph(AbstractGraph):
BaseGraph: A graph instance representing the web scraping and searching workflow. BaseGraph: A graph instance representing the web scraping and searching workflow.
""" """
smart_scraper_instance = SmartScraperGraph( # smart_scraper_instance = SmartScraperGraph(
prompt="", # prompt="",
source="", # source="",
config=self.copy_config, # config=self.copy_config,
schema=self.copy_schema # schema=self.copy_schema
) # )
search_internet_node = SearchInternetNode( search_internet_node = SearchInternetNode(
input="user_prompt", input="user_prompt",
@ -82,8 +82,10 @@ class SearchGraph(AbstractGraph):
input="user_prompt & urls", input="user_prompt & urls",
output=["results"], output=["results"],
node_config={ node_config={
"graph_instance": smart_scraper_instance, "graph_instance": SmartScraperGraph,
} "scraper_config": self.copy_config
},
schema=self.copy_schema
) )
merge_answers_node = MergeAnswersNode( merge_answers_node = MergeAnswersNode(

View File

@ -92,11 +92,11 @@ class GenerateAnswerNode(BaseNode):
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output( self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"], schema = self.node_config["schema"]) # json schema works only on specific models
method="function_calling") # json schema works only on specific models
# default parser to empty lambda function # default parser to empty lambda function
output_parser = lambda x: x def output_parser(x):
return x
if is_basemodel_subclass(self.node_config["schema"]): if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict output_parser = dict
format_instructions = "NA" format_instructions = "NA"

View File

@ -2,11 +2,10 @@
GraphIterator Module GraphIterator Module
""" """
import asyncio import asyncio
import copy
from typing import List, Optional from typing import List, Optional
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from ..utils.logging import get_logger
from .base_node import BaseNode from .base_node import BaseNode
from langchain_core.pydantic_v1 import BaseModel
DEFAULT_BATCHSIZE = 16 DEFAULT_BATCHSIZE = 16
@ -31,12 +30,14 @@ class GraphIteratorNode(BaseNode):
output: List[str], output: List[str],
node_config: Optional[dict] = None, node_config: Optional[dict] = None,
node_name: str = "GraphIterator", node_name: str = "GraphIterator",
schema: Optional[BaseModel] = None,
): ):
super().__init__(node_name, "node", input, output, 2, node_config) super().__init__(node_name, "node", input, output, 2, node_config)
self.verbose = ( self.verbose = (
False if node_config is None else node_config.get("verbose", False) False if node_config is None else node_config.get("verbose", False)
) )
self.schema = schema
def execute(self, state: dict) -> dict: def execute(self, state: dict) -> dict:
""" """
@ -97,16 +98,24 @@ class GraphIteratorNode(BaseNode):
urls = input_data[1] urls = input_data[1]
graph_instance = self.node_config.get("graph_instance", None) graph_instance = self.node_config.get("graph_instance", None)
scraper_config = self.node_config.get("scraper_config", None)
if graph_instance is None: if graph_instance is None:
raise ValueError("graph instance is required for concurrent execution") raise ValueError("graph instance is required for concurrent execution")
if "graph_depth" in graph_instance.config: graph_instance = [graph_instance(
graph_instance.config["graph_depth"] += 1 prompt="",
else: source="",
graph_instance.config["graph_depth"] = 1 config=scraper_config,
schema=self.schema) for _ in range(len(urls))]
graph_instance.prompt = user_prompt for graph in graph_instance:
if "graph_depth" in graph.config:
graph.config["graph_depth"] += 1
else:
graph.config["graph_depth"] = 1
graph.prompt = user_prompt
participants = [] participants = []
@ -116,12 +125,11 @@ class GraphIteratorNode(BaseNode):
async with semaphore: async with semaphore:
return await asyncio.to_thread(graph.run) return await asyncio.to_thread(graph.run)
for url in urls: for url, graph in zip(urls, graph_instance):
instance = copy.copy(graph_instance) graph.source = url
instance.source = url
if url.startswith("http"): if url.startswith("http"):
instance.input_key = "url" graph.input_key = "url"
participants.append(instance) participants.append(graph)
futures = [_async_run(graph) for graph in participants] futures = [_async_run(graph) for graph in participants]