feat(burr-node): working burr bridge

This commit is contained in:
PeriniM 2024-05-22 00:24:38 +02:00
parent cfaf7eebd8
commit 654a042396
7 changed files with 15 additions and 66 deletions

View File

@ -90,6 +90,7 @@ graph = BaseGraph(
entry_point=fetch_node,
use_burr=True,
burr_config={
"project_name": "smart-scraper-graph",
"app_instance_id": str(uuid.uuid4()),
"inputs": {
"llm_model": graph_config["llm"].get("model", "gpt-3.5-turbo"),
@ -101,9 +102,9 @@ graph = BaseGraph(
# Execute the graph
# ************************************************
result, execution_info = graph.execute({
"user_prompt": "Describe the content",
"url": "https://example.com/"
result, exec_info = graph.execute({
"user_prompt": "List me all the projects with their description",
"url": "https://perinim.github.io/projects/"
})
# get the answer from the result

View File

@ -29,7 +29,7 @@ dependencies = [
"playwright==1.43.0",
"google==3.0.0",
"yahoo-search-py==0.3",
"burr[start]"
"burr[start]==0.17.1"
]
license = "MIT"

View File

@ -6,6 +6,7 @@
# features: []
# all-features: false
# with-sources: false
# generate-hashes: false
-e file:.
aiofiles==23.2.1

View File

@ -6,6 +6,7 @@
# features: []
# all-features: false
# with-sources: false
# generate-hashes: false
-e file:.
aiofiles==23.2.1

View File

@ -164,6 +164,7 @@ class BaseGraph:
self.initial_state = initial_state
if self.use_burr:
bridge = BurrBridge(self, self.burr_config)
return bridge.execute(initial_state)
result = bridge.execute(initial_state)
return (result["_state"], [])
else:
return self._execute_standard(initial_state)

View File

@ -1,16 +0,0 @@
digraph {
graph [compound=false concentrate=false rankdir=TB ranksep=0.4]
fetch_node [label=fetch_node shape=box style=rounded]
parse_node [label=parse_node shape=box style=rounded]
rag_node [label=rag_node shape=box style=rounded]
input__llm_model [label="input: llm_model" shape=oval style=dashed]
input__llm_model -> rag_node
input__embedder_model [label="input: embedder_model" shape=oval style=dashed]
input__embedder_model -> rag_node
generate_answer_node [label=generate_answer_node shape=box style=rounded]
input__llm_model [label="input: llm_model" shape=oval style=dashed]
input__llm_model -> generate_answer_node
fetch_node -> parse_node [style=solid]
parse_node -> rag_node [style=solid]
rag_node -> generate_answer_node [style=solid]
}

View File

@ -8,7 +8,6 @@ from typing import Any, Dict, List, Tuple
from burr import tracking
from burr.core import Application, ApplicationBuilder, State, Action, default
from burr.core.action import action
from burr.lifecycle import PostRunStepHook, PreRunStepHook
@ -40,7 +39,7 @@ class BurrNodeBridge(Action):
return parse_boolean_expression(self.node.input)
def run(self, state: State, **run_kwargs) -> dict:
node_inputs = {key: state[key] for key in self.reads}
node_inputs = {key: state[key] for key in self.reads if key in state}
result_state = self.node.execute(node_inputs, **run_kwargs)
return result_state
@ -49,7 +48,7 @@ class BurrNodeBridge(Action):
return self.node.output
def update(self, result: dict, state: State) -> State:
return state.update(**state)
return state.update(**result)
def parse_boolean_expression(expression: str) -> List[str]:
@ -92,7 +91,8 @@ class BurrBridge:
def __init__(self, base_graph, burr_config):
self.base_graph = base_graph
self.burr_config = burr_config
self.tracker = tracking.LocalTrackingClient(project="smart-scraper-graph")
self.project_name = burr_config.get("project_name", "default-project")
self.tracker = tracking.LocalTrackingClient(project=self.project_name)
self.app_instance_id = burr_config.get("app_instance_id", "default-instance")
self.burr_inputs = burr_config.get("inputs", {})
self.burr_app = None
@ -111,7 +111,7 @@ class BurrBridge:
actions = self._create_actions()
transitions = self._create_transitions()
hooks = [PrintLnHook()]
burr_state = self._convert_state_to_burr(initial_state)
burr_state = State(initial_state)
app = (
ApplicationBuilder()
@ -136,32 +136,10 @@ class BurrBridge:
actions = {}
for node in self.base_graph.nodes:
action_func = self._create_action(node)
action_func = BurrNodeBridge(node)
actions[node.node_name] = action_func
return actions
def _create_action(self, node) -> Any:
"""
Create a Burr action function from a base graph node.
Args:
node (Node): The base graph node to convert to a Burr action.
Returns:
function: The Burr action function.
"""
# @action(reads=parse_boolean_expression(node.input), writes=node.output)
# def dynamic_action(state: State, **kwargs):
# node_inputs = {key: state[key] for key in self._parse_boolean_expression(node.input)}
# result_state = node.execute(node_inputs, **kwargs)
# return result_state, state.update(**result_state)
#
# return dynamic_action
# import pdb
# pdb.set_trace()
return BurrNodeBridge(node)
def _create_transitions(self) -> List[Tuple[str, str, Any]]:
"""
Create Burr transitions from the base graph edges.
@ -175,22 +153,6 @@ class BurrBridge:
transitions.append((from_node, to_node, default))
return transitions
def _convert_state_to_burr(self, state: Dict[str, Any]) -> State:
"""
Convert a dictionary state to a Burr state.
Args:
state (dict): The dictionary state to convert.
Returns:
State: The Burr state instance.
"""
burr_state = State()
for key, value in state.items():
setattr(burr_state, key, value)
return burr_state
def _convert_state_from_burr(self, burr_state: State) -> Dict[str, Any]:
"""
Convert a Burr state to a dictionary state.
@ -223,7 +185,6 @@ class BurrBridge:
# TODO: to fix final nodes detection
final_nodes = [self.burr_app.graph.actions[-1].name]
# TODO: fix inputs
last_action, result, final_state = self.burr_app.run(
halt_after=final_nodes,
inputs=self.burr_inputs