mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
feat: update base_graph
This commit is contained in:
parent
66a29bc5cc
commit
0571b6da55
@ -1,7 +1,11 @@
|
|||||||
|
"""
|
||||||
|
base_graph module
|
||||||
|
"""
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from langchain_community.callbacks import get_openai_callback
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
from langchain_community.callbacks import get_openai_callback
|
||||||
|
from ..integrations import BurrBridge
|
||||||
|
|
||||||
# Import telemetry functions
|
# Import telemetry functions
|
||||||
from ..telemetry import log_graph_execution, log_event
|
from ..telemetry import log_graph_execution, log_event
|
||||||
@ -79,7 +83,8 @@ class BaseGraph:
|
|||||||
|
|
||||||
def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
|
def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
|
||||||
"""
|
"""
|
||||||
Executes the graph by traversing nodes starting from the entry point using the standard method.
|
Executes the graph by traversing nodes starting from the
|
||||||
|
entry point using the standard method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
initial_state (dict): The initial state to pass to the entry point node.
|
initial_state (dict): The initial state to pass to the entry point node.
|
||||||
@ -114,23 +119,25 @@ class BaseGraph:
|
|||||||
curr_time = time.time()
|
curr_time = time.time()
|
||||||
current_node = next(node for node in self.nodes if node.node_name == current_node_name)
|
current_node = next(node for node in self.nodes if node.node_name == current_node_name)
|
||||||
|
|
||||||
|
|
||||||
# check if there is a "source" key in the node config
|
# check if there is a "source" key in the node config
|
||||||
if current_node.__class__.__name__ == "FetchNode":
|
if current_node.__class__.__name__ == "FetchNode":
|
||||||
# get the second key name of the state dictionary
|
# get the second key name of the state dictionary
|
||||||
source_type = list(state.keys())[1]
|
source_type = list(state.keys())[1]
|
||||||
if state.get("user_prompt", None):
|
if state.get("user_prompt", None):
|
||||||
prompt = state["user_prompt"] if type(state["user_prompt"]) == str else None
|
# Set 'prompt' if 'user_prompt' is a string, otherwise None
|
||||||
# quick fix for local_dir source type
|
prompt = state["user_prompt"] if isinstance(state["user_prompt"], str) else None
|
||||||
|
|
||||||
|
# Convert 'local_dir' source type to 'html_dir'
|
||||||
if source_type == "local_dir":
|
if source_type == "local_dir":
|
||||||
source_type = "html_dir"
|
source_type = "html_dir"
|
||||||
elif source_type == "url":
|
elif source_type == "url":
|
||||||
if type(state[source_type]) == list:
|
# If the source is a list, add string URLs to 'source'
|
||||||
# iterate through the list of urls and see if they are strings
|
if isinstance(state[source_type], list):
|
||||||
for url in state[source_type]:
|
for url in state[source_type]:
|
||||||
if type(url) == str:
|
if isinstance(url, str):
|
||||||
source.append(url)
|
source.append(url)
|
||||||
elif type(state[source_type]) == str:
|
# If the source is a single string, add it to 'source'
|
||||||
|
elif isinstance(state[source_type], str):
|
||||||
source.append(state[source_type])
|
source.append(state[source_type])
|
||||||
|
|
||||||
# check if there is an "llm_model" variable in the class
|
# check if there is an "llm_model" variable in the class
|
||||||
@ -164,7 +171,6 @@ class BaseGraph:
|
|||||||
result = current_node.execute(state)
|
result = current_node.execute(state)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_node = current_node.node_name
|
error_node = current_node.node_name
|
||||||
|
|
||||||
graph_execution_time = time.time() - start_time
|
graph_execution_time = time.time() - start_time
|
||||||
log_graph_execution(
|
log_graph_execution(
|
||||||
graph_name=self.graph_name,
|
graph_name=self.graph_name,
|
||||||
@ -251,7 +257,6 @@ class BaseGraph:
|
|||||||
|
|
||||||
self.initial_state = initial_state
|
self.initial_state = initial_state
|
||||||
if self.use_burr:
|
if self.use_burr:
|
||||||
from ..integrations import BurrBridge
|
|
||||||
|
|
||||||
bridge = BurrBridge(self, self.burr_config)
|
bridge = BurrBridge(self, self.burr_config)
|
||||||
result = bridge.execute(initial_state)
|
result = bridge.execute(initial_state)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user