diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index ed5ba54f..7c4df3d8 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -6,7 +6,6 @@ import time import warnings from langchain_community.callbacks import get_openai_callback from typing import Tuple -from collections import deque class BaseGraph: @@ -27,8 +26,6 @@ class BaseGraph: Raises: Warning: If the entry point node is not the first node in the list. - ValueError: If conditional_node does not have exactly two outgoing edges - Example: >>> BaseGraph( @@ -51,7 +48,7 @@ class BaseGraph: self.nodes = nodes self.edges = self._create_edges({e for e in edges}) - self.entry_point = entry_point + self.entry_point = entry_point.node_name if nodes[0].node_name != entry_point.node_name: # raise a warning if the entry point is not the first node in the list @@ -71,16 +68,13 @@ class BaseGraph: edge_dict = {} for from_node, to_node in edges: - if from_node in edge_dict: - edge_dict[from_node].append(to_node) - else: - edge_dict[from_node] = [to_node] + edge_dict[from_node.node_name] = to_node.node_name return edge_dict def execute(self, initial_state: dict) -> Tuple[dict, list]: """ - Executes the graph by traversing nodes in breadth-first order starting from the entry point. - The execution follows the edges based on the result of each node's execution and continues until + Executes the graph by traversing nodes starting from the entry point. The execution + follows the edges based on the result of each node's execution and continues until it reaches a node with no outgoing edges. Args: @@ -90,6 +84,7 @@ class BaseGraph: Tuple[dict, list]: A tuple containing the final state and a list of execution info. """ + current_node_name = self.nodes[0] state = initial_state # variables for tracking execution info @@ -103,22 +98,23 @@ class BaseGraph: "total_cost_USD": 0.0, } - queue = deque([self.entry_point]) - while queue: - current_node = queue.popleft() + for index in self.nodes: + curr_time = time.time() - with get_openai_callback() as callback: + current_node = index + + with get_openai_callback() as cb: result = current_node.execute(state) node_exec_time = time.time() - curr_time total_exec_time += node_exec_time cb = { - "node_name": current_node.node_name, - "total_tokens": callback.total_tokens, - "prompt_tokens": callback.prompt_tokens, - "completion_tokens": callback.completion_tokens, - "successful_requests": callback.successful_requests, - "total_cost_USD": callback.total_cost, + "node_name": index.node_name, + "total_tokens": cb.total_tokens, + "prompt_tokens": cb.prompt_tokens, + "completion_tokens": cb.completion_tokens, + "successful_requests": cb.successful_requests, + "total_cost_USD": cb.total_cost, "exec_time": node_exec_time, } @@ -132,31 +128,21 @@ class BaseGraph: cb_total["successful_requests"] += cb["successful_requests"] cb_total["total_cost_USD"] += cb["total_cost_USD"] - - - current_node_connections = self.edges[current_node] - if current_node.node_type == 'conditional_node': - # Assert that there are exactly two out edges from the conditional node - if len(current_node_connections) != 2: - raise ValueError(f"Conditional node should have exactly two out connections {current_node_connections.node_name}") - if result["next_node"] == 0: - queue.append(current_node_connections[0]) - else: - queue.append(current_node_connections[1]) - # remove the conditional node result - del result["next_node"] - else: - queue.extend(node for node in current_node_connections) + if current_node.node_type == "conditional_node": + current_node_name = result + elif current_node_name in self.edges: + current_node_name = self.edges[current_node_name] + else: + current_node_name = None + exec_info.append({ + "node_name": "TOTAL RESULT", + "total_tokens": cb_total["total_tokens"], + "prompt_tokens": cb_total["prompt_tokens"], + "completion_tokens": cb_total["completion_tokens"], + "successful_requests": cb_total["successful_requests"], + "total_cost_USD": cb_total["total_cost_USD"], + "exec_time": total_exec_time, + }) - exec_info.append({ - "node_name": "TOTAL RESULT", - "total_tokens": cb_total["total_tokens"], - "prompt_tokens": cb_total["prompt_tokens"], - "completion_tokens": cb_total["completion_tokens"], - "successful_requests": cb_total["successful_requests"], - "total_cost_USD": cb_total["total_cost_USD"], - "exec_time": total_exec_time, - }) - - return state, exec_info + return state, exec_info \ No newline at end of file