From b2ebabd32f77889d1b2a2a2884b2902136d0ff11 Mon Sep 17 00:00:00 2001 From: VinciGit00 Date: Thu, 25 Apr 2024 19:45:49 +0200 Subject: [PATCH] Update base_graph.py --- scrapegraphai/graphs/base_graph.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index 192f1b3c..53a14e08 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -2,6 +2,7 @@ Module for creating the base graphs """ import time +import warnings from langchain_community.callbacks import get_openai_callback @@ -26,12 +27,19 @@ class BaseGraph: entry_point (BaseNode): The node instance that represents the entry point of the graph. """ - def __init__(self, nodes: list, edges: list): + def __init__(self, nodes: list, edges: dict, entry_point: str): """ Initializes the graph with nodes, edges, and the entry point. """ - self.nodes = nodes + + self.nodes = {node.node_name: node for node in nodes} self.edges = self._create_edges(edges) + self.entry_point = entry_point.node_name + + if nodes[0].node_name != entry_point.node_name: + # raise a warning if the entry point is not the first node in the list + warnings.warn( + "Careful! The entry point node is different from the first node if the graph.") def _create_edges(self, edges: list) -> dict: """