mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
feat: update base_graph
This commit is contained in:
parent
66a29bc5cc
commit
0571b6da55
@ -1,7 +1,11 @@
|
||||
"""
|
||||
base_graph module
|
||||
"""
|
||||
import time
|
||||
import warnings
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from typing import Tuple
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from ..integrations import BurrBridge
|
||||
|
||||
# Import telemetry functions
|
||||
from ..telemetry import log_graph_execution, log_event
|
||||
@ -56,7 +60,7 @@ class BaseGraph:
|
||||
# raise a warning if the entry point is not the first node in the list
|
||||
warnings.warn(
|
||||
"Careful! The entry point node is different from the first node in the graph.")
|
||||
|
||||
|
||||
# Burr configuration
|
||||
self.use_burr = use_burr
|
||||
self.burr_config = burr_config or {}
|
||||
@ -79,7 +83,8 @@ class BaseGraph:
|
||||
|
||||
def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
|
||||
"""
|
||||
Executes the graph by traversing nodes starting from the entry point using the standard method.
|
||||
Executes the graph by traversing nodes starting from the
|
||||
entry point using the standard method.
|
||||
|
||||
Args:
|
||||
initial_state (dict): The initial state to pass to the entry point node.
|
||||
@ -114,23 +119,25 @@ class BaseGraph:
|
||||
curr_time = time.time()
|
||||
current_node = next(node for node in self.nodes if node.node_name == current_node_name)
|
||||
|
||||
|
||||
# check if there is a "source" key in the node config
|
||||
if current_node.__class__.__name__ == "FetchNode":
|
||||
# get the second key name of the state dictionary
|
||||
source_type = list(state.keys())[1]
|
||||
if state.get("user_prompt", None):
|
||||
prompt = state["user_prompt"] if type(state["user_prompt"]) == str else None
|
||||
# quick fix for local_dir source type
|
||||
# Set 'prompt' if 'user_prompt' is a string, otherwise None
|
||||
prompt = state["user_prompt"] if isinstance(state["user_prompt"], str) else None
|
||||
|
||||
# Convert 'local_dir' source type to 'html_dir'
|
||||
if source_type == "local_dir":
|
||||
source_type = "html_dir"
|
||||
elif source_type == "url":
|
||||
if type(state[source_type]) == list:
|
||||
# iterate through the list of urls and see if they are strings
|
||||
# If the source is a list, add string URLs to 'source'
|
||||
if isinstance(state[source_type], list):
|
||||
for url in state[source_type]:
|
||||
if type(url) == str:
|
||||
if isinstance(url, str):
|
||||
source.append(url)
|
||||
elif type(state[source_type]) == str:
|
||||
# If the source is a single string, add it to 'source'
|
||||
elif isinstance(state[source_type], str):
|
||||
source.append(state[source_type])
|
||||
|
||||
# check if there is an "llm_model" variable in the class
|
||||
@ -164,7 +171,6 @@ class BaseGraph:
|
||||
result = current_node.execute(state)
|
||||
except Exception as e:
|
||||
error_node = current_node.node_name
|
||||
|
||||
graph_execution_time = time.time() - start_time
|
||||
log_graph_execution(
|
||||
graph_name=self.graph_name,
|
||||
@ -221,7 +227,7 @@ class BaseGraph:
|
||||
graph_execution_time = time.time() - start_time
|
||||
response = state.get("answer", None) if source_type == "url" else None
|
||||
content = state.get("parsed_doc", None) if response is not None else None
|
||||
|
||||
|
||||
log_graph_execution(
|
||||
graph_name=self.graph_name,
|
||||
source=source,
|
||||
@ -251,14 +257,13 @@ class BaseGraph:
|
||||
|
||||
self.initial_state = initial_state
|
||||
if self.use_burr:
|
||||
from ..integrations import BurrBridge
|
||||
|
||||
|
||||
bridge = BurrBridge(self, self.burr_config)
|
||||
result = bridge.execute(initial_state)
|
||||
return (result["_state"], [])
|
||||
else:
|
||||
return self._execute_standard(initial_state)
|
||||
|
||||
|
||||
def append_node(self, node):
|
||||
"""
|
||||
Adds a node to the graph.
|
||||
@ -266,11 +271,11 @@ class BaseGraph:
|
||||
Args:
|
||||
node (BaseNode): The node instance to add to the graph.
|
||||
"""
|
||||
|
||||
|
||||
# if node name already exists in the graph, raise an exception
|
||||
if node.node_name in {n.node_name for n in self.nodes}:
|
||||
raise ValueError(f"Node with name '{node.node_name}' already exists in the graph. You can change it by setting the 'node_name' attribute.")
|
||||
|
||||
|
||||
# get the last node in the list
|
||||
last_node = self.nodes[-1]
|
||||
# add the edge connecting the last node to the new node
|
||||
|
||||
Loading…
Reference in New Issue
Block a user