From cd8d3e7a4f92f2b4e8a41e663c14bee73757ca86 Mon Sep 17 00:00:00 2001 From: VinciGit00 Date: Fri, 26 Apr 2024 11:08:56 +0200 Subject: [PATCH] refactoring of the graphs --- examples/openai/custom_graph_openai.py | 5 +++-- .../commit_and_push_with_tests.sh | 2 ++ scrapegraphai/graphs/base_graph.py | 19 ++++++++++--------- scrapegraphai/graphs/script_creator_graph.py | 10 +++++----- scrapegraphai/graphs/search_graph.py | 9 +++++---- scrapegraphai/graphs/smart_scraper_graph.py | 7 ++++--- scrapegraphai/graphs/speech_graph.py | 8 ++++---- 7 files changed, 33 insertions(+), 27 deletions(-) diff --git a/examples/openai/custom_graph_openai.py b/examples/openai/custom_graph_openai.py index 6b3ee965..175c51ab 100644 --- a/examples/openai/custom_graph_openai.py +++ b/examples/openai/custom_graph_openai.py @@ -69,12 +69,13 @@ graph = BaseGraph( rag_node, generate_answer_node, ], - edges={ + edges=[ (robot_node, fetch_node), (fetch_node, parse_node), (parse_node, rag_node), (rag_node, generate_answer_node) - }, + ], + entry_point=robot_node ) # ************************************************ diff --git a/manual deployment/commit_and_push_with_tests.sh b/manual deployment/commit_and_push_with_tests.sh index 9cf7c1af..d97fe67f 100755 --- a/manual deployment/commit_and_push_with_tests.sh +++ b/manual deployment/commit_and_push_with_tests.sh @@ -13,6 +13,8 @@ pylint pylint scrapegraphai/**/*.py scrapegraphai/*.py tests/**/*.py cd tests +poetry install + # Run pytest if ! pytest; then echo "Pytest failed. Aborting commit and push." diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index 53a14e08..c2ebfb0b 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -11,29 +11,29 @@ class BaseGraph: BaseGraph manages the execution flow of a graph composed of interconnected nodes. Attributes: - nodes (dict): A dictionary mapping each node's name to its corresponding node instance. - edges (dict): A dictionary representing the directed edges of the graph where each + nodes (list): A dictionary mapping each node's name to its corresponding node instance. + edges (list): A dictionary representing the directed edges of the graph where each key-value pair corresponds to the from-node and to-node relationship. entry_point (str): The name of the entry point node from which the graph execution begins. Methods: - execute(initial_state): Executes the graph's nodes starting from the entry point and + execute(initial_state): Executes the graph's nodes starting from the entry point and traverses the graph based on the provided initial state. Args: nodes (iterable): An iterable of node instances that will be part of the graph. - edges (iterable): An iterable of tuples where each tuple represents a directed edge + edges (iterable): An iterable of tuples where each tuple represents a directed edge in the graph, defined by a pair of nodes (from_node, to_node). entry_point (BaseNode): The node instance that represents the entry point of the graph. """ - def __init__(self, nodes: list, edges: dict, entry_point: str): + def __init__(self, nodes: list, edges: list, entry_point: str): """ Initializes the graph with nodes, edges, and the entry point. """ - self.nodes = {node.node_name: node for node in nodes} - self.edges = self._create_edges(edges) + self.nodes = nodes + self.edges = self._create_edges({e for e in edges}) self.entry_point = entry_point.node_name if nodes[0].node_name != entry_point.node_name: @@ -58,8 +58,8 @@ class BaseGraph: def execute(self, initial_state: dict) -> dict: """ - Executes the graph by traversing nodes starting from the entry point. The execution - follows the edges based on the result of each node's execution and continues until + Executes the graph by traversing nodes starting from the entry point. The execution + follows the edges based on the result of each node's execution and continues until it reaches a node with no outgoing edges. Args: @@ -68,6 +68,7 @@ class BaseGraph: Returns: dict: The state after execution has completed, which may have been altered by the nodes. """ + print(self.nodes) current_node_name = self.nodes[0] state = initial_state diff --git a/scrapegraphai/graphs/script_creator_graph.py b/scrapegraphai/graphs/script_creator_graph.py index 06cc7a81..fa86eeb4 100644 --- a/scrapegraphai/graphs/script_creator_graph.py +++ b/scrapegraphai/graphs/script_creator_graph.py @@ -1,4 +1,4 @@ -""" +""" Module for creating the smart scraper """ from .base_graph import BaseGraph @@ -57,17 +57,17 @@ class ScriptCreatorGraph(AbstractGraph): ) return BaseGraph( - nodes={ + nodes=[ fetch_node, parse_node, rag_node, generate_scraper_node, - }, - edges={ + ], + edges=[ (fetch_node, parse_node), (parse_node, rag_node), (rag_node, generate_scraper_node) - }, + ], entry_point=fetch_node ) diff --git a/scrapegraphai/graphs/search_graph.py b/scrapegraphai/graphs/search_graph.py index ad21e485..b48965dd 100644 --- a/scrapegraphai/graphs/search_graph.py +++ b/scrapegraphai/graphs/search_graph.py @@ -11,6 +11,7 @@ from ..nodes import ( ) from .abstract_graph import AbstractGraph + class SearchGraph(AbstractGraph): """ Module for searching info on the internet @@ -49,19 +50,19 @@ class SearchGraph(AbstractGraph): ) return BaseGraph( - nodes={ + nodes=[ search_internet_node, fetch_node, parse_node, rag_node, generate_answer_node, - }, - edges={ + ], + edges=[ (search_internet_node, fetch_node), (fetch_node, parse_node), (parse_node, rag_node), (rag_node, generate_answer_node) - }, + ], entry_point=search_internet_node ) diff --git a/scrapegraphai/graphs/smart_scraper_graph.py b/scrapegraphai/graphs/smart_scraper_graph.py index 5cbc8067..5a520224 100644 --- a/scrapegraphai/graphs/smart_scraper_graph.py +++ b/scrapegraphai/graphs/smart_scraper_graph.py @@ -1,4 +1,4 @@ -""" +""" Module for creating the smart scraper """ from .base_graph import BaseGraph @@ -59,11 +59,12 @@ class SmartScraperGraph(AbstractGraph): rag_node, generate_answer_node, ], - edges={ + edges=[ (fetch_node, parse_node), (parse_node, rag_node), (rag_node, generate_answer_node) - } + ], + entry_point=fetch_node ) def run(self) -> str: diff --git a/scrapegraphai/graphs/speech_graph.py b/scrapegraphai/graphs/speech_graph.py index f050acb4..2b10077f 100644 --- a/scrapegraphai/graphs/speech_graph.py +++ b/scrapegraphai/graphs/speech_graph.py @@ -62,19 +62,19 @@ class SpeechGraph(AbstractGraph): ) return BaseGraph( - nodes={ + nodes=[ fetch_node, parse_node, rag_node, generate_answer_node, text_to_speech_node - }, - edges={ + ], + edges=[ (fetch_node, parse_node), (parse_node, rag_node), (rag_node, generate_answer_node), (generate_answer_node, text_to_speech_node) - }, + ], entry_point=fetch_node )