Update base_graph.py

This commit is contained in:
VinciGit00 2024-04-25 19:45:49 +02:00
parent e714a59c2e
commit b2ebabd32f

View File

@ -2,6 +2,7 @@
Module for creating the base graphs
"""
import time
import warnings
from langchain_community.callbacks import get_openai_callback
@ -26,12 +27,19 @@ class BaseGraph:
entry_point (BaseNode): The node instance that represents the entry point of the graph.
"""
def __init__(self, nodes: list, edges: list):
def __init__(self, nodes: list, edges: dict, entry_point: str):
"""
Initializes the graph with nodes, edges, and the entry point.
"""
self.nodes = nodes
self.nodes = {node.node_name: node for node in nodes}
self.edges = self._create_edges(edges)
self.entry_point = entry_point.node_name
if nodes[0].node_name != entry_point.node_name:
# raise a warning if the entry point is not the first node in the list
warnings.warn(
"Careful! The entry point node is different from the first node if the graph.")
def _create_edges(self, edges: list) -> dict:
"""