mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
feat(base_graph): alligned with main
This commit is contained in:
parent
02745a4f63
commit
73fa31db0f
@ -6,7 +6,6 @@ import time
|
||||
import warnings
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from typing import Tuple
|
||||
from collections import deque
|
||||
|
||||
|
||||
class BaseGraph:
|
||||
@ -27,8 +26,6 @@ class BaseGraph:
|
||||
|
||||
Raises:
|
||||
Warning: If the entry point node is not the first node in the list.
|
||||
ValueError: If conditional_node does not have exactly two outgoing edges
|
||||
|
||||
|
||||
Example:
|
||||
>>> BaseGraph(
|
||||
@ -51,7 +48,7 @@ class BaseGraph:
|
||||
|
||||
self.nodes = nodes
|
||||
self.edges = self._create_edges({e for e in edges})
|
||||
self.entry_point = entry_point
|
||||
self.entry_point = entry_point.node_name
|
||||
|
||||
if nodes[0].node_name != entry_point.node_name:
|
||||
# raise a warning if the entry point is not the first node in the list
|
||||
@ -71,16 +68,13 @@ class BaseGraph:
|
||||
|
||||
edge_dict = {}
|
||||
for from_node, to_node in edges:
|
||||
if from_node in edge_dict:
|
||||
edge_dict[from_node].append(to_node)
|
||||
else:
|
||||
edge_dict[from_node] = [to_node]
|
||||
edge_dict[from_node.node_name] = to_node.node_name
|
||||
return edge_dict
|
||||
|
||||
def execute(self, initial_state: dict) -> Tuple[dict, list]:
|
||||
"""
|
||||
Executes the graph by traversing nodes in breadth-first order starting from the entry point.
|
||||
The execution follows the edges based on the result of each node's execution and continues until
|
||||
Executes the graph by traversing nodes starting from the entry point. The execution
|
||||
follows the edges based on the result of each node's execution and continues until
|
||||
it reaches a node with no outgoing edges.
|
||||
|
||||
Args:
|
||||
@ -90,6 +84,7 @@ class BaseGraph:
|
||||
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
|
||||
"""
|
||||
|
||||
current_node_name = self.nodes[0]
|
||||
state = initial_state
|
||||
|
||||
# variables for tracking execution info
|
||||
@ -103,22 +98,23 @@ class BaseGraph:
|
||||
"total_cost_USD": 0.0,
|
||||
}
|
||||
|
||||
queue = deque([self.entry_point])
|
||||
while queue:
|
||||
current_node = queue.popleft()
|
||||
for index in self.nodes:
|
||||
|
||||
curr_time = time.time()
|
||||
with get_openai_callback() as callback:
|
||||
current_node = index
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
result = current_node.execute(state)
|
||||
node_exec_time = time.time() - curr_time
|
||||
total_exec_time += node_exec_time
|
||||
|
||||
cb = {
|
||||
"node_name": current_node.node_name,
|
||||
"total_tokens": callback.total_tokens,
|
||||
"prompt_tokens": callback.prompt_tokens,
|
||||
"completion_tokens": callback.completion_tokens,
|
||||
"successful_requests": callback.successful_requests,
|
||||
"total_cost_USD": callback.total_cost,
|
||||
"node_name": index.node_name,
|
||||
"total_tokens": cb.total_tokens,
|
||||
"prompt_tokens": cb.prompt_tokens,
|
||||
"completion_tokens": cb.completion_tokens,
|
||||
"successful_requests": cb.successful_requests,
|
||||
"total_cost_USD": cb.total_cost,
|
||||
"exec_time": node_exec_time,
|
||||
}
|
||||
|
||||
@ -132,31 +128,21 @@ class BaseGraph:
|
||||
cb_total["successful_requests"] += cb["successful_requests"]
|
||||
cb_total["total_cost_USD"] += cb["total_cost_USD"]
|
||||
|
||||
|
||||
|
||||
current_node_connections = self.edges[current_node]
|
||||
if current_node.node_type == 'conditional_node':
|
||||
# Assert that there are exactly two out edges from the conditional node
|
||||
if len(current_node_connections) != 2:
|
||||
raise ValueError(f"Conditional node should have exactly two out connections {current_node_connections.node_name}")
|
||||
if result["next_node"] == 0:
|
||||
queue.append(current_node_connections[0])
|
||||
else:
|
||||
queue.append(current_node_connections[1])
|
||||
# remove the conditional node result
|
||||
del result["next_node"]
|
||||
else:
|
||||
queue.extend(node for node in current_node_connections)
|
||||
if current_node.node_type == "conditional_node":
|
||||
current_node_name = result
|
||||
elif current_node_name in self.edges:
|
||||
current_node_name = self.edges[current_node_name]
|
||||
else:
|
||||
current_node_name = None
|
||||
|
||||
exec_info.append({
|
||||
"node_name": "TOTAL RESULT",
|
||||
"total_tokens": cb_total["total_tokens"],
|
||||
"prompt_tokens": cb_total["prompt_tokens"],
|
||||
"completion_tokens": cb_total["completion_tokens"],
|
||||
"successful_requests": cb_total["successful_requests"],
|
||||
"total_cost_USD": cb_total["total_cost_USD"],
|
||||
"exec_time": total_exec_time,
|
||||
})
|
||||
|
||||
exec_info.append({
|
||||
"node_name": "TOTAL RESULT",
|
||||
"total_tokens": cb_total["total_tokens"],
|
||||
"prompt_tokens": cb_total["prompt_tokens"],
|
||||
"completion_tokens": cb_total["completion_tokens"],
|
||||
"successful_requests": cb_total["successful_requests"],
|
||||
"total_cost_USD": cb_total["total_cost_USD"],
|
||||
"exec_time": total_exec_time,
|
||||
})
|
||||
|
||||
return state, exec_info
|
||||
return state, exec_info
|
||||
Loading…
Reference in New Issue
Block a user