mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
feat(burr-node): working burr bridge
This commit is contained in:
parent
cfaf7eebd8
commit
654a042396
@ -90,6 +90,7 @@ graph = BaseGraph(
|
||||
entry_point=fetch_node,
|
||||
use_burr=True,
|
||||
burr_config={
|
||||
"project_name": "smart-scraper-graph",
|
||||
"app_instance_id": str(uuid.uuid4()),
|
||||
"inputs": {
|
||||
"llm_model": graph_config["llm"].get("model", "gpt-3.5-turbo"),
|
||||
@ -101,9 +102,9 @@ graph = BaseGraph(
|
||||
# Execute the graph
|
||||
# ************************************************
|
||||
|
||||
result, execution_info = graph.execute({
|
||||
"user_prompt": "Describe the content",
|
||||
"url": "https://example.com/"
|
||||
result, exec_info = graph.execute({
|
||||
"user_prompt": "List me all the projects with their description",
|
||||
"url": "https://perinim.github.io/projects/"
|
||||
})
|
||||
|
||||
# get the answer from the result
|
||||
|
||||
@ -29,7 +29,7 @@ dependencies = [
|
||||
"playwright==1.43.0",
|
||||
"google==3.0.0",
|
||||
"yahoo-search-py==0.3",
|
||||
"burr[start]"
|
||||
"burr[start]==0.17.1"
|
||||
]
|
||||
|
||||
license = "MIT"
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
# generate-hashes: false
|
||||
|
||||
-e file:.
|
||||
aiofiles==23.2.1
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
# generate-hashes: false
|
||||
|
||||
-e file:.
|
||||
aiofiles==23.2.1
|
||||
|
||||
@ -164,6 +164,7 @@ class BaseGraph:
|
||||
self.initial_state = initial_state
|
||||
if self.use_burr:
|
||||
bridge = BurrBridge(self, self.burr_config)
|
||||
return bridge.execute(initial_state)
|
||||
result = bridge.execute(initial_state)
|
||||
return (result["_state"], [])
|
||||
else:
|
||||
return self._execute_standard(initial_state)
|
||||
@ -1,16 +0,0 @@
|
||||
digraph {
|
||||
graph [compound=false concentrate=false rankdir=TB ranksep=0.4]
|
||||
fetch_node [label=fetch_node shape=box style=rounded]
|
||||
parse_node [label=parse_node shape=box style=rounded]
|
||||
rag_node [label=rag_node shape=box style=rounded]
|
||||
input__llm_model [label="input: llm_model" shape=oval style=dashed]
|
||||
input__llm_model -> rag_node
|
||||
input__embedder_model [label="input: embedder_model" shape=oval style=dashed]
|
||||
input__embedder_model -> rag_node
|
||||
generate_answer_node [label=generate_answer_node shape=box style=rounded]
|
||||
input__llm_model [label="input: llm_model" shape=oval style=dashed]
|
||||
input__llm_model -> generate_answer_node
|
||||
fetch_node -> parse_node [style=solid]
|
||||
parse_node -> rag_node [style=solid]
|
||||
rag_node -> generate_answer_node [style=solid]
|
||||
}
|
||||
@ -8,7 +8,6 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
from burr import tracking
|
||||
from burr.core import Application, ApplicationBuilder, State, Action, default
|
||||
from burr.core.action import action
|
||||
from burr.lifecycle import PostRunStepHook, PreRunStepHook
|
||||
|
||||
|
||||
@ -40,7 +39,7 @@ class BurrNodeBridge(Action):
|
||||
return parse_boolean_expression(self.node.input)
|
||||
|
||||
def run(self, state: State, **run_kwargs) -> dict:
|
||||
node_inputs = {key: state[key] for key in self.reads}
|
||||
node_inputs = {key: state[key] for key in self.reads if key in state}
|
||||
result_state = self.node.execute(node_inputs, **run_kwargs)
|
||||
return result_state
|
||||
|
||||
@ -49,7 +48,7 @@ class BurrNodeBridge(Action):
|
||||
return self.node.output
|
||||
|
||||
def update(self, result: dict, state: State) -> State:
|
||||
return state.update(**state)
|
||||
return state.update(**result)
|
||||
|
||||
|
||||
def parse_boolean_expression(expression: str) -> List[str]:
|
||||
@ -92,7 +91,8 @@ class BurrBridge:
|
||||
def __init__(self, base_graph, burr_config):
|
||||
self.base_graph = base_graph
|
||||
self.burr_config = burr_config
|
||||
self.tracker = tracking.LocalTrackingClient(project="smart-scraper-graph")
|
||||
self.project_name = burr_config.get("project_name", "default-project")
|
||||
self.tracker = tracking.LocalTrackingClient(project=self.project_name)
|
||||
self.app_instance_id = burr_config.get("app_instance_id", "default-instance")
|
||||
self.burr_inputs = burr_config.get("inputs", {})
|
||||
self.burr_app = None
|
||||
@ -111,7 +111,7 @@ class BurrBridge:
|
||||
actions = self._create_actions()
|
||||
transitions = self._create_transitions()
|
||||
hooks = [PrintLnHook()]
|
||||
burr_state = self._convert_state_to_burr(initial_state)
|
||||
burr_state = State(initial_state)
|
||||
|
||||
app = (
|
||||
ApplicationBuilder()
|
||||
@ -136,32 +136,10 @@ class BurrBridge:
|
||||
|
||||
actions = {}
|
||||
for node in self.base_graph.nodes:
|
||||
action_func = self._create_action(node)
|
||||
action_func = BurrNodeBridge(node)
|
||||
actions[node.node_name] = action_func
|
||||
return actions
|
||||
|
||||
def _create_action(self, node) -> Any:
|
||||
"""
|
||||
Create a Burr action function from a base graph node.
|
||||
|
||||
Args:
|
||||
node (Node): The base graph node to convert to a Burr action.
|
||||
|
||||
Returns:
|
||||
function: The Burr action function.
|
||||
"""
|
||||
|
||||
# @action(reads=parse_boolean_expression(node.input), writes=node.output)
|
||||
# def dynamic_action(state: State, **kwargs):
|
||||
# node_inputs = {key: state[key] for key in self._parse_boolean_expression(node.input)}
|
||||
# result_state = node.execute(node_inputs, **kwargs)
|
||||
# return result_state, state.update(**result_state)
|
||||
#
|
||||
# return dynamic_action
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
return BurrNodeBridge(node)
|
||||
|
||||
def _create_transitions(self) -> List[Tuple[str, str, Any]]:
|
||||
"""
|
||||
Create Burr transitions from the base graph edges.
|
||||
@ -175,22 +153,6 @@ class BurrBridge:
|
||||
transitions.append((from_node, to_node, default))
|
||||
return transitions
|
||||
|
||||
def _convert_state_to_burr(self, state: Dict[str, Any]) -> State:
|
||||
"""
|
||||
Convert a dictionary state to a Burr state.
|
||||
|
||||
Args:
|
||||
state (dict): The dictionary state to convert.
|
||||
|
||||
Returns:
|
||||
State: The Burr state instance.
|
||||
"""
|
||||
|
||||
burr_state = State()
|
||||
for key, value in state.items():
|
||||
setattr(burr_state, key, value)
|
||||
return burr_state
|
||||
|
||||
def _convert_state_from_burr(self, burr_state: State) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Burr state to a dictionary state.
|
||||
@ -223,7 +185,6 @@ class BurrBridge:
|
||||
# TODO: to fix final nodes detection
|
||||
final_nodes = [self.burr_app.graph.actions[-1].name]
|
||||
|
||||
# TODO: fix inputs
|
||||
last_action, result, final_state = self.burr_app.run(
|
||||
halt_after=final_nodes,
|
||||
inputs=self.burr_inputs
|
||||
|
||||
Loading…
Reference in New Issue
Block a user