From 0571b6da55920bfe691feef2e1ecb5f3760dabf7 Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Tue, 6 Aug 2024 14:01:11 +0200 Subject: [PATCH] feat: update base_graph --- scrapegraphai/graphs/base_graph.py | 39 +++++++++++++++++------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index 21f564d7..052d501c 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -1,7 +1,11 @@ +""" +base_graph module +""" import time import warnings -from langchain_community.callbacks import get_openai_callback from typing import Tuple +from langchain_community.callbacks import get_openai_callback +from ..integrations import BurrBridge # Import telemetry functions from ..telemetry import log_graph_execution, log_event @@ -56,7 +60,7 @@ class BaseGraph: # raise a warning if the entry point is not the first node in the list warnings.warn( "Careful! The entry point node is different from the first node in the graph.") - + # Burr configuration self.use_burr = use_burr self.burr_config = burr_config or {} @@ -79,7 +83,8 @@ class BaseGraph: def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]: """ - Executes the graph by traversing nodes starting from the entry point using the standard method. + Executes the graph by traversing nodes starting from the + entry point using the standard method. Args: initial_state (dict): The initial state to pass to the entry point node. @@ -114,23 +119,25 @@ class BaseGraph: curr_time = time.time() current_node = next(node for node in self.nodes if node.node_name == current_node_name) - # check if there is a "source" key in the node config if current_node.__class__.__name__ == "FetchNode": # get the second key name of the state dictionary source_type = list(state.keys())[1] if state.get("user_prompt", None): - prompt = state["user_prompt"] if type(state["user_prompt"]) == str else None - # quick fix for local_dir source type + # Set 'prompt' if 'user_prompt' is a string, otherwise None + prompt = state["user_prompt"] if isinstance(state["user_prompt"], str) else None + + # Convert 'local_dir' source type to 'html_dir' if source_type == "local_dir": source_type = "html_dir" elif source_type == "url": - if type(state[source_type]) == list: - # iterate through the list of urls and see if they are strings + # If the source is a list, add string URLs to 'source' + if isinstance(state[source_type], list): for url in state[source_type]: - if type(url) == str: + if isinstance(url, str): source.append(url) - elif type(state[source_type]) == str: + # If the source is a single string, add it to 'source' + elif isinstance(state[source_type], str): source.append(state[source_type]) # check if there is an "llm_model" variable in the class @@ -164,7 +171,6 @@ class BaseGraph: result = current_node.execute(state) except Exception as e: error_node = current_node.node_name - graph_execution_time = time.time() - start_time log_graph_execution( graph_name=self.graph_name, @@ -221,7 +227,7 @@ class BaseGraph: graph_execution_time = time.time() - start_time response = state.get("answer", None) if source_type == "url" else None content = state.get("parsed_doc", None) if response is not None else None - + log_graph_execution( graph_name=self.graph_name, source=source, @@ -251,14 +257,13 @@ class BaseGraph: self.initial_state = initial_state if self.use_burr: - from ..integrations import BurrBridge - + bridge = BurrBridge(self, self.burr_config) result = bridge.execute(initial_state) return (result["_state"], []) else: return self._execute_standard(initial_state) - + def append_node(self, node): """ Adds a node to the graph. @@ -266,11 +271,11 @@ class BaseGraph: Args: node (BaseNode): The node instance to add to the graph. """ - + # if node name already exists in the graph, raise an exception if node.node_name in {n.node_name for n in self.nodes}: raise ValueError(f"Node with name '{node.node_name}' already exists in the graph. You can change it by setting the 'node_name' attribute.") - + # get the last node in the list last_node = self.nodes[-1] # add the edge connecting the last node to the new node