mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-25 21:11:11 +08:00
feat(base_graph): alligned with main
This commit is contained in:
parent
02745a4f63
commit
73fa31db0f
@ -6,7 +6,6 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from langchain_community.callbacks import get_openai_callback
|
from langchain_community.callbacks import get_openai_callback
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
|
|
||||||
class BaseGraph:
|
class BaseGraph:
|
||||||
@ -27,8 +26,6 @@ class BaseGraph:
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Warning: If the entry point node is not the first node in the list.
|
Warning: If the entry point node is not the first node in the list.
|
||||||
ValueError: If conditional_node does not have exactly two outgoing edges
|
|
||||||
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> BaseGraph(
|
>>> BaseGraph(
|
||||||
@ -51,7 +48,7 @@ class BaseGraph:
|
|||||||
|
|
||||||
self.nodes = nodes
|
self.nodes = nodes
|
||||||
self.edges = self._create_edges({e for e in edges})
|
self.edges = self._create_edges({e for e in edges})
|
||||||
self.entry_point = entry_point
|
self.entry_point = entry_point.node_name
|
||||||
|
|
||||||
if nodes[0].node_name != entry_point.node_name:
|
if nodes[0].node_name != entry_point.node_name:
|
||||||
# raise a warning if the entry point is not the first node in the list
|
# raise a warning if the entry point is not the first node in the list
|
||||||
@ -71,16 +68,13 @@ class BaseGraph:
|
|||||||
|
|
||||||
edge_dict = {}
|
edge_dict = {}
|
||||||
for from_node, to_node in edges:
|
for from_node, to_node in edges:
|
||||||
if from_node in edge_dict:
|
edge_dict[from_node.node_name] = to_node.node_name
|
||||||
edge_dict[from_node].append(to_node)
|
|
||||||
else:
|
|
||||||
edge_dict[from_node] = [to_node]
|
|
||||||
return edge_dict
|
return edge_dict
|
||||||
|
|
||||||
def execute(self, initial_state: dict) -> Tuple[dict, list]:
|
def execute(self, initial_state: dict) -> Tuple[dict, list]:
|
||||||
"""
|
"""
|
||||||
Executes the graph by traversing nodes in breadth-first order starting from the entry point.
|
Executes the graph by traversing nodes starting from the entry point. The execution
|
||||||
The execution follows the edges based on the result of each node's execution and continues until
|
follows the edges based on the result of each node's execution and continues until
|
||||||
it reaches a node with no outgoing edges.
|
it reaches a node with no outgoing edges.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -90,6 +84,7 @@ class BaseGraph:
|
|||||||
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
|
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
current_node_name = self.nodes[0]
|
||||||
state = initial_state
|
state = initial_state
|
||||||
|
|
||||||
# variables for tracking execution info
|
# variables for tracking execution info
|
||||||
@ -103,22 +98,23 @@ class BaseGraph:
|
|||||||
"total_cost_USD": 0.0,
|
"total_cost_USD": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
queue = deque([self.entry_point])
|
for index in self.nodes:
|
||||||
while queue:
|
|
||||||
current_node = queue.popleft()
|
|
||||||
curr_time = time.time()
|
curr_time = time.time()
|
||||||
with get_openai_callback() as callback:
|
current_node = index
|
||||||
|
|
||||||
|
with get_openai_callback() as cb:
|
||||||
result = current_node.execute(state)
|
result = current_node.execute(state)
|
||||||
node_exec_time = time.time() - curr_time
|
node_exec_time = time.time() - curr_time
|
||||||
total_exec_time += node_exec_time
|
total_exec_time += node_exec_time
|
||||||
|
|
||||||
cb = {
|
cb = {
|
||||||
"node_name": current_node.node_name,
|
"node_name": index.node_name,
|
||||||
"total_tokens": callback.total_tokens,
|
"total_tokens": cb.total_tokens,
|
||||||
"prompt_tokens": callback.prompt_tokens,
|
"prompt_tokens": cb.prompt_tokens,
|
||||||
"completion_tokens": callback.completion_tokens,
|
"completion_tokens": cb.completion_tokens,
|
||||||
"successful_requests": callback.successful_requests,
|
"successful_requests": cb.successful_requests,
|
||||||
"total_cost_USD": callback.total_cost,
|
"total_cost_USD": cb.total_cost,
|
||||||
"exec_time": node_exec_time,
|
"exec_time": node_exec_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,31 +128,21 @@ class BaseGraph:
|
|||||||
cb_total["successful_requests"] += cb["successful_requests"]
|
cb_total["successful_requests"] += cb["successful_requests"]
|
||||||
cb_total["total_cost_USD"] += cb["total_cost_USD"]
|
cb_total["total_cost_USD"] += cb["total_cost_USD"]
|
||||||
|
|
||||||
|
if current_node.node_type == "conditional_node":
|
||||||
|
current_node_name = result
|
||||||
current_node_connections = self.edges[current_node]
|
elif current_node_name in self.edges:
|
||||||
if current_node.node_type == 'conditional_node':
|
current_node_name = self.edges[current_node_name]
|
||||||
# Assert that there are exactly two out edges from the conditional node
|
else:
|
||||||
if len(current_node_connections) != 2:
|
current_node_name = None
|
||||||
raise ValueError(f"Conditional node should have exactly two out connections {current_node_connections.node_name}")
|
|
||||||
if result["next_node"] == 0:
|
|
||||||
queue.append(current_node_connections[0])
|
|
||||||
else:
|
|
||||||
queue.append(current_node_connections[1])
|
|
||||||
# remove the conditional node result
|
|
||||||
del result["next_node"]
|
|
||||||
else:
|
|
||||||
queue.extend(node for node in current_node_connections)
|
|
||||||
|
|
||||||
|
exec_info.append({
|
||||||
|
"node_name": "TOTAL RESULT",
|
||||||
|
"total_tokens": cb_total["total_tokens"],
|
||||||
|
"prompt_tokens": cb_total["prompt_tokens"],
|
||||||
|
"completion_tokens": cb_total["completion_tokens"],
|
||||||
|
"successful_requests": cb_total["successful_requests"],
|
||||||
|
"total_cost_USD": cb_total["total_cost_USD"],
|
||||||
|
"exec_time": total_exec_time,
|
||||||
|
})
|
||||||
|
|
||||||
exec_info.append({
|
return state, exec_info
|
||||||
"node_name": "TOTAL RESULT",
|
|
||||||
"total_tokens": cb_total["total_tokens"],
|
|
||||||
"prompt_tokens": cb_total["prompt_tokens"],
|
|
||||||
"completion_tokens": cb_total["completion_tokens"],
|
|
||||||
"successful_requests": cb_total["successful_requests"],
|
|
||||||
"total_cost_USD": cb_total["total_cost_USD"],
|
|
||||||
"exec_time": total_exec_time,
|
|
||||||
})
|
|
||||||
|
|
||||||
return state, exec_info
|
|
||||||
Loading…
Reference in New Issue
Block a user