feat(base_graph): alligned with main

This commit is contained in:
Marco Perini 2024-05-17 18:54:27 +02:00
parent 02745a4f63
commit 73fa31db0f

View File

@ -6,7 +6,6 @@ import time
import warnings
from langchain_community.callbacks import get_openai_callback
from typing import Tuple
from collections import deque
class BaseGraph:
@ -27,8 +26,6 @@ class BaseGraph:
Raises:
Warning: If the entry point node is not the first node in the list.
ValueError: If conditional_node does not have exactly two outgoing edges
Example:
>>> BaseGraph(
@ -51,7 +48,7 @@ class BaseGraph:
self.nodes = nodes
self.edges = self._create_edges({e for e in edges})
self.entry_point = entry_point
self.entry_point = entry_point.node_name
if nodes[0].node_name != entry_point.node_name:
# raise a warning if the entry point is not the first node in the list
@ -71,16 +68,13 @@ class BaseGraph:
edge_dict = {}
for from_node, to_node in edges:
if from_node in edge_dict:
edge_dict[from_node].append(to_node)
else:
edge_dict[from_node] = [to_node]
edge_dict[from_node.node_name] = to_node.node_name
return edge_dict
def execute(self, initial_state: dict) -> Tuple[dict, list]:
"""
Executes the graph by traversing nodes in breadth-first order starting from the entry point.
The execution follows the edges based on the result of each node's execution and continues until
Executes the graph by traversing nodes starting from the entry point. The execution
follows the edges based on the result of each node's execution and continues until
it reaches a node with no outgoing edges.
Args:
@ -90,6 +84,7 @@ class BaseGraph:
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
"""
current_node_name = self.nodes[0]
state = initial_state
# variables for tracking execution info
@ -103,22 +98,23 @@ class BaseGraph:
"total_cost_USD": 0.0,
}
queue = deque([self.entry_point])
while queue:
current_node = queue.popleft()
for index in self.nodes:
curr_time = time.time()
with get_openai_callback() as callback:
current_node = index
with get_openai_callback() as cb:
result = current_node.execute(state)
node_exec_time = time.time() - curr_time
total_exec_time += node_exec_time
cb = {
"node_name": current_node.node_name,
"total_tokens": callback.total_tokens,
"prompt_tokens": callback.prompt_tokens,
"completion_tokens": callback.completion_tokens,
"successful_requests": callback.successful_requests,
"total_cost_USD": callback.total_cost,
"node_name": index.node_name,
"total_tokens": cb.total_tokens,
"prompt_tokens": cb.prompt_tokens,
"completion_tokens": cb.completion_tokens,
"successful_requests": cb.successful_requests,
"total_cost_USD": cb.total_cost,
"exec_time": node_exec_time,
}
@ -132,31 +128,21 @@ class BaseGraph:
cb_total["successful_requests"] += cb["successful_requests"]
cb_total["total_cost_USD"] += cb["total_cost_USD"]
current_node_connections = self.edges[current_node]
if current_node.node_type == 'conditional_node':
# Assert that there are exactly two out edges from the conditional node
if len(current_node_connections) != 2:
raise ValueError(f"Conditional node should have exactly two out connections {current_node_connections.node_name}")
if result["next_node"] == 0:
queue.append(current_node_connections[0])
else:
queue.append(current_node_connections[1])
# remove the conditional node result
del result["next_node"]
else:
queue.extend(node for node in current_node_connections)
if current_node.node_type == "conditional_node":
current_node_name = result
elif current_node_name in self.edges:
current_node_name = self.edges[current_node_name]
else:
current_node_name = None
exec_info.append({
"node_name": "TOTAL RESULT",
"total_tokens": cb_total["total_tokens"],
"prompt_tokens": cb_total["prompt_tokens"],
"completion_tokens": cb_total["completion_tokens"],
"successful_requests": cb_total["successful_requests"],
"total_cost_USD": cb_total["total_cost_USD"],
"exec_time": total_exec_time,
})
exec_info.append({
"node_name": "TOTAL RESULT",
"total_tokens": cb_total["total_tokens"],
"prompt_tokens": cb_total["prompt_tokens"],
"completion_tokens": cb_total["completion_tokens"],
"successful_requests": cb_total["successful_requests"],
"total_cost_USD": cb_total["total_cost_USD"],
"exec_time": total_exec_time,
})
return state, exec_info
return state, exec_info