refactoring of engine

This commit is contained in:
VinciGit00 2024-04-25 19:22:12 +02:00
parent 1b004d82a1
commit e714a59c2e
3 changed files with 13 additions and 15 deletions

View File

@ -62,20 +62,19 @@ generate_answer_node = GenerateAnswerNode(
# ************************************************
graph = BaseGraph(
nodes={
nodes=[
robot_node,
fetch_node,
parse_node,
rag_node,
generate_answer_node,
},
],
edges={
(robot_node, fetch_node),
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
},
entry_point=robot_node
)
# ************************************************

View File

@ -26,15 +26,14 @@ class BaseGraph:
entry_point (BaseNode): The node instance that represents the entry point of the graph.
"""
def __init__(self, nodes: dict, edges: dict, entry_point: str):
def __init__(self, nodes: list, edges: list):
"""
Initializes the graph with nodes, edges, and the entry point.
"""
self.nodes = {node.node_name: node for node in nodes}
self.nodes = nodes
self.edges = self._create_edges(edges)
self.entry_point = entry_point.node_name
def _create_edges(self, edges: dict) -> dict:
def _create_edges(self, edges: list) -> dict:
"""
Helper method to create a dictionary of edges from the given iterable of tuples.
@ -61,7 +60,7 @@ class BaseGraph:
Returns:
dict: The state after execution has completed, which may have been altered by the nodes.
"""
current_node_name = self.entry_point
current_node_name = self.nodes[0]
state = initial_state
# variables for tracking execution info
@ -75,10 +74,10 @@ class BaseGraph:
"total_cost_USD": 0.0,
}
while current_node_name is not None:
for index in self.nodes:
curr_time = time.time()
current_node = self.nodes[current_node_name]
current_node = index
with get_openai_callback() as cb:
result = current_node.execute(state)

View File

@ -10,6 +10,7 @@ from ..nodes import (
)
from .abstract_graph import AbstractGraph
class SmartScraperGraph(AbstractGraph):
"""
SmartScraper is a comprehensive web scraping tool that automates the process of extracting
@ -52,25 +53,24 @@ class SmartScraperGraph(AbstractGraph):
)
return BaseGraph(
nodes={
nodes=[
fetch_node,
parse_node,
rag_node,
generate_answer_node,
},
],
edges={
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
},
entry_point=fetch_node
}
)
def run(self) -> str:
"""
Executes the web scraping process and returns the answer to the prompt.
"""
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)
return self.final_state.get("answer", "No answer found.")