feat(base_graph): alligned with main

This commit is contained in:
Marco Perini 2024-05-17 18:54:27 +02:00
parent 02745a4f63
commit 73fa31db0f

View File

@ -6,7 +6,6 @@ import time
import warnings import warnings
from langchain_community.callbacks import get_openai_callback from langchain_community.callbacks import get_openai_callback
from typing import Tuple from typing import Tuple
from collections import deque
class BaseGraph: class BaseGraph:
@ -27,8 +26,6 @@ class BaseGraph:
Raises: Raises:
Warning: If the entry point node is not the first node in the list. Warning: If the entry point node is not the first node in the list.
ValueError: If conditional_node does not have exactly two outgoing edges
Example: Example:
>>> BaseGraph( >>> BaseGraph(
@ -51,7 +48,7 @@ class BaseGraph:
self.nodes = nodes self.nodes = nodes
self.edges = self._create_edges({e for e in edges}) self.edges = self._create_edges({e for e in edges})
self.entry_point = entry_point self.entry_point = entry_point.node_name
if nodes[0].node_name != entry_point.node_name: if nodes[0].node_name != entry_point.node_name:
# raise a warning if the entry point is not the first node in the list # raise a warning if the entry point is not the first node in the list
@ -71,16 +68,13 @@ class BaseGraph:
edge_dict = {} edge_dict = {}
for from_node, to_node in edges: for from_node, to_node in edges:
if from_node in edge_dict: edge_dict[from_node.node_name] = to_node.node_name
edge_dict[from_node].append(to_node)
else:
edge_dict[from_node] = [to_node]
return edge_dict return edge_dict
def execute(self, initial_state: dict) -> Tuple[dict, list]: def execute(self, initial_state: dict) -> Tuple[dict, list]:
""" """
Executes the graph by traversing nodes in breadth-first order starting from the entry point. Executes the graph by traversing nodes starting from the entry point. The execution
The execution follows the edges based on the result of each node's execution and continues until follows the edges based on the result of each node's execution and continues until
it reaches a node with no outgoing edges. it reaches a node with no outgoing edges.
Args: Args:
@ -90,6 +84,7 @@ class BaseGraph:
Tuple[dict, list]: A tuple containing the final state and a list of execution info. Tuple[dict, list]: A tuple containing the final state and a list of execution info.
""" """
current_node_name = self.nodes[0]
state = initial_state state = initial_state
# variables for tracking execution info # variables for tracking execution info
@ -103,22 +98,23 @@ class BaseGraph:
"total_cost_USD": 0.0, "total_cost_USD": 0.0,
} }
queue = deque([self.entry_point]) for index in self.nodes:
while queue:
current_node = queue.popleft()
curr_time = time.time() curr_time = time.time()
with get_openai_callback() as callback: current_node = index
with get_openai_callback() as cb:
result = current_node.execute(state) result = current_node.execute(state)
node_exec_time = time.time() - curr_time node_exec_time = time.time() - curr_time
total_exec_time += node_exec_time total_exec_time += node_exec_time
cb = { cb = {
"node_name": current_node.node_name, "node_name": index.node_name,
"total_tokens": callback.total_tokens, "total_tokens": cb.total_tokens,
"prompt_tokens": callback.prompt_tokens, "prompt_tokens": cb.prompt_tokens,
"completion_tokens": callback.completion_tokens, "completion_tokens": cb.completion_tokens,
"successful_requests": callback.successful_requests, "successful_requests": cb.successful_requests,
"total_cost_USD": callback.total_cost, "total_cost_USD": cb.total_cost,
"exec_time": node_exec_time, "exec_time": node_exec_time,
} }
@ -132,31 +128,21 @@ class BaseGraph:
cb_total["successful_requests"] += cb["successful_requests"] cb_total["successful_requests"] += cb["successful_requests"]
cb_total["total_cost_USD"] += cb["total_cost_USD"] cb_total["total_cost_USD"] += cb["total_cost_USD"]
if current_node.node_type == "conditional_node":
current_node_name = result
current_node_connections = self.edges[current_node] elif current_node_name in self.edges:
if current_node.node_type == 'conditional_node': current_node_name = self.edges[current_node_name]
# Assert that there are exactly two out edges from the conditional node else:
if len(current_node_connections) != 2: current_node_name = None
raise ValueError(f"Conditional node should have exactly two out connections {current_node_connections.node_name}")
if result["next_node"] == 0:
queue.append(current_node_connections[0])
else:
queue.append(current_node_connections[1])
# remove the conditional node result
del result["next_node"]
else:
queue.extend(node for node in current_node_connections)
exec_info.append({
"node_name": "TOTAL RESULT",
"total_tokens": cb_total["total_tokens"],
"prompt_tokens": cb_total["prompt_tokens"],
"completion_tokens": cb_total["completion_tokens"],
"successful_requests": cb_total["successful_requests"],
"total_cost_USD": cb_total["total_cost_USD"],
"exec_time": total_exec_time,
})
exec_info.append({ return state, exec_info
"node_name": "TOTAL RESULT",
"total_tokens": cb_total["total_tokens"],
"prompt_tokens": cb_total["prompt_tokens"],
"completion_tokens": cb_total["completion_tokens"],
"successful_requests": cb_total["successful_requests"],
"total_cost_USD": cb_total["total_cost_USD"],
"exec_time": total_exec_time,
})
return state, exec_info