mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
refactoring of engine
This commit is contained in:
parent
1b004d82a1
commit
e714a59c2e
@ -62,20 +62,19 @@ generate_answer_node = GenerateAnswerNode(
|
||||
# ************************************************
|
||||
|
||||
graph = BaseGraph(
|
||||
nodes={
|
||||
nodes=[
|
||||
robot_node,
|
||||
fetch_node,
|
||||
parse_node,
|
||||
rag_node,
|
||||
generate_answer_node,
|
||||
},
|
||||
],
|
||||
edges={
|
||||
(robot_node, fetch_node),
|
||||
(fetch_node, parse_node),
|
||||
(parse_node, rag_node),
|
||||
(rag_node, generate_answer_node)
|
||||
},
|
||||
entry_point=robot_node
|
||||
)
|
||||
|
||||
# ************************************************
|
||||
|
||||
@ -26,15 +26,14 @@ class BaseGraph:
|
||||
entry_point (BaseNode): The node instance that represents the entry point of the graph.
|
||||
"""
|
||||
|
||||
def __init__(self, nodes: dict, edges: dict, entry_point: str):
|
||||
def __init__(self, nodes: list, edges: list):
|
||||
"""
|
||||
Initializes the graph with nodes, edges, and the entry point.
|
||||
"""
|
||||
self.nodes = {node.node_name: node for node in nodes}
|
||||
self.nodes = nodes
|
||||
self.edges = self._create_edges(edges)
|
||||
self.entry_point = entry_point.node_name
|
||||
|
||||
def _create_edges(self, edges: dict) -> dict:
|
||||
def _create_edges(self, edges: list) -> dict:
|
||||
"""
|
||||
Helper method to create a dictionary of edges from the given iterable of tuples.
|
||||
|
||||
@ -61,7 +60,7 @@ class BaseGraph:
|
||||
Returns:
|
||||
dict: The state after execution has completed, which may have been altered by the nodes.
|
||||
"""
|
||||
current_node_name = self.entry_point
|
||||
current_node_name = self.nodes[0]
|
||||
state = initial_state
|
||||
|
||||
# variables for tracking execution info
|
||||
@ -75,10 +74,10 @@ class BaseGraph:
|
||||
"total_cost_USD": 0.0,
|
||||
}
|
||||
|
||||
while current_node_name is not None:
|
||||
for index in self.nodes:
|
||||
|
||||
curr_time = time.time()
|
||||
current_node = self.nodes[current_node_name]
|
||||
current_node = index
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
result = current_node.execute(state)
|
||||
|
||||
@ -10,6 +10,7 @@ from ..nodes import (
|
||||
)
|
||||
from .abstract_graph import AbstractGraph
|
||||
|
||||
|
||||
class SmartScraperGraph(AbstractGraph):
|
||||
"""
|
||||
SmartScraper is a comprehensive web scraping tool that automates the process of extracting
|
||||
@ -52,25 +53,24 @@ class SmartScraperGraph(AbstractGraph):
|
||||
)
|
||||
|
||||
return BaseGraph(
|
||||
nodes={
|
||||
nodes=[
|
||||
fetch_node,
|
||||
parse_node,
|
||||
rag_node,
|
||||
generate_answer_node,
|
||||
},
|
||||
],
|
||||
edges={
|
||||
(fetch_node, parse_node),
|
||||
(parse_node, rag_node),
|
||||
(rag_node, generate_answer_node)
|
||||
},
|
||||
entry_point=fetch_node
|
||||
}
|
||||
)
|
||||
|
||||
def run(self) -> str:
|
||||
"""
|
||||
Executes the web scraping process and returns the answer to the prompt.
|
||||
"""
|
||||
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
|
||||
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
|
||||
self.final_state, self.execution_info = self.graph.execute(inputs)
|
||||
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user