feat: update base_graph

This commit is contained in:
Marco Vinciguerra 2024-08-06 14:01:11 +02:00
parent 66a29bc5cc
commit 0571b6da55

View File

@ -1,7 +1,11 @@
"""
base_graph module
"""
import time import time
import warnings import warnings
from langchain_community.callbacks import get_openai_callback
from typing import Tuple from typing import Tuple
from langchain_community.callbacks import get_openai_callback
from ..integrations import BurrBridge
# Import telemetry functions # Import telemetry functions
from ..telemetry import log_graph_execution, log_event from ..telemetry import log_graph_execution, log_event
@ -79,7 +83,8 @@ class BaseGraph:
def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]: def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
""" """
Executes the graph by traversing nodes starting from the entry point using the standard method. Executes the graph by traversing nodes starting from the
entry point using the standard method.
Args: Args:
initial_state (dict): The initial state to pass to the entry point node. initial_state (dict): The initial state to pass to the entry point node.
@ -114,23 +119,25 @@ class BaseGraph:
curr_time = time.time() curr_time = time.time()
current_node = next(node for node in self.nodes if node.node_name == current_node_name) current_node = next(node for node in self.nodes if node.node_name == current_node_name)
# check if there is a "source" key in the node config # check if there is a "source" key in the node config
if current_node.__class__.__name__ == "FetchNode": if current_node.__class__.__name__ == "FetchNode":
# get the second key name of the state dictionary # get the second key name of the state dictionary
source_type = list(state.keys())[1] source_type = list(state.keys())[1]
if state.get("user_prompt", None): if state.get("user_prompt", None):
prompt = state["user_prompt"] if type(state["user_prompt"]) == str else None # Set 'prompt' if 'user_prompt' is a string, otherwise None
# quick fix for local_dir source type prompt = state["user_prompt"] if isinstance(state["user_prompt"], str) else None
# Convert 'local_dir' source type to 'html_dir'
if source_type == "local_dir": if source_type == "local_dir":
source_type = "html_dir" source_type = "html_dir"
elif source_type == "url": elif source_type == "url":
if type(state[source_type]) == list: # If the source is a list, add string URLs to 'source'
# iterate through the list of urls and see if they are strings if isinstance(state[source_type], list):
for url in state[source_type]: for url in state[source_type]:
if type(url) == str: if isinstance(url, str):
source.append(url) source.append(url)
elif type(state[source_type]) == str: # If the source is a single string, add it to 'source'
elif isinstance(state[source_type], str):
source.append(state[source_type]) source.append(state[source_type])
# check if there is an "llm_model" variable in the class # check if there is an "llm_model" variable in the class
@ -164,7 +171,6 @@ class BaseGraph:
result = current_node.execute(state) result = current_node.execute(state)
except Exception as e: except Exception as e:
error_node = current_node.node_name error_node = current_node.node_name
graph_execution_time = time.time() - start_time graph_execution_time = time.time() - start_time
log_graph_execution( log_graph_execution(
graph_name=self.graph_name, graph_name=self.graph_name,
@ -251,7 +257,6 @@ class BaseGraph:
self.initial_state = initial_state self.initial_state = initial_state
if self.use_burr: if self.use_burr:
from ..integrations import BurrBridge
bridge = BurrBridge(self, self.burr_config) bridge = BurrBridge(self, self.burr_config)
result = bridge.execute(initial_state) result = bridge.execute(initial_state)