mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
feat: update base_graph
This commit is contained in:
parent
66a29bc5cc
commit
0571b6da55
@ -1,7 +1,11 @@
|
|||||||
|
"""
|
||||||
|
base_graph module
|
||||||
|
"""
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from langchain_community.callbacks import get_openai_callback
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
from langchain_community.callbacks import get_openai_callback
|
||||||
|
from ..integrations import BurrBridge
|
||||||
|
|
||||||
# Import telemetry functions
|
# Import telemetry functions
|
||||||
from ..telemetry import log_graph_execution, log_event
|
from ..telemetry import log_graph_execution, log_event
|
||||||
@ -56,7 +60,7 @@ class BaseGraph:
|
|||||||
# raise a warning if the entry point is not the first node in the list
|
# raise a warning if the entry point is not the first node in the list
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Careful! The entry point node is different from the first node in the graph.")
|
"Careful! The entry point node is different from the first node in the graph.")
|
||||||
|
|
||||||
# Burr configuration
|
# Burr configuration
|
||||||
self.use_burr = use_burr
|
self.use_burr = use_burr
|
||||||
self.burr_config = burr_config or {}
|
self.burr_config = burr_config or {}
|
||||||
@ -79,7 +83,8 @@ class BaseGraph:
|
|||||||
|
|
||||||
def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
|
def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
|
||||||
"""
|
"""
|
||||||
Executes the graph by traversing nodes starting from the entry point using the standard method.
|
Executes the graph by traversing nodes starting from the
|
||||||
|
entry point using the standard method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
initial_state (dict): The initial state to pass to the entry point node.
|
initial_state (dict): The initial state to pass to the entry point node.
|
||||||
@ -114,23 +119,25 @@ class BaseGraph:
|
|||||||
curr_time = time.time()
|
curr_time = time.time()
|
||||||
current_node = next(node for node in self.nodes if node.node_name == current_node_name)
|
current_node = next(node for node in self.nodes if node.node_name == current_node_name)
|
||||||
|
|
||||||
|
|
||||||
# check if there is a "source" key in the node config
|
# check if there is a "source" key in the node config
|
||||||
if current_node.__class__.__name__ == "FetchNode":
|
if current_node.__class__.__name__ == "FetchNode":
|
||||||
# get the second key name of the state dictionary
|
# get the second key name of the state dictionary
|
||||||
source_type = list(state.keys())[1]
|
source_type = list(state.keys())[1]
|
||||||
if state.get("user_prompt", None):
|
if state.get("user_prompt", None):
|
||||||
prompt = state["user_prompt"] if type(state["user_prompt"]) == str else None
|
# Set 'prompt' if 'user_prompt' is a string, otherwise None
|
||||||
# quick fix for local_dir source type
|
prompt = state["user_prompt"] if isinstance(state["user_prompt"], str) else None
|
||||||
|
|
||||||
|
# Convert 'local_dir' source type to 'html_dir'
|
||||||
if source_type == "local_dir":
|
if source_type == "local_dir":
|
||||||
source_type = "html_dir"
|
source_type = "html_dir"
|
||||||
elif source_type == "url":
|
elif source_type == "url":
|
||||||
if type(state[source_type]) == list:
|
# If the source is a list, add string URLs to 'source'
|
||||||
# iterate through the list of urls and see if they are strings
|
if isinstance(state[source_type], list):
|
||||||
for url in state[source_type]:
|
for url in state[source_type]:
|
||||||
if type(url) == str:
|
if isinstance(url, str):
|
||||||
source.append(url)
|
source.append(url)
|
||||||
elif type(state[source_type]) == str:
|
# If the source is a single string, add it to 'source'
|
||||||
|
elif isinstance(state[source_type], str):
|
||||||
source.append(state[source_type])
|
source.append(state[source_type])
|
||||||
|
|
||||||
# check if there is an "llm_model" variable in the class
|
# check if there is an "llm_model" variable in the class
|
||||||
@ -164,7 +171,6 @@ class BaseGraph:
|
|||||||
result = current_node.execute(state)
|
result = current_node.execute(state)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_node = current_node.node_name
|
error_node = current_node.node_name
|
||||||
|
|
||||||
graph_execution_time = time.time() - start_time
|
graph_execution_time = time.time() - start_time
|
||||||
log_graph_execution(
|
log_graph_execution(
|
||||||
graph_name=self.graph_name,
|
graph_name=self.graph_name,
|
||||||
@ -221,7 +227,7 @@ class BaseGraph:
|
|||||||
graph_execution_time = time.time() - start_time
|
graph_execution_time = time.time() - start_time
|
||||||
response = state.get("answer", None) if source_type == "url" else None
|
response = state.get("answer", None) if source_type == "url" else None
|
||||||
content = state.get("parsed_doc", None) if response is not None else None
|
content = state.get("parsed_doc", None) if response is not None else None
|
||||||
|
|
||||||
log_graph_execution(
|
log_graph_execution(
|
||||||
graph_name=self.graph_name,
|
graph_name=self.graph_name,
|
||||||
source=source,
|
source=source,
|
||||||
@ -251,14 +257,13 @@ class BaseGraph:
|
|||||||
|
|
||||||
self.initial_state = initial_state
|
self.initial_state = initial_state
|
||||||
if self.use_burr:
|
if self.use_burr:
|
||||||
from ..integrations import BurrBridge
|
|
||||||
|
|
||||||
bridge = BurrBridge(self, self.burr_config)
|
bridge = BurrBridge(self, self.burr_config)
|
||||||
result = bridge.execute(initial_state)
|
result = bridge.execute(initial_state)
|
||||||
return (result["_state"], [])
|
return (result["_state"], [])
|
||||||
else:
|
else:
|
||||||
return self._execute_standard(initial_state)
|
return self._execute_standard(initial_state)
|
||||||
|
|
||||||
def append_node(self, node):
|
def append_node(self, node):
|
||||||
"""
|
"""
|
||||||
Adds a node to the graph.
|
Adds a node to the graph.
|
||||||
@ -266,11 +271,11 @@ class BaseGraph:
|
|||||||
Args:
|
Args:
|
||||||
node (BaseNode): The node instance to add to the graph.
|
node (BaseNode): The node instance to add to the graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# if node name already exists in the graph, raise an exception
|
# if node name already exists in the graph, raise an exception
|
||||||
if node.node_name in {n.node_name for n in self.nodes}:
|
if node.node_name in {n.node_name for n in self.nodes}:
|
||||||
raise ValueError(f"Node with name '{node.node_name}' already exists in the graph. You can change it by setting the 'node_name' attribute.")
|
raise ValueError(f"Node with name '{node.node_name}' already exists in the graph. You can change it by setting the 'node_name' attribute.")
|
||||||
|
|
||||||
# get the last node in the list
|
# get the last node in the list
|
||||||
last_node = self.nodes[-1]
|
last_node = self.nodes[-1]
|
||||||
# add the edge connecting the last node to the new node
|
# add the edge connecting the last node to the new node
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user