From 654a04239640a89d9fa408ccb2e4485247ab84df Mon Sep 17 00:00:00 2001 From: PeriniM Date: Wed, 22 May 2024 00:24:38 +0200 Subject: [PATCH] feat(burr-node): working burr bridge --- examples/openai/burr_integration_openai.py | 7 +-- pyproject.toml | 2 +- requirements-dev.lock | 1 + requirements.lock | 1 + scrapegraphai/graphs/base_graph.py | 3 +- scrapegraphai/graphs/smart_scraper_graph | 16 ------- scrapegraphai/integrations/burr_bridge.py | 51 +++------------------- 7 files changed, 15 insertions(+), 66 deletions(-) delete mode 100644 scrapegraphai/graphs/smart_scraper_graph diff --git a/examples/openai/burr_integration_openai.py b/examples/openai/burr_integration_openai.py index 41f2d817..7d531c05 100644 --- a/examples/openai/burr_integration_openai.py +++ b/examples/openai/burr_integration_openai.py @@ -90,6 +90,7 @@ graph = BaseGraph( entry_point=fetch_node, use_burr=True, burr_config={ + "project_name": "smart-scraper-graph", "app_instance_id": str(uuid.uuid4()), "inputs": { "llm_model": graph_config["llm"].get("model", "gpt-3.5-turbo"), @@ -101,9 +102,9 @@ graph = BaseGraph( # Execute the graph # ************************************************ -result, execution_info = graph.execute({ - "user_prompt": "Describe the content", - "url": "https://example.com/" +result, exec_info = graph.execute({ + "user_prompt": "List me all the projects with their description", + "url": "https://perinim.github.io/projects/" }) # get the answer from the result diff --git a/pyproject.toml b/pyproject.toml index 5f85f19a..19360e4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "playwright==1.43.0", "google==3.0.0", "yahoo-search-py==0.3", - "burr[start]" + "burr[start]==0.17.1" ] license = "MIT" diff --git a/requirements-dev.lock b/requirements-dev.lock index 89789099..7458fe01 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -6,6 +6,7 @@ # features: [] # all-features: false # with-sources: false +# generate-hashes: false -e file:. aiofiles==23.2.1 diff --git a/requirements.lock b/requirements.lock index b0872619..ed73ca98 100644 --- a/requirements.lock +++ b/requirements.lock @@ -6,6 +6,7 @@ # features: [] # all-features: false # with-sources: false +# generate-hashes: false -e file:. aiofiles==23.2.1 diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index 06791528..07615a78 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -164,6 +164,7 @@ class BaseGraph: self.initial_state = initial_state if self.use_burr: bridge = BurrBridge(self, self.burr_config) - return bridge.execute(initial_state) + result = bridge.execute(initial_state) + return (result["_state"], []) else: return self._execute_standard(initial_state) \ No newline at end of file diff --git a/scrapegraphai/graphs/smart_scraper_graph b/scrapegraphai/graphs/smart_scraper_graph deleted file mode 100644 index fe361b4d..00000000 --- a/scrapegraphai/graphs/smart_scraper_graph +++ /dev/null @@ -1,16 +0,0 @@ -digraph { - graph [compound=false concentrate=false rankdir=TB ranksep=0.4] - fetch_node [label=fetch_node shape=box style=rounded] - parse_node [label=parse_node shape=box style=rounded] - rag_node [label=rag_node shape=box style=rounded] - input__llm_model [label="input: llm_model" shape=oval style=dashed] - input__llm_model -> rag_node - input__embedder_model [label="input: embedder_model" shape=oval style=dashed] - input__embedder_model -> rag_node - generate_answer_node [label=generate_answer_node shape=box style=rounded] - input__llm_model [label="input: llm_model" shape=oval style=dashed] - input__llm_model -> generate_answer_node - fetch_node -> parse_node [style=solid] - parse_node -> rag_node [style=solid] - rag_node -> generate_answer_node [style=solid] -} diff --git a/scrapegraphai/integrations/burr_bridge.py b/scrapegraphai/integrations/burr_bridge.py index 3b687015..bd8df466 100644 --- a/scrapegraphai/integrations/burr_bridge.py +++ b/scrapegraphai/integrations/burr_bridge.py @@ -8,7 +8,6 @@ from typing import Any, Dict, List, Tuple from burr import tracking from burr.core import Application, ApplicationBuilder, State, Action, default -from burr.core.action import action from burr.lifecycle import PostRunStepHook, PreRunStepHook @@ -40,7 +39,7 @@ class BurrNodeBridge(Action): return parse_boolean_expression(self.node.input) def run(self, state: State, **run_kwargs) -> dict: - node_inputs = {key: state[key] for key in self.reads} + node_inputs = {key: state[key] for key in self.reads if key in state} result_state = self.node.execute(node_inputs, **run_kwargs) return result_state @@ -49,7 +48,7 @@ class BurrNodeBridge(Action): return self.node.output def update(self, result: dict, state: State) -> State: - return state.update(**state) + return state.update(**result) def parse_boolean_expression(expression: str) -> List[str]: @@ -92,7 +91,8 @@ class BurrBridge: def __init__(self, base_graph, burr_config): self.base_graph = base_graph self.burr_config = burr_config - self.tracker = tracking.LocalTrackingClient(project="smart-scraper-graph") + self.project_name = burr_config.get("project_name", "default-project") + self.tracker = tracking.LocalTrackingClient(project=self.project_name) self.app_instance_id = burr_config.get("app_instance_id", "default-instance") self.burr_inputs = burr_config.get("inputs", {}) self.burr_app = None @@ -111,7 +111,7 @@ class BurrBridge: actions = self._create_actions() transitions = self._create_transitions() hooks = [PrintLnHook()] - burr_state = self._convert_state_to_burr(initial_state) + burr_state = State(initial_state) app = ( ApplicationBuilder() @@ -136,32 +136,10 @@ class BurrBridge: actions = {} for node in self.base_graph.nodes: - action_func = self._create_action(node) + action_func = BurrNodeBridge(node) actions[node.node_name] = action_func return actions - def _create_action(self, node) -> Any: - """ - Create a Burr action function from a base graph node. - - Args: - node (Node): The base graph node to convert to a Burr action. - - Returns: - function: The Burr action function. - """ - - # @action(reads=parse_boolean_expression(node.input), writes=node.output) - # def dynamic_action(state: State, **kwargs): - # node_inputs = {key: state[key] for key in self._parse_boolean_expression(node.input)} - # result_state = node.execute(node_inputs, **kwargs) - # return result_state, state.update(**result_state) - # - # return dynamic_action - # import pdb - # pdb.set_trace() - return BurrNodeBridge(node) - def _create_transitions(self) -> List[Tuple[str, str, Any]]: """ Create Burr transitions from the base graph edges. @@ -175,22 +153,6 @@ class BurrBridge: transitions.append((from_node, to_node, default)) return transitions - def _convert_state_to_burr(self, state: Dict[str, Any]) -> State: - """ - Convert a dictionary state to a Burr state. - - Args: - state (dict): The dictionary state to convert. - - Returns: - State: The Burr state instance. - """ - - burr_state = State() - for key, value in state.items(): - setattr(burr_state, key, value) - return burr_state - def _convert_state_from_burr(self, burr_state: State) -> Dict[str, Any]: """ Convert a Burr state to a dictionary state. @@ -223,7 +185,6 @@ class BurrBridge: # TODO: to fix final nodes detection final_nodes = [self.burr_app.graph.actions[-1].name] - # TODO: fix inputs last_action, result, final_state = self.burr_app.run( halt_after=final_nodes, inputs=self.burr_inputs