From e714a59c2e22909d2b2dd59e6800916fd66aa974 Mon Sep 17 00:00:00 2001 From: VinciGit00 Date: Thu, 25 Apr 2024 19:22:12 +0200 Subject: [PATCH] refactoring of engine --- examples/openai/custom_graph_openai.py | 5 ++--- scrapegraphai/graphs/base_graph.py | 13 ++++++------- scrapegraphai/graphs/smart_scraper_graph.py | 10 +++++----- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/examples/openai/custom_graph_openai.py b/examples/openai/custom_graph_openai.py index be5a4d55..6b3ee965 100644 --- a/examples/openai/custom_graph_openai.py +++ b/examples/openai/custom_graph_openai.py @@ -62,20 +62,19 @@ generate_answer_node = GenerateAnswerNode( # ************************************************ graph = BaseGraph( - nodes={ + nodes=[ robot_node, fetch_node, parse_node, rag_node, generate_answer_node, - }, + ], edges={ (robot_node, fetch_node), (fetch_node, parse_node), (parse_node, rag_node), (rag_node, generate_answer_node) }, - entry_point=robot_node ) # ************************************************ diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index 8df92b9a..192f1b3c 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -26,15 +26,14 @@ class BaseGraph: entry_point (BaseNode): The node instance that represents the entry point of the graph. """ - def __init__(self, nodes: dict, edges: dict, entry_point: str): + def __init__(self, nodes: list, edges: list): """ Initializes the graph with nodes, edges, and the entry point. """ - self.nodes = {node.node_name: node for node in nodes} + self.nodes = nodes self.edges = self._create_edges(edges) - self.entry_point = entry_point.node_name - def _create_edges(self, edges: dict) -> dict: + def _create_edges(self, edges: list) -> dict: """ Helper method to create a dictionary of edges from the given iterable of tuples. @@ -61,7 +60,7 @@ class BaseGraph: Returns: dict: The state after execution has completed, which may have been altered by the nodes. """ - current_node_name = self.entry_point + current_node_name = self.nodes[0] state = initial_state # variables for tracking execution info @@ -75,10 +74,10 @@ class BaseGraph: "total_cost_USD": 0.0, } - while current_node_name is not None: + for index in self.nodes: curr_time = time.time() - current_node = self.nodes[current_node_name] + current_node = index with get_openai_callback() as cb: result = current_node.execute(state) diff --git a/scrapegraphai/graphs/smart_scraper_graph.py b/scrapegraphai/graphs/smart_scraper_graph.py index e413727b..5cbc8067 100644 --- a/scrapegraphai/graphs/smart_scraper_graph.py +++ b/scrapegraphai/graphs/smart_scraper_graph.py @@ -10,6 +10,7 @@ from ..nodes import ( ) from .abstract_graph import AbstractGraph + class SmartScraperGraph(AbstractGraph): """ SmartScraper is a comprehensive web scraping tool that automates the process of extracting @@ -52,25 +53,24 @@ class SmartScraperGraph(AbstractGraph): ) return BaseGraph( - nodes={ + nodes=[ fetch_node, parse_node, rag_node, generate_answer_node, - }, + ], edges={ (fetch_node, parse_node), (parse_node, rag_node), (rag_node, generate_answer_node) - }, - entry_point=fetch_node + } ) def run(self) -> str: """ Executes the web scraping process and returns the answer to the prompt. """ - inputs = {"user_prompt": self.prompt, self.input_key: self.source} + inputs = {"user_prompt": self.prompt, self.input_key: self.source} self.final_state, self.execution_info = self.graph.execute(inputs) return self.final_state.get("answer", "No answer found.")