diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index 192f1b3c..53a14e08 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -2,6 +2,7 @@ Module for creating the base graphs """ import time +import warnings from langchain_community.callbacks import get_openai_callback @@ -26,12 +27,19 @@ class BaseGraph: entry_point (BaseNode): The node instance that represents the entry point of the graph. """ - def __init__(self, nodes: list, edges: list): + def __init__(self, nodes: list, edges: dict, entry_point: str): """ Initializes the graph with nodes, edges, and the entry point. """ - self.nodes = nodes + + self.nodes = {node.node_name: node for node in nodes} self.edges = self._create_edges(edges) + self.entry_point = entry_point.node_name + + if nodes[0].node_name != entry_point.node_name: + # raise a warning if the entry point is not the first node in the list + warnings.warn( + "Careful! The entry point node is different from the first node if the graph.") def _create_edges(self, edges: list) -> dict: """