feat: update base_graph

This commit is contained in:
Marco Vinciguerra 2024-08-06 14:01:11 +02:00
parent 66a29bc5cc
commit 0571b6da55

View File

@ -1,7 +1,11 @@
"""
base_graph module
"""
import time
import warnings
from langchain_community.callbacks import get_openai_callback
from typing import Tuple
from langchain_community.callbacks import get_openai_callback
from ..integrations import BurrBridge
# Import telemetry functions
from ..telemetry import log_graph_execution, log_event
@ -56,7 +60,7 @@ class BaseGraph:
# raise a warning if the entry point is not the first node in the list
warnings.warn(
"Careful! The entry point node is different from the first node in the graph.")
# Burr configuration
self.use_burr = use_burr
self.burr_config = burr_config or {}
@ -79,7 +83,8 @@ class BaseGraph:
def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
"""
Executes the graph by traversing nodes starting from the entry point using the standard method.
Executes the graph by traversing nodes starting from the
entry point using the standard method.
Args:
initial_state (dict): The initial state to pass to the entry point node.
@ -114,23 +119,25 @@ class BaseGraph:
curr_time = time.time()
current_node = next(node for node in self.nodes if node.node_name == current_node_name)
# check if there is a "source" key in the node config
if current_node.__class__.__name__ == "FetchNode":
# get the second key name of the state dictionary
source_type = list(state.keys())[1]
if state.get("user_prompt", None):
prompt = state["user_prompt"] if type(state["user_prompt"]) == str else None
# quick fix for local_dir source type
# Set 'prompt' if 'user_prompt' is a string, otherwise None
prompt = state["user_prompt"] if isinstance(state["user_prompt"], str) else None
# Convert 'local_dir' source type to 'html_dir'
if source_type == "local_dir":
source_type = "html_dir"
elif source_type == "url":
if type(state[source_type]) == list:
# iterate through the list of urls and see if they are strings
# If the source is a list, add string URLs to 'source'
if isinstance(state[source_type], list):
for url in state[source_type]:
if type(url) == str:
if isinstance(url, str):
source.append(url)
elif type(state[source_type]) == str:
# If the source is a single string, add it to 'source'
elif isinstance(state[source_type], str):
source.append(state[source_type])
# check if there is an "llm_model" variable in the class
@ -164,7 +171,6 @@ class BaseGraph:
result = current_node.execute(state)
except Exception as e:
error_node = current_node.node_name
graph_execution_time = time.time() - start_time
log_graph_execution(
graph_name=self.graph_name,
@ -221,7 +227,7 @@ class BaseGraph:
graph_execution_time = time.time() - start_time
response = state.get("answer", None) if source_type == "url" else None
content = state.get("parsed_doc", None) if response is not None else None
log_graph_execution(
graph_name=self.graph_name,
source=source,
@ -251,14 +257,13 @@ class BaseGraph:
self.initial_state = initial_state
if self.use_burr:
from ..integrations import BurrBridge
bridge = BurrBridge(self, self.burr_config)
result = bridge.execute(initial_state)
return (result["_state"], [])
else:
return self._execute_standard(initial_state)
def append_node(self, node):
"""
Adds a node to the graph.
@ -266,11 +271,11 @@ class BaseGraph:
Args:
node (BaseNode): The node instance to add to the graph.
"""
# if node name already exists in the graph, raise an exception
if node.node_name in {n.node_name for n in self.nodes}:
raise ValueError(f"Node with name '{node.node_name}' already exists in the graph. You can change it by setting the 'node_name' attribute.")
# get the last node in the list
last_node = self.nodes[-1]
# add the edge connecting the last node to the new node