From 6cbd84f254ebc1f1c68699273bdd8fcdb0fe26d4 Mon Sep 17 00:00:00 2001 From: Marco Perini Date: Tue, 21 May 2024 19:57:10 +0200 Subject: [PATCH] feat(burr-bridge): BurrBridge class to integrate inside BaseGraph --- .python-version | 1 - examples/openai/burr_integration_openai.py | 109 ++++++++++ pyproject.toml | 1 + requirements-dev.lock | 157 ++++++++++++++ requirements.lock | 156 ++++++++++++++ scrapegraphai/graphs/__init__.py | 1 - scrapegraphai/graphs/base_graph.py | 69 +++--- .../graphs/smart_scraper_graph_burr.py | 2 +- scrapegraphai/graphs/turbo_scraper.py | 146 ------------- scrapegraphai/integrations/__init__.py | 1 + scrapegraphai/integrations/burr_bridge.py | 198 ++++++++++++++++++ 11 files changed, 668 insertions(+), 173 deletions(-) delete mode 100644 .python-version create mode 100644 examples/openai/burr_integration_openai.py delete mode 100644 scrapegraphai/graphs/turbo_scraper.py create mode 100644 scrapegraphai/integrations/__init__.py create mode 100644 scrapegraphai/integrations/burr_bridge.py diff --git a/.python-version b/.python-version deleted file mode 100644 index 8e34c813..00000000 --- a/.python-version +++ /dev/null @@ -1 +0,0 @@ -3.9.19 diff --git a/examples/openai/burr_integration_openai.py b/examples/openai/burr_integration_openai.py new file mode 100644 index 00000000..0c95c231 --- /dev/null +++ b/examples/openai/burr_integration_openai.py @@ -0,0 +1,109 @@ +""" +Example of custom graph using existing nodes +""" + +import os +from dotenv import load_dotenv + +from langchain_openai import OpenAIEmbeddings +from scrapegraphai.models import OpenAI +from scrapegraphai.graphs import BaseGraph +from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode +load_dotenv() + +# ************************************************ +# Define the configuration for the graph +# ************************************************ + +openai_key = os.getenv("OPENAI_APIKEY") + +graph_config = { + "llm": { + "api_key": openai_key, + "model": "gpt-3.5-turbo", + "temperature": 0, + "streaming": False + }, +} + +# ************************************************ +# Define the graph nodes +# ************************************************ + +llm_model = OpenAI(graph_config["llm"]) +embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key) + +# define the nodes for the graph + +fetch_node = FetchNode( + input="url | local_dir", + output=["doc", "link_urls", "img_urls"], + node_config={ + "verbose": True, + "headless": True, + } +) +parse_node = ParseNode( + input="doc", + output=["parsed_doc"], + node_config={ + "chunk_size": 4096, + "verbose": True, + } +) +rag_node = RAGNode( + input="user_prompt & (parsed_doc | doc)", + output=["relevant_chunks"], + node_config={ + "llm_model": llm_model, + "embedder_model": embedder, + "verbose": True, + } +) +generate_answer_node = GenerateAnswerNode( + input="user_prompt & (relevant_chunks | parsed_doc | doc)", + output=["answer"], + node_config={ + "llm_model": llm_model, + "verbose": True, + } +) + +# ************************************************ +# Create the graph by defining the connections +# ************************************************ + +graph = BaseGraph( + nodes=[ + fetch_node, + parse_node, + rag_node, + generate_answer_node, + ], + edges=[ + (fetch_node, parse_node), + (parse_node, rag_node), + (rag_node, generate_answer_node) + ], + entry_point=fetch_node, + use_burr=True, + burr_config={ + "app_instance_id": "custom_graph_openai", + "inputs": { + "llm_model": graph_config["llm"].get("model", "gpt-3.5-turbo"), + } + } +) + +# ************************************************ +# Execute the graph +# ************************************************ + +result, execution_info = graph.execute({ + "user_prompt": "Describe the content", + "url": "https://example.com/" +}) + +# get the answer from the result +result = result.get("answer", "No answer found.") +print(result) diff --git a/pyproject.toml b/pyproject.toml index d862966e..5f85f19a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "playwright==1.43.0", "google==3.0.0", "yahoo-search-py==0.3", + "burr[start]" ] license = "MIT" diff --git a/requirements-dev.lock b/requirements-dev.lock index 7c37321b..89789099 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -8,11 +8,15 @@ # with-sources: false -e file:. +aiofiles==23.2.1 + # via burr aiohttp==3.9.5 # via langchain # via langchain-community aiosignal==1.3.1 # via aiohttp +altair==5.3.0 + # via streamlit annotated-types==0.6.0 # via pydantic anthropic==0.25.9 @@ -22,27 +26,51 @@ anyio==4.3.0 # via groq # via httpx # via openai + # via starlette + # via watchfiles async-timeout==4.0.3 # via aiohttp # via langchain attrs==23.2.0 # via aiohttp + # via jsonschema + # via referencing beautifulsoup4==4.12.3 # via google # via scrapegraphai +blinker==1.8.2 + # via streamlit boto3==1.34.105 # via langchain-aws botocore==1.34.105 # via boto3 # via s3transfer +burr==0.17.1 + # via scrapegraphai cachetools==5.3.3 # via google-auth + # via streamlit certifi==2024.2.2 # via httpcore # via httpx # via requests charset-normalizer==3.3.2 # via requests +click==8.1.7 + # via burr + # via streamlit + # via typer + # via uvicorn +colorama==0.4.6 + # via click + # via loguru + # via pytest + # via tqdm + # via uvicorn +contourpy==1.2.1 + # via matplotlib +cycler==0.12.1 + # via matplotlib dataclasses-json==0.6.6 # via langchain # via langchain-community @@ -52,13 +80,26 @@ distro==1.9.0 # via anthropic # via groq # via openai +dnspython==2.6.1 + # via email-validator +email-validator==2.1.1 + # via fastapi exceptiongroup==1.2.1 # via anyio # via pytest faiss-cpu==1.8.0 # via scrapegraphai +fastapi==0.111.0 + # via burr + # via fastapi-pagination +fastapi-cli==0.0.4 + # via fastapi +fastapi-pagination==0.12.24 + # via burr filelock==3.14.0 # via huggingface-hub +fonttools==4.51.0 + # via matplotlib free-proxy==1.1.1 # via scrapegraphai frozenlist==1.4.1 @@ -66,6 +107,10 @@ frozenlist==1.4.1 # via aiosignal fsspec==2024.3.1 # via huggingface-hub +gitdb==4.0.11 + # via gitpython +gitpython==3.1.43 + # via streamlit google==3.0.0 # via scrapegraphai google-ai-generativelanguage==0.6.3 @@ -90,6 +135,7 @@ googleapis-common-protos==1.63.0 # via google-api-core # via grpcio-status graphviz==0.20.3 + # via burr # via scrapegraphai greenlet==3.0.3 # via playwright @@ -103,6 +149,7 @@ grpcio-status==1.62.2 # via google-api-core h11==0.14.0 # via httpcore + # via uvicorn html2text==2024.2.26 # via scrapegraphai httpcore==1.0.5 @@ -110,8 +157,11 @@ httpcore==1.0.5 httplib2==0.22.0 # via google-api-python-client # via google-auth-httplib2 +httptools==0.6.1 + # via uvicorn httpx==0.27.0 # via anthropic + # via fastapi # via groq # via openai # via yahoo-search-py @@ -119,11 +169,17 @@ huggingface-hub==0.23.0 # via tokenizers idna==3.7 # via anyio + # via email-validator # via httpx # via requests # via yarl iniconfig==2.0.0 # via pytest +jinja2==3.1.4 + # via altair + # via burr + # via fastapi + # via pydeck jmespath==1.0.1 # via boto3 # via botocore @@ -132,6 +188,12 @@ jsonpatch==1.33 # via langchain-core jsonpointer==2.4 # via jsonpatch +jsonschema==4.22.0 + # via altair +jsonschema-specifications==2023.12.1 + # via jsonschema +kiwisolver==1.4.5 + # via matplotlib langchain==0.1.15 # via scrapegraphai langchain-anthropic==0.1.11 @@ -161,10 +223,20 @@ langsmith==0.1.57 # via langchain # via langchain-community # via langchain-core +loguru==0.7.2 + # via burr lxml==5.2.2 # via free-proxy +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.5 + # via jinja2 marshmallow==3.21.2 # via dataclasses-json +matplotlib==3.9.0 + # via burr +mdurl==0.1.2 + # via markdown-it-py minify-html==0.15.0 # via scrapegraphai multidict==6.0.5 @@ -173,22 +245,40 @@ multidict==6.0.5 mypy-extensions==1.0.0 # via typing-inspect numpy==1.26.4 + # via altair + # via contourpy # via faiss-cpu # via langchain # via langchain-aws # via langchain-community + # via matplotlib # via pandas + # via pyarrow + # via pydeck + # via sf-hamilton + # via streamlit openai==1.30.1 + # via burr # via langchain-openai orjson==3.10.3 + # via fastapi # via langsmith packaging==23.2 + # via altair # via huggingface-hub # via langchain-core # via marshmallow + # via matplotlib # via pytest + # via streamlit pandas==2.2.2 + # via altair # via scrapegraphai + # via sf-hamilton + # via streamlit +pillow==10.3.0 + # via matplotlib + # via streamlit playwright==1.43.0 # via scrapegraphai pluggy==1.5.0 @@ -203,6 +293,9 @@ protobuf==4.25.3 # via googleapis-common-protos # via grpcio-status # via proto-plus + # via streamlit +pyarrow==16.1.0 + # via streamlit pyasn1==0.6.0 # via pyasn1-modules # via rsa @@ -210,6 +303,9 @@ pyasn1-modules==0.4.0 # via google-auth pydantic==2.7.1 # via anthropic + # via burr + # via fastapi + # via fastapi-pagination # via google-generativeai # via groq # via langchain @@ -219,18 +315,27 @@ pydantic==2.7.1 # via yahoo-search-py pydantic-core==2.18.2 # via pydantic +pydeck==0.9.1 + # via streamlit pyee==11.1.0 # via playwright +pygments==2.18.0 + # via rich pyparsing==3.1.2 # via httplib2 + # via matplotlib pytest==8.0.0 # via pytest-mock pytest-mock==3.14.0 python-dateutil==2.9.0.post0 # via botocore + # via matplotlib # via pandas python-dotenv==1.0.1 # via scrapegraphai + # via uvicorn +python-multipart==0.0.9 + # via fastapi pytz==2024.1 # via pandas pyyaml==6.0.1 @@ -238,24 +343,42 @@ pyyaml==6.0.1 # via langchain # via langchain-community # via langchain-core + # via uvicorn +referencing==0.35.1 + # via jsonschema + # via jsonschema-specifications regex==2024.5.10 # via tiktoken requests==2.31.0 + # via burr # via free-proxy # via google-api-core # via huggingface-hub # via langchain # via langchain-community # via langsmith + # via streamlit # via tiktoken +rich==13.7.1 + # via streamlit + # via typer +rpds-py==0.18.1 + # via jsonschema + # via referencing rsa==4.9 # via google-auth s3transfer==0.10.1 # via boto3 selectolax==0.3.21 # via yahoo-search-py +sf-hamilton==1.62.0 + # via burr +shellingham==1.5.4 + # via typer six==1.16.0 # via python-dateutil +smmap==5.0.1 + # via gitdb sniffio==1.3.1 # via anthropic # via anyio @@ -267,25 +390,41 @@ soupsieve==2.5 sqlalchemy==2.0.30 # via langchain # via langchain-community +starlette==0.37.2 + # via fastapi +streamlit==1.34.0 + # via burr tenacity==8.3.0 # via langchain # via langchain-community # via langchain-core + # via streamlit tiktoken==0.6.0 # via langchain-openai # via scrapegraphai tokenizers==0.19.1 # via anthropic +toml==0.10.2 + # via streamlit tomli==2.0.1 # via pytest +toolz==0.12.1 + # via altair +tornado==6.4 + # via streamlit tqdm==4.66.4 # via google-generativeai # via huggingface-hub # via openai # via scrapegraphai +typer==0.12.3 + # via fastapi-cli typing-extensions==4.11.0 + # via altair # via anthropic # via anyio + # via fastapi + # via fastapi-pagination # via google-generativeai # via groq # via huggingface-hub @@ -293,18 +432,36 @@ typing-extensions==4.11.0 # via pydantic # via pydantic-core # via pyee + # via sf-hamilton # via sqlalchemy + # via streamlit + # via typer # via typing-inspect + # via uvicorn typing-inspect==0.9.0 # via dataclasses-json + # via sf-hamilton tzdata==2024.1 # via pandas +ujson==5.10.0 + # via fastapi uritemplate==4.1.1 # via google-api-python-client urllib3==1.26.18 # via botocore # via requests # via yahoo-search-py +uvicorn==0.29.0 + # via burr + # via fastapi +watchdog==4.0.0 + # via streamlit +watchfiles==0.21.0 + # via uvicorn +websockets==12.0 + # via uvicorn +win32-setctime==1.1.0 + # via loguru yahoo-search-py==0.3 # via scrapegraphai yarl==1.9.4 diff --git a/requirements.lock b/requirements.lock index c02d4522..b0872619 100644 --- a/requirements.lock +++ b/requirements.lock @@ -8,11 +8,15 @@ # with-sources: false -e file:. +aiofiles==23.2.1 + # via burr aiohttp==3.9.5 # via langchain # via langchain-community aiosignal==1.3.1 # via aiohttp +altair==5.3.0 + # via streamlit annotated-types==0.6.0 # via pydantic anthropic==0.25.9 @@ -22,27 +26,50 @@ anyio==4.3.0 # via groq # via httpx # via openai + # via starlette + # via watchfiles async-timeout==4.0.3 # via aiohttp # via langchain attrs==23.2.0 # via aiohttp + # via jsonschema + # via referencing beautifulsoup4==4.12.3 # via google # via scrapegraphai +blinker==1.8.2 + # via streamlit boto3==1.34.105 # via langchain-aws botocore==1.34.105 # via boto3 # via s3transfer +burr==0.17.1 + # via scrapegraphai cachetools==5.3.3 # via google-auth + # via streamlit certifi==2024.2.2 # via httpcore # via httpx # via requests charset-normalizer==3.3.2 # via requests +click==8.1.7 + # via burr + # via streamlit + # via typer + # via uvicorn +colorama==0.4.6 + # via click + # via loguru + # via tqdm + # via uvicorn +contourpy==1.2.1 + # via matplotlib +cycler==0.12.1 + # via matplotlib dataclasses-json==0.6.6 # via langchain # via langchain-community @@ -52,12 +79,25 @@ distro==1.9.0 # via anthropic # via groq # via openai +dnspython==2.6.1 + # via email-validator +email-validator==2.1.1 + # via fastapi exceptiongroup==1.2.1 # via anyio faiss-cpu==1.8.0 # via scrapegraphai +fastapi==0.111.0 + # via burr + # via fastapi-pagination +fastapi-cli==0.0.4 + # via fastapi +fastapi-pagination==0.12.24 + # via burr filelock==3.14.0 # via huggingface-hub +fonttools==4.51.0 + # via matplotlib free-proxy==1.1.1 # via scrapegraphai frozenlist==1.4.1 @@ -65,6 +105,10 @@ frozenlist==1.4.1 # via aiosignal fsspec==2024.3.1 # via huggingface-hub +gitdb==4.0.11 + # via gitpython +gitpython==3.1.43 + # via streamlit google==3.0.0 # via scrapegraphai google-ai-generativelanguage==0.6.3 @@ -89,6 +133,7 @@ googleapis-common-protos==1.63.0 # via google-api-core # via grpcio-status graphviz==0.20.3 + # via burr # via scrapegraphai greenlet==3.0.3 # via playwright @@ -102,6 +147,7 @@ grpcio-status==1.62.2 # via google-api-core h11==0.14.0 # via httpcore + # via uvicorn html2text==2024.2.26 # via scrapegraphai httpcore==1.0.5 @@ -109,8 +155,11 @@ httpcore==1.0.5 httplib2==0.22.0 # via google-api-python-client # via google-auth-httplib2 +httptools==0.6.1 + # via uvicorn httpx==0.27.0 # via anthropic + # via fastapi # via groq # via openai # via yahoo-search-py @@ -118,9 +167,15 @@ huggingface-hub==0.23.0 # via tokenizers idna==3.7 # via anyio + # via email-validator # via httpx # via requests # via yarl +jinja2==3.1.4 + # via altair + # via burr + # via fastapi + # via pydeck jmespath==1.0.1 # via boto3 # via botocore @@ -129,6 +184,12 @@ jsonpatch==1.33 # via langchain-core jsonpointer==2.4 # via jsonpatch +jsonschema==4.22.0 + # via altair +jsonschema-specifications==2023.12.1 + # via jsonschema +kiwisolver==1.4.5 + # via matplotlib langchain==0.1.15 # via scrapegraphai langchain-anthropic==0.1.11 @@ -158,10 +219,20 @@ langsmith==0.1.57 # via langchain # via langchain-community # via langchain-core +loguru==0.7.2 + # via burr lxml==5.2.2 # via free-proxy +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.5 + # via jinja2 marshmallow==3.21.2 # via dataclasses-json +matplotlib==3.9.0 + # via burr +mdurl==0.1.2 + # via markdown-it-py minify-html==0.15.0 # via scrapegraphai multidict==6.0.5 @@ -170,21 +241,39 @@ multidict==6.0.5 mypy-extensions==1.0.0 # via typing-inspect numpy==1.26.4 + # via altair + # via contourpy # via faiss-cpu # via langchain # via langchain-aws # via langchain-community + # via matplotlib # via pandas + # via pyarrow + # via pydeck + # via sf-hamilton + # via streamlit openai==1.30.1 + # via burr # via langchain-openai orjson==3.10.3 + # via fastapi # via langsmith packaging==23.2 + # via altair # via huggingface-hub # via langchain-core # via marshmallow + # via matplotlib + # via streamlit pandas==2.2.2 + # via altair # via scrapegraphai + # via sf-hamilton + # via streamlit +pillow==10.3.0 + # via matplotlib + # via streamlit playwright==1.43.0 # via scrapegraphai proto-plus==1.23.0 @@ -197,6 +286,9 @@ protobuf==4.25.3 # via googleapis-common-protos # via grpcio-status # via proto-plus + # via streamlit +pyarrow==16.1.0 + # via streamlit pyasn1==0.6.0 # via pyasn1-modules # via rsa @@ -204,6 +296,9 @@ pyasn1-modules==0.4.0 # via google-auth pydantic==2.7.1 # via anthropic + # via burr + # via fastapi + # via fastapi-pagination # via google-generativeai # via groq # via langchain @@ -213,15 +308,24 @@ pydantic==2.7.1 # via yahoo-search-py pydantic-core==2.18.2 # via pydantic +pydeck==0.9.1 + # via streamlit pyee==11.1.0 # via playwright +pygments==2.18.0 + # via rich pyparsing==3.1.2 # via httplib2 + # via matplotlib python-dateutil==2.9.0.post0 # via botocore + # via matplotlib # via pandas python-dotenv==1.0.1 # via scrapegraphai + # via uvicorn +python-multipart==0.0.9 + # via fastapi pytz==2024.1 # via pandas pyyaml==6.0.1 @@ -229,24 +333,42 @@ pyyaml==6.0.1 # via langchain # via langchain-community # via langchain-core + # via uvicorn +referencing==0.35.1 + # via jsonschema + # via jsonschema-specifications regex==2024.5.10 # via tiktoken requests==2.31.0 + # via burr # via free-proxy # via google-api-core # via huggingface-hub # via langchain # via langchain-community # via langsmith + # via streamlit # via tiktoken +rich==13.7.1 + # via streamlit + # via typer +rpds-py==0.18.1 + # via jsonschema + # via referencing rsa==4.9 # via google-auth s3transfer==0.10.1 # via boto3 selectolax==0.3.21 # via yahoo-search-py +sf-hamilton==1.62.0 + # via burr +shellingham==1.5.4 + # via typer six==1.16.0 # via python-dateutil +smmap==5.0.1 + # via gitdb sniffio==1.3.1 # via anthropic # via anyio @@ -258,23 +380,39 @@ soupsieve==2.5 sqlalchemy==2.0.30 # via langchain # via langchain-community +starlette==0.37.2 + # via fastapi +streamlit==1.34.0 + # via burr tenacity==8.3.0 # via langchain # via langchain-community # via langchain-core + # via streamlit tiktoken==0.6.0 # via langchain-openai # via scrapegraphai tokenizers==0.19.1 # via anthropic +toml==0.10.2 + # via streamlit +toolz==0.12.1 + # via altair +tornado==6.4 + # via streamlit tqdm==4.66.4 # via google-generativeai # via huggingface-hub # via openai # via scrapegraphai +typer==0.12.3 + # via fastapi-cli typing-extensions==4.11.0 + # via altair # via anthropic # via anyio + # via fastapi + # via fastapi-pagination # via google-generativeai # via groq # via huggingface-hub @@ -282,18 +420,36 @@ typing-extensions==4.11.0 # via pydantic # via pydantic-core # via pyee + # via sf-hamilton # via sqlalchemy + # via streamlit + # via typer # via typing-inspect + # via uvicorn typing-inspect==0.9.0 # via dataclasses-json + # via sf-hamilton tzdata==2024.1 # via pandas +ujson==5.10.0 + # via fastapi uritemplate==4.1.1 # via google-api-python-client urllib3==1.26.18 # via botocore # via requests # via yahoo-search-py +uvicorn==0.29.0 + # via burr + # via fastapi +watchdog==4.0.0 + # via streamlit +watchfiles==0.21.0 + # via uvicorn +websockets==12.0 + # via uvicorn +win32-setctime==1.1.0 + # via loguru yahoo-search-py==0.3 # via scrapegraphai yarl==1.9.4 diff --git a/scrapegraphai/graphs/__init__.py b/scrapegraphai/graphs/__init__.py index 10eb6d8e..15f4a4ec 100644 --- a/scrapegraphai/graphs/__init__.py +++ b/scrapegraphai/graphs/__init__.py @@ -15,4 +15,3 @@ from .csv_scraper_graph import CSVScraperGraph from .pdf_scraper_graph import PDFScraperGraph from .omni_scraper_graph import OmniScraperGraph from .omni_search_graph import OmniSearchGraph -from .turbo_scraper import TurboScraperGraph diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index 867d774f..06791528 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -7,6 +7,8 @@ import warnings from langchain_community.callbacks import get_openai_callback from typing import Tuple +from ..integrations import BurrBridge + class BaseGraph: """ @@ -40,20 +42,27 @@ class BaseGraph: ... (parse_node, rag_node), ... (rag_node, generate_answer_node) ... ], - ... entry_point=fetch_node + ... entry_point=fetch_node, + ... use_burr=True, + ... burr_config={"app_instance_id": "example-instance"} ... ) """ - def __init__(self, nodes: list, edges: list, entry_point: str): + def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool = False, burr_config: dict = None): self.nodes = nodes self.edges = self._create_edges({e for e in edges}) self.entry_point = entry_point.node_name + self.initial_state = {} if nodes[0].node_name != entry_point.node_name: # raise a warning if the entry point is not the first node in the list warnings.warn( "Careful! The entry point node is different from the first node if the graph.") + + # Burr configuration + self.use_burr = use_burr + self.burr_config = burr_config or {} def _create_edges(self, edges: list) -> dict: """ @@ -71,11 +80,9 @@ class BaseGraph: edge_dict[from_node.node_name] = to_node.node_name return edge_dict - def execute(self, initial_state: dict) -> Tuple[dict, list]: + def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]: """ - Executes the graph by traversing nodes starting from the entry point. The execution - follows the edges based on the result of each node's execution and continues until - it reaches a node with no outgoing edges. + Executes the graph by traversing nodes starting from the entry point using the standard method. Args: initial_state (dict): The initial state to pass to the entry point node. @@ -83,8 +90,7 @@ class BaseGraph: Returns: Tuple[dict, list]: A tuple containing the final state and a list of execution info. """ - - current_node_name = self.nodes[0] + current_node_name = self.entry_point state = initial_state # variables for tracking execution info @@ -98,18 +104,17 @@ class BaseGraph: "total_cost_USD": 0.0, } - for index in self.nodes: - + while current_node_name: curr_time = time.time() - current_node = index + current_node = next(node for node in self.nodes if node.node_name == current_node_name) with get_openai_callback() as cb: result = current_node.execute(state) node_exec_time = time.time() - curr_time total_exec_time += node_exec_time - cb = { - "node_name": index.node_name, + cb_data = { + "node_name": current_node.node_name, "total_tokens": cb.total_tokens, "prompt_tokens": cb.prompt_tokens, "completion_tokens": cb.completion_tokens, @@ -118,15 +123,13 @@ class BaseGraph: "exec_time": node_exec_time, } - exec_info.append( - cb - ) + exec_info.append(cb_data) - cb_total["total_tokens"] += cb["total_tokens"] - cb_total["prompt_tokens"] += cb["prompt_tokens"] - cb_total["completion_tokens"] += cb["completion_tokens"] - cb_total["successful_requests"] += cb["successful_requests"] - cb_total["total_cost_USD"] += cb["total_cost_USD"] + cb_total["total_tokens"] += cb_data["total_tokens"] + cb_total["prompt_tokens"] += cb_data["prompt_tokens"] + cb_total["completion_tokens"] += cb_data["completion_tokens"] + cb_total["successful_requests"] += cb_data["successful_requests"] + cb_total["total_cost_USD"] += cb_data["total_cost_USD"] if current_node.node_type == "conditional_node": current_node_name = result @@ -137,12 +140,30 @@ class BaseGraph: exec_info.append({ "node_name": "TOTAL RESULT", - "total_tokens": cb_total["total_tokens"], - "prompt_tokens": cb_total["prompt_tokens"], + "total_tokens": cb_total["total_tokens"], + "prompt_tokens": cb_total["prompt_tokens"], "completion_tokens": cb_total["completion_tokens"], "successful_requests": cb_total["successful_requests"], - "total_cost_USD": cb_total["total_cost_USD"], + "total_cost_USD": cb_total["total_cost_USD"], "exec_time": total_exec_time, }) return state, exec_info + + def execute(self, initial_state: dict) -> Tuple[dict, list]: + """ + Executes the graph by either using BurrBridge or the standard method. + + Args: + initial_state (dict): The initial state to pass to the entry point node. + + Returns: + Tuple[dict, list]: A tuple containing the final state and a list of execution info. + """ + + self.initial_state = initial_state + if self.use_burr: + bridge = BurrBridge(self, self.burr_config) + return bridge.execute(initial_state) + else: + return self._execute_standard(initial_state) \ No newline at end of file diff --git a/scrapegraphai/graphs/smart_scraper_graph_burr.py b/scrapegraphai/graphs/smart_scraper_graph_burr.py index 388200a5..eccdf908 100644 --- a/scrapegraphai/graphs/smart_scraper_graph_burr.py +++ b/scrapegraphai/graphs/smart_scraper_graph_burr.py @@ -25,7 +25,7 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter from tqdm import tqdm if __name__ == '__main__': - from scrapegraphai.utils.remover import remover + from scrapegraphai.utils import cleanup_html else: from ..utils.remover import remover diff --git a/scrapegraphai/graphs/turbo_scraper.py b/scrapegraphai/graphs/turbo_scraper.py deleted file mode 100644 index 2881fd76..00000000 --- a/scrapegraphai/graphs/turbo_scraper.py +++ /dev/null @@ -1,146 +0,0 @@ -""" -SmartScraperGraph Module -""" - -from .base_graph import BaseGraph -from ..nodes import ( - FetchNode, - ParseNode, - RAGNode, - SearchLinksWithContext, - GraphIteratorNode, - MergeAnswersNode -) -from .search_graph import SearchGraph -from .abstract_graph import AbstractGraph - - -class SmartScraperGraph(AbstractGraph): - """ - SmartScraper is a scraping pipeline that automates the process of - extracting information from web pages - using a natural language model to interpret and answer prompts. - - Attributes: - prompt (str): The prompt for the graph. - source (str): The source of the graph. - config (dict): Configuration parameters for the graph. - llm_model: An instance of a language model client, configured for generating answers. - embedder_model: An instance of an embedding model client, - configured for generating embeddings. - verbose (bool): A flag indicating whether to show print statements during execution. - headless (bool): A flag indicating whether to run the graph in headless mode. - - Args: - prompt (str): The prompt for the graph. - source (str): The source of the graph. - config (dict): Configuration parameters for the graph. - - Example: - >>> smart_scraper = SmartScraperGraph( - ... "List me all the attractions in Chioggia.", - ... "https://en.wikipedia.org/wiki/Chioggia", - ... {"llm": {"model": "gpt-3.5-turbo"}} - ... ) - >>> result = smart_scraper.run() - ) - """ - - def __init__(self, prompt: str, source: str, config: dict): - super().__init__(prompt, config, source) - - self.input_key = "url" if source.startswith("http") else "local_dir" - - def _create_graph(self) -> BaseGraph: - """ - Creates the graph of nodes representing the workflow for web scraping. - - Returns: - BaseGraph: A graph instance representing the web scraping workflow. - """ - smart_scraper_graph = SmartScraperGraph( - prompt="", - source="", - config=self.llm_model - ) - fetch_node = FetchNode( - input="url | local_dir", - output=["doc"] - ) - - parse_node = ParseNode( - input="doc", - output=["parsed_doc"], - node_config={ - "chunk_size": self.model_token - } - ) - - rag_node = RAGNode( - input="user_prompt & (parsed_doc | doc)", - output=["relevant_chunks"], - node_config={ - "llm_model": self.llm_model, - "embedder_model": self.embedder_model - } - ) - - search_link_with_context_node = SearchLinksWithContext( - input="user_prompt & (relevant_chunks | parsed_doc | doc)", - output=["answer"], - node_config={ - "llm_model": self.llm_model - } - ) - - graph_iterator_node = GraphIteratorNode( - input="user_prompt & urls", - output=["results"], - node_config={ - "graph_instance": smart_scraper_graph, - "verbose": True, - } - ) - - merge_answers_node = MergeAnswersNode( - input="user_prompt & results", - output=["answer"], - node_config={ - "llm_model": self.llm_model, - "verbose": True, - } - ) - - return BaseGraph( - nodes=[ - fetch_node, - parse_node, - rag_node, - search_link_with_context_node, - graph_iterator_node, - merge_answers_node - - ], - edges=[ - (fetch_node, parse_node), - (parse_node, rag_node), - (rag_node, search_link_with_context_node), - (search_link_with_context_node, graph_iterator_node), - (graph_iterator_node, merge_answers_node), - - ], - entry_point=fetch_node - ) - - def run(self) -> str: - """ - Executes the scraping process and returns the answer to the prompt. - - Returns: - str: The answer to the prompt. - """ - - inputs = {"user_prompt": self.prompt, self.input_key: self.source} - self.final_state, self.execution_info = self.graph.execute(inputs) - - return self.final_state.get("answer", "No answer found.") diff --git a/scrapegraphai/integrations/__init__.py b/scrapegraphai/integrations/__init__.py new file mode 100644 index 00000000..97589cd0 --- /dev/null +++ b/scrapegraphai/integrations/__init__.py @@ -0,0 +1 @@ +from .burr_bridge import BurrBridge \ No newline at end of file diff --git a/scrapegraphai/integrations/burr_bridge.py b/scrapegraphai/integrations/burr_bridge.py new file mode 100644 index 00000000..27e39c83 --- /dev/null +++ b/scrapegraphai/integrations/burr_bridge.py @@ -0,0 +1,198 @@ +""" +Bridge class to integrate Burr into ScrapeGraphAI graphs +[Burr](https://github.com/DAGWorks-Inc/burr) +""" + +import re +from typing import Any, Dict, List, Tuple + +from burr import tracking +from burr.core import Application, ApplicationBuilder, State, Action, default +from burr.core.action import action +from burr.lifecycle import PostRunStepHook, PreRunStepHook + + +class PrintLnHook(PostRunStepHook, PreRunStepHook): + """ + Hook to print the action name before and after it is executed. + """ + + def pre_run_step(self, *, state: "State", action: "Action", **future_kwargs: Any): + print(f"Starting action: {action.name}") + + def post_run_step(self, *, state: "State", action: "Action", **future_kwargs: Any): + print(f"Finishing action: {action.name}") + +class BurrBridge: + """ + Bridge class to integrate Burr into ScrapeGraphAI graphs. + + Args: + base_graph (BaseGraph): The base graph to convert to a Burr application. + burr_config (dict): Configuration parameters for the Burr application. + + Attributes: + base_graph (BaseGraph): The base graph to convert to a Burr application. + burr_config (dict): Configuration parameters for the Burr application. + tracker (LocalTrackingClient): The tracking client for the Burr application. + app_instance_id (str): The instance ID for the Burr application. + burr_inputs (dict): The inputs for the Burr application. + burr_app (Application): The Burr application instance. + + Example: + >>> burr_bridge = BurrBridge(base_graph, burr_config) + >>> result = burr_bridge.execute(initial_state={"input_key": "input_value"}) + """ + + def __init__(self, base_graph, burr_config): + self.base_graph = base_graph + self.burr_config = burr_config + self.tracker = tracking.LocalTrackingClient(project="smart-scraper-graph") + self.app_instance_id = burr_config.get("app_instance_id", "default-instance") + self.burr_inputs = burr_config.get("inputs", {}) + self.burr_app = None + + def _initialize_burr_app(self, initial_state: Dict[str, Any] = {}) -> Application: + """ + Initialize a Burr application from the base graph. + + Args: + initial_state (dict): The initial state of the Burr application. + + Returns: + Application: The Burr application instance. + """ + + actions = self._create_actions() + transitions = self._create_transitions() + hooks = [PrintLnHook()] + burr_state = self._convert_state_to_burr(initial_state) + + app = ( + ApplicationBuilder() + .with_actions(**actions) + .with_transitions(*transitions) + .with_entrypoint(self.base_graph.entry_point) + .with_state(**burr_state) + .with_identifiers(app_id=self.app_instance_id) + .with_tracker(self.tracker) + .with_hooks(*hooks) + .build() + ) + return app + + def _create_actions(self) -> Dict[str, Any]: + """ + Create Burr actions from the base graph nodes. + + Returns: + dict: A dictionary of Burr actions with the node name as keys and the action functions as values. + """ + + actions = {} + for node in self.base_graph.nodes: + action_func = self._create_action(node) + actions[node.node_name] = action_func + return actions + + def _create_action(self, node) -> Any: + """ + Create a Burr action function from a base graph node. + + Args: + node (Node): The base graph node to convert to a Burr action. + + Returns: + function: The Burr action function. + """ + + @action(reads=self._parse_boolean_expression(node.input), writes=node.output) + def dynamic_action(state: State, **kwargs): + node_inputs = {key: state[key] for key in self._parse_boolean_expression(node.input)} + result_state = node.execute(node_inputs, **kwargs) + return result_state, state.update(**result_state) + return dynamic_action + + def _create_transitions(self) -> List[Tuple[str, str, Any]]: + """ + Create Burr transitions from the base graph edges. + + Returns: + list: A list of tuples representing the transitions between Burr actions. + """ + + transitions = [] + for from_node, to_node in self.base_graph.edges.items(): + transitions.append((from_node, to_node, default)) + return transitions + + def _parse_boolean_expression(self, expression: str) -> List[str]: + """ + Parse a boolean expression to extract the keys used in the expression, without boolean operators. + + Args: + expression (str): The boolean expression to parse. + + Returns: + list: A list of unique keys used in the expression. + """ + + # Use regular expression to extract all unique keys + keys = re.findall(r'\w+', expression) + return list(set(keys)) # Remove duplicates + + def _convert_state_to_burr(self, state: Dict[str, Any]) -> State: + """ + Convert a dictionary state to a Burr state. + + Args: + state (dict): The dictionary state to convert. + + Returns: + State: The Burr state instance. + """ + + burr_state = State() + for key, value in state.items(): + setattr(burr_state, key, value) + return burr_state + + def _convert_state_from_burr(self, burr_state: State) -> Dict[str, Any]: + """ + Convert a Burr state to a dictionary state. + + Args: + burr_state (State): The Burr state to convert. + + Returns: + dict: The dictionary state instance. + """ + + state = {} + for key in burr_state.__dict__.keys(): + state[key] = getattr(burr_state, key) + return state + + def execute(self, initial_state: Dict[str, Any] = {}) -> Dict[str, Any]: + """ + Execute the Burr application with the given initial state. + + Args: + initial_state (dict): The initial state to pass to the Burr application. + + Returns: + dict: The final state of the Burr application. + """ + + self.burr_app = self._initialize_burr_app(initial_state) + + # TODO: to fix final nodes detection + final_nodes = [self.burr_app.graph.actions[-1].name] + + # TODO: fix inputs + last_action, result, final_state = self.burr_app.run( + halt_after=final_nodes, + inputs=self.burr_inputs + ) + + return self._convert_state_from_burr(final_state) \ No newline at end of file