refactoring of the graphs

This commit is contained in:
VinciGit00 2024-04-26 11:08:56 +02:00
parent b2ebabd32f
commit cd8d3e7a4f
7 changed files with 33 additions and 27 deletions

View File

@ -69,12 +69,13 @@ graph = BaseGraph(
rag_node,
generate_answer_node,
],
edges={
edges=[
(robot_node, fetch_node),
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
},
],
entry_point=robot_node
)
# ************************************************

View File

@ -13,6 +13,8 @@ pylint pylint scrapegraphai/**/*.py scrapegraphai/*.py tests/**/*.py
cd tests
poetry install
# Run pytest
if ! pytest; then
echo "Pytest failed. Aborting commit and push."

View File

@ -11,29 +11,29 @@ class BaseGraph:
BaseGraph manages the execution flow of a graph composed of interconnected nodes.
Attributes:
nodes (dict): A dictionary mapping each node's name to its corresponding node instance.
edges (dict): A dictionary representing the directed edges of the graph where each
nodes (list): A dictionary mapping each node's name to its corresponding node instance.
edges (list): A dictionary representing the directed edges of the graph where each
key-value pair corresponds to the from-node and to-node relationship.
entry_point (str): The name of the entry point node from which the graph execution begins.
Methods:
execute(initial_state): Executes the graph's nodes starting from the entry point and
execute(initial_state): Executes the graph's nodes starting from the entry point and
traverses the graph based on the provided initial state.
Args:
nodes (iterable): An iterable of node instances that will be part of the graph.
edges (iterable): An iterable of tuples where each tuple represents a directed edge
edges (iterable): An iterable of tuples where each tuple represents a directed edge
in the graph, defined by a pair of nodes (from_node, to_node).
entry_point (BaseNode): The node instance that represents the entry point of the graph.
"""
def __init__(self, nodes: list, edges: dict, entry_point: str):
def __init__(self, nodes: list, edges: list, entry_point: str):
"""
Initializes the graph with nodes, edges, and the entry point.
"""
self.nodes = {node.node_name: node for node in nodes}
self.edges = self._create_edges(edges)
self.nodes = nodes
self.edges = self._create_edges({e for e in edges})
self.entry_point = entry_point.node_name
if nodes[0].node_name != entry_point.node_name:
@ -58,8 +58,8 @@ class BaseGraph:
def execute(self, initial_state: dict) -> dict:
"""
Executes the graph by traversing nodes starting from the entry point. The execution
follows the edges based on the result of each node's execution and continues until
Executes the graph by traversing nodes starting from the entry point. The execution
follows the edges based on the result of each node's execution and continues until
it reaches a node with no outgoing edges.
Args:
@ -68,6 +68,7 @@ class BaseGraph:
Returns:
dict: The state after execution has completed, which may have been altered by the nodes.
"""
print(self.nodes)
current_node_name = self.nodes[0]
state = initial_state

View File

@ -1,4 +1,4 @@
"""
"""
Module for creating the smart scraper
"""
from .base_graph import BaseGraph
@ -57,17 +57,17 @@ class ScriptCreatorGraph(AbstractGraph):
)
return BaseGraph(
nodes={
nodes=[
fetch_node,
parse_node,
rag_node,
generate_scraper_node,
},
edges={
],
edges=[
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_scraper_node)
},
],
entry_point=fetch_node
)

View File

@ -11,6 +11,7 @@ from ..nodes import (
)
from .abstract_graph import AbstractGraph
class SearchGraph(AbstractGraph):
"""
Module for searching info on the internet
@ -49,19 +50,19 @@ class SearchGraph(AbstractGraph):
)
return BaseGraph(
nodes={
nodes=[
search_internet_node,
fetch_node,
parse_node,
rag_node,
generate_answer_node,
},
edges={
],
edges=[
(search_internet_node, fetch_node),
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
},
],
entry_point=search_internet_node
)

View File

@ -1,4 +1,4 @@
"""
"""
Module for creating the smart scraper
"""
from .base_graph import BaseGraph
@ -59,11 +59,12 @@ class SmartScraperGraph(AbstractGraph):
rag_node,
generate_answer_node,
],
edges={
edges=[
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
}
],
entry_point=fetch_node
)
def run(self) -> str:

View File

@ -62,19 +62,19 @@ class SpeechGraph(AbstractGraph):
)
return BaseGraph(
nodes={
nodes=[
fetch_node,
parse_node,
rag_node,
generate_answer_node,
text_to_speech_node
},
edges={
],
edges=[
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node),
(generate_answer_node, text_to_speech_node)
},
],
entry_point=fetch_node
)