feat(burr-node): working burr bridge

This commit is contained in:
PeriniM 2024-05-22 00:24:38 +02:00
parent cfaf7eebd8
commit 654a042396
7 changed files with 15 additions and 66 deletions

View File

@ -90,6 +90,7 @@ graph = BaseGraph(
entry_point=fetch_node, entry_point=fetch_node,
use_burr=True, use_burr=True,
burr_config={ burr_config={
"project_name": "smart-scraper-graph",
"app_instance_id": str(uuid.uuid4()), "app_instance_id": str(uuid.uuid4()),
"inputs": { "inputs": {
"llm_model": graph_config["llm"].get("model", "gpt-3.5-turbo"), "llm_model": graph_config["llm"].get("model", "gpt-3.5-turbo"),
@ -101,9 +102,9 @@ graph = BaseGraph(
# Execute the graph # Execute the graph
# ************************************************ # ************************************************
result, execution_info = graph.execute({ result, exec_info = graph.execute({
"user_prompt": "Describe the content", "user_prompt": "List me all the projects with their description",
"url": "https://example.com/" "url": "https://perinim.github.io/projects/"
}) })
# get the answer from the result # get the answer from the result

View File

@ -29,7 +29,7 @@ dependencies = [
"playwright==1.43.0", "playwright==1.43.0",
"google==3.0.0", "google==3.0.0",
"yahoo-search-py==0.3", "yahoo-search-py==0.3",
"burr[start]" "burr[start]==0.17.1"
] ]
license = "MIT" license = "MIT"

View File

@ -6,6 +6,7 @@
# features: [] # features: []
# all-features: false # all-features: false
# with-sources: false # with-sources: false
# generate-hashes: false
-e file:. -e file:.
aiofiles==23.2.1 aiofiles==23.2.1

View File

@ -6,6 +6,7 @@
# features: [] # features: []
# all-features: false # all-features: false
# with-sources: false # with-sources: false
# generate-hashes: false
-e file:. -e file:.
aiofiles==23.2.1 aiofiles==23.2.1

View File

@ -164,6 +164,7 @@ class BaseGraph:
self.initial_state = initial_state self.initial_state = initial_state
if self.use_burr: if self.use_burr:
bridge = BurrBridge(self, self.burr_config) bridge = BurrBridge(self, self.burr_config)
return bridge.execute(initial_state) result = bridge.execute(initial_state)
return (result["_state"], [])
else: else:
return self._execute_standard(initial_state) return self._execute_standard(initial_state)

View File

@ -1,16 +0,0 @@
digraph {
graph [compound=false concentrate=false rankdir=TB ranksep=0.4]
fetch_node [label=fetch_node shape=box style=rounded]
parse_node [label=parse_node shape=box style=rounded]
rag_node [label=rag_node shape=box style=rounded]
input__llm_model [label="input: llm_model" shape=oval style=dashed]
input__llm_model -> rag_node
input__embedder_model [label="input: embedder_model" shape=oval style=dashed]
input__embedder_model -> rag_node
generate_answer_node [label=generate_answer_node shape=box style=rounded]
input__llm_model [label="input: llm_model" shape=oval style=dashed]
input__llm_model -> generate_answer_node
fetch_node -> parse_node [style=solid]
parse_node -> rag_node [style=solid]
rag_node -> generate_answer_node [style=solid]
}

View File

@ -8,7 +8,6 @@ from typing import Any, Dict, List, Tuple
from burr import tracking from burr import tracking
from burr.core import Application, ApplicationBuilder, State, Action, default from burr.core import Application, ApplicationBuilder, State, Action, default
from burr.core.action import action
from burr.lifecycle import PostRunStepHook, PreRunStepHook from burr.lifecycle import PostRunStepHook, PreRunStepHook
@ -40,7 +39,7 @@ class BurrNodeBridge(Action):
return parse_boolean_expression(self.node.input) return parse_boolean_expression(self.node.input)
def run(self, state: State, **run_kwargs) -> dict: def run(self, state: State, **run_kwargs) -> dict:
node_inputs = {key: state[key] for key in self.reads} node_inputs = {key: state[key] for key in self.reads if key in state}
result_state = self.node.execute(node_inputs, **run_kwargs) result_state = self.node.execute(node_inputs, **run_kwargs)
return result_state return result_state
@ -49,7 +48,7 @@ class BurrNodeBridge(Action):
return self.node.output return self.node.output
def update(self, result: dict, state: State) -> State: def update(self, result: dict, state: State) -> State:
return state.update(**state) return state.update(**result)
def parse_boolean_expression(expression: str) -> List[str]: def parse_boolean_expression(expression: str) -> List[str]:
@ -92,7 +91,8 @@ class BurrBridge:
def __init__(self, base_graph, burr_config): def __init__(self, base_graph, burr_config):
self.base_graph = base_graph self.base_graph = base_graph
self.burr_config = burr_config self.burr_config = burr_config
self.tracker = tracking.LocalTrackingClient(project="smart-scraper-graph") self.project_name = burr_config.get("project_name", "default-project")
self.tracker = tracking.LocalTrackingClient(project=self.project_name)
self.app_instance_id = burr_config.get("app_instance_id", "default-instance") self.app_instance_id = burr_config.get("app_instance_id", "default-instance")
self.burr_inputs = burr_config.get("inputs", {}) self.burr_inputs = burr_config.get("inputs", {})
self.burr_app = None self.burr_app = None
@ -111,7 +111,7 @@ class BurrBridge:
actions = self._create_actions() actions = self._create_actions()
transitions = self._create_transitions() transitions = self._create_transitions()
hooks = [PrintLnHook()] hooks = [PrintLnHook()]
burr_state = self._convert_state_to_burr(initial_state) burr_state = State(initial_state)
app = ( app = (
ApplicationBuilder() ApplicationBuilder()
@ -136,32 +136,10 @@ class BurrBridge:
actions = {} actions = {}
for node in self.base_graph.nodes: for node in self.base_graph.nodes:
action_func = self._create_action(node) action_func = BurrNodeBridge(node)
actions[node.node_name] = action_func actions[node.node_name] = action_func
return actions return actions
def _create_action(self, node) -> Any:
"""
Create a Burr action function from a base graph node.
Args:
node (Node): The base graph node to convert to a Burr action.
Returns:
function: The Burr action function.
"""
# @action(reads=parse_boolean_expression(node.input), writes=node.output)
# def dynamic_action(state: State, **kwargs):
# node_inputs = {key: state[key] for key in self._parse_boolean_expression(node.input)}
# result_state = node.execute(node_inputs, **kwargs)
# return result_state, state.update(**result_state)
#
# return dynamic_action
# import pdb
# pdb.set_trace()
return BurrNodeBridge(node)
def _create_transitions(self) -> List[Tuple[str, str, Any]]: def _create_transitions(self) -> List[Tuple[str, str, Any]]:
""" """
Create Burr transitions from the base graph edges. Create Burr transitions from the base graph edges.
@ -175,22 +153,6 @@ class BurrBridge:
transitions.append((from_node, to_node, default)) transitions.append((from_node, to_node, default))
return transitions return transitions
def _convert_state_to_burr(self, state: Dict[str, Any]) -> State:
"""
Convert a dictionary state to a Burr state.
Args:
state (dict): The dictionary state to convert.
Returns:
State: The Burr state instance.
"""
burr_state = State()
for key, value in state.items():
setattr(burr_state, key, value)
return burr_state
def _convert_state_from_burr(self, burr_state: State) -> Dict[str, Any]: def _convert_state_from_burr(self, burr_state: State) -> Dict[str, Any]:
""" """
Convert a Burr state to a dictionary state. Convert a Burr state to a dictionary state.
@ -223,7 +185,6 @@ class BurrBridge:
# TODO: to fix final nodes detection # TODO: to fix final nodes detection
final_nodes = [self.burr_app.graph.actions[-1].name] final_nodes = [self.burr_app.graph.actions[-1].name]
# TODO: fix inputs
last_action, result, final_state = self.burr_app.run( last_action, result, final_state = self.burr_app.run(
halt_after=final_nodes, halt_after=final_nodes,
inputs=self.burr_inputs inputs=self.burr_inputs