mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
feat(burr-node): working burr bridge
This commit is contained in:
parent
cfaf7eebd8
commit
654a042396
@ -90,6 +90,7 @@ graph = BaseGraph(
|
|||||||
entry_point=fetch_node,
|
entry_point=fetch_node,
|
||||||
use_burr=True,
|
use_burr=True,
|
||||||
burr_config={
|
burr_config={
|
||||||
|
"project_name": "smart-scraper-graph",
|
||||||
"app_instance_id": str(uuid.uuid4()),
|
"app_instance_id": str(uuid.uuid4()),
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"llm_model": graph_config["llm"].get("model", "gpt-3.5-turbo"),
|
"llm_model": graph_config["llm"].get("model", "gpt-3.5-turbo"),
|
||||||
@ -101,9 +102,9 @@ graph = BaseGraph(
|
|||||||
# Execute the graph
|
# Execute the graph
|
||||||
# ************************************************
|
# ************************************************
|
||||||
|
|
||||||
result, execution_info = graph.execute({
|
result, exec_info = graph.execute({
|
||||||
"user_prompt": "Describe the content",
|
"user_prompt": "List me all the projects with their description",
|
||||||
"url": "https://example.com/"
|
"url": "https://perinim.github.io/projects/"
|
||||||
})
|
})
|
||||||
|
|
||||||
# get the answer from the result
|
# get the answer from the result
|
||||||
|
|||||||
@ -29,7 +29,7 @@ dependencies = [
|
|||||||
"playwright==1.43.0",
|
"playwright==1.43.0",
|
||||||
"google==3.0.0",
|
"google==3.0.0",
|
||||||
"yahoo-search-py==0.3",
|
"yahoo-search-py==0.3",
|
||||||
"burr[start]"
|
"burr[start]==0.17.1"
|
||||||
]
|
]
|
||||||
|
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
# features: []
|
# features: []
|
||||||
# all-features: false
|
# all-features: false
|
||||||
# with-sources: false
|
# with-sources: false
|
||||||
|
# generate-hashes: false
|
||||||
|
|
||||||
-e file:.
|
-e file:.
|
||||||
aiofiles==23.2.1
|
aiofiles==23.2.1
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
# features: []
|
# features: []
|
||||||
# all-features: false
|
# all-features: false
|
||||||
# with-sources: false
|
# with-sources: false
|
||||||
|
# generate-hashes: false
|
||||||
|
|
||||||
-e file:.
|
-e file:.
|
||||||
aiofiles==23.2.1
|
aiofiles==23.2.1
|
||||||
|
|||||||
@ -164,6 +164,7 @@ class BaseGraph:
|
|||||||
self.initial_state = initial_state
|
self.initial_state = initial_state
|
||||||
if self.use_burr:
|
if self.use_burr:
|
||||||
bridge = BurrBridge(self, self.burr_config)
|
bridge = BurrBridge(self, self.burr_config)
|
||||||
return bridge.execute(initial_state)
|
result = bridge.execute(initial_state)
|
||||||
|
return (result["_state"], [])
|
||||||
else:
|
else:
|
||||||
return self._execute_standard(initial_state)
|
return self._execute_standard(initial_state)
|
||||||
@ -1,16 +0,0 @@
|
|||||||
digraph {
|
|
||||||
graph [compound=false concentrate=false rankdir=TB ranksep=0.4]
|
|
||||||
fetch_node [label=fetch_node shape=box style=rounded]
|
|
||||||
parse_node [label=parse_node shape=box style=rounded]
|
|
||||||
rag_node [label=rag_node shape=box style=rounded]
|
|
||||||
input__llm_model [label="input: llm_model" shape=oval style=dashed]
|
|
||||||
input__llm_model -> rag_node
|
|
||||||
input__embedder_model [label="input: embedder_model" shape=oval style=dashed]
|
|
||||||
input__embedder_model -> rag_node
|
|
||||||
generate_answer_node [label=generate_answer_node shape=box style=rounded]
|
|
||||||
input__llm_model [label="input: llm_model" shape=oval style=dashed]
|
|
||||||
input__llm_model -> generate_answer_node
|
|
||||||
fetch_node -> parse_node [style=solid]
|
|
||||||
parse_node -> rag_node [style=solid]
|
|
||||||
rag_node -> generate_answer_node [style=solid]
|
|
||||||
}
|
|
||||||
@ -8,7 +8,6 @@ from typing import Any, Dict, List, Tuple
|
|||||||
|
|
||||||
from burr import tracking
|
from burr import tracking
|
||||||
from burr.core import Application, ApplicationBuilder, State, Action, default
|
from burr.core import Application, ApplicationBuilder, State, Action, default
|
||||||
from burr.core.action import action
|
|
||||||
from burr.lifecycle import PostRunStepHook, PreRunStepHook
|
from burr.lifecycle import PostRunStepHook, PreRunStepHook
|
||||||
|
|
||||||
|
|
||||||
@ -40,7 +39,7 @@ class BurrNodeBridge(Action):
|
|||||||
return parse_boolean_expression(self.node.input)
|
return parse_boolean_expression(self.node.input)
|
||||||
|
|
||||||
def run(self, state: State, **run_kwargs) -> dict:
|
def run(self, state: State, **run_kwargs) -> dict:
|
||||||
node_inputs = {key: state[key] for key in self.reads}
|
node_inputs = {key: state[key] for key in self.reads if key in state}
|
||||||
result_state = self.node.execute(node_inputs, **run_kwargs)
|
result_state = self.node.execute(node_inputs, **run_kwargs)
|
||||||
return result_state
|
return result_state
|
||||||
|
|
||||||
@ -49,7 +48,7 @@ class BurrNodeBridge(Action):
|
|||||||
return self.node.output
|
return self.node.output
|
||||||
|
|
||||||
def update(self, result: dict, state: State) -> State:
|
def update(self, result: dict, state: State) -> State:
|
||||||
return state.update(**state)
|
return state.update(**result)
|
||||||
|
|
||||||
|
|
||||||
def parse_boolean_expression(expression: str) -> List[str]:
|
def parse_boolean_expression(expression: str) -> List[str]:
|
||||||
@ -92,7 +91,8 @@ class BurrBridge:
|
|||||||
def __init__(self, base_graph, burr_config):
|
def __init__(self, base_graph, burr_config):
|
||||||
self.base_graph = base_graph
|
self.base_graph = base_graph
|
||||||
self.burr_config = burr_config
|
self.burr_config = burr_config
|
||||||
self.tracker = tracking.LocalTrackingClient(project="smart-scraper-graph")
|
self.project_name = burr_config.get("project_name", "default-project")
|
||||||
|
self.tracker = tracking.LocalTrackingClient(project=self.project_name)
|
||||||
self.app_instance_id = burr_config.get("app_instance_id", "default-instance")
|
self.app_instance_id = burr_config.get("app_instance_id", "default-instance")
|
||||||
self.burr_inputs = burr_config.get("inputs", {})
|
self.burr_inputs = burr_config.get("inputs", {})
|
||||||
self.burr_app = None
|
self.burr_app = None
|
||||||
@ -111,7 +111,7 @@ class BurrBridge:
|
|||||||
actions = self._create_actions()
|
actions = self._create_actions()
|
||||||
transitions = self._create_transitions()
|
transitions = self._create_transitions()
|
||||||
hooks = [PrintLnHook()]
|
hooks = [PrintLnHook()]
|
||||||
burr_state = self._convert_state_to_burr(initial_state)
|
burr_state = State(initial_state)
|
||||||
|
|
||||||
app = (
|
app = (
|
||||||
ApplicationBuilder()
|
ApplicationBuilder()
|
||||||
@ -136,32 +136,10 @@ class BurrBridge:
|
|||||||
|
|
||||||
actions = {}
|
actions = {}
|
||||||
for node in self.base_graph.nodes:
|
for node in self.base_graph.nodes:
|
||||||
action_func = self._create_action(node)
|
action_func = BurrNodeBridge(node)
|
||||||
actions[node.node_name] = action_func
|
actions[node.node_name] = action_func
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
def _create_action(self, node) -> Any:
|
|
||||||
"""
|
|
||||||
Create a Burr action function from a base graph node.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node (Node): The base graph node to convert to a Burr action.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
function: The Burr action function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# @action(reads=parse_boolean_expression(node.input), writes=node.output)
|
|
||||||
# def dynamic_action(state: State, **kwargs):
|
|
||||||
# node_inputs = {key: state[key] for key in self._parse_boolean_expression(node.input)}
|
|
||||||
# result_state = node.execute(node_inputs, **kwargs)
|
|
||||||
# return result_state, state.update(**result_state)
|
|
||||||
#
|
|
||||||
# return dynamic_action
|
|
||||||
# import pdb
|
|
||||||
# pdb.set_trace()
|
|
||||||
return BurrNodeBridge(node)
|
|
||||||
|
|
||||||
def _create_transitions(self) -> List[Tuple[str, str, Any]]:
|
def _create_transitions(self) -> List[Tuple[str, str, Any]]:
|
||||||
"""
|
"""
|
||||||
Create Burr transitions from the base graph edges.
|
Create Burr transitions from the base graph edges.
|
||||||
@ -175,22 +153,6 @@ class BurrBridge:
|
|||||||
transitions.append((from_node, to_node, default))
|
transitions.append((from_node, to_node, default))
|
||||||
return transitions
|
return transitions
|
||||||
|
|
||||||
def _convert_state_to_burr(self, state: Dict[str, Any]) -> State:
|
|
||||||
"""
|
|
||||||
Convert a dictionary state to a Burr state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (dict): The dictionary state to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
State: The Burr state instance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
burr_state = State()
|
|
||||||
for key, value in state.items():
|
|
||||||
setattr(burr_state, key, value)
|
|
||||||
return burr_state
|
|
||||||
|
|
||||||
def _convert_state_from_burr(self, burr_state: State) -> Dict[str, Any]:
|
def _convert_state_from_burr(self, burr_state: State) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert a Burr state to a dictionary state.
|
Convert a Burr state to a dictionary state.
|
||||||
@ -223,7 +185,6 @@ class BurrBridge:
|
|||||||
# TODO: to fix final nodes detection
|
# TODO: to fix final nodes detection
|
||||||
final_nodes = [self.burr_app.graph.actions[-1].name]
|
final_nodes = [self.burr_app.graph.actions[-1].name]
|
||||||
|
|
||||||
# TODO: fix inputs
|
|
||||||
last_action, result, final_state = self.burr_app.run(
|
last_action, result, final_state = self.burr_app.run(
|
||||||
halt_after=final_nodes,
|
halt_after=final_nodes,
|
||||||
inputs=self.burr_inputs
|
inputs=self.burr_inputs
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user