feat(burr-bridge): BurrBridge class to integrate inside BaseGraph

This commit is contained in:
Marco Perini 2024-05-21 19:57:10 +02:00
parent 0b5cdd48fd
commit 6cbd84f254
11 changed files with 668 additions and 173 deletions

View File

@ -1 +0,0 @@
3.9.19

View File

@ -0,0 +1,109 @@
"""
Example of custom graph using existing nodes
"""
import os
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
from scrapegraphai.models import OpenAI
from scrapegraphai.graphs import BaseGraph
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode
load_dotenv()
# ************************************************
# Define the configuration for the graph
# ************************************************
openai_key = os.getenv("OPENAI_APIKEY")
graph_config = {
"llm": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
"temperature": 0,
"streaming": False
},
}
# ************************************************
# Define the graph nodes
# ************************************************
llm_model = OpenAI(graph_config["llm"])
embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key)
# define the nodes for the graph
fetch_node = FetchNode(
input="url | local_dir",
output=["doc", "link_urls", "img_urls"],
node_config={
"verbose": True,
"headless": True,
}
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": 4096,
"verbose": True,
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm_model": llm_model,
"embedder_model": embedder,
"verbose": True,
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm_model": llm_model,
"verbose": True,
}
)
# ************************************************
# Create the graph by defining the connections
# ************************************************
graph = BaseGraph(
nodes=[
fetch_node,
parse_node,
rag_node,
generate_answer_node,
],
edges=[
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
],
entry_point=fetch_node,
use_burr=True,
burr_config={
"app_instance_id": "custom_graph_openai",
"inputs": {
"llm_model": graph_config["llm"].get("model", "gpt-3.5-turbo"),
}
}
)
# ************************************************
# Execute the graph
# ************************************************
result, execution_info = graph.execute({
"user_prompt": "Describe the content",
"url": "https://example.com/"
})
# get the answer from the result
result = result.get("answer", "No answer found.")
print(result)

View File

@ -29,6 +29,7 @@ dependencies = [
"playwright==1.43.0",
"google==3.0.0",
"yahoo-search-py==0.3",
"burr[start]"
]
license = "MIT"

View File

@ -8,11 +8,15 @@
# with-sources: false
-e file:.
aiofiles==23.2.1
# via burr
aiohttp==3.9.5
# via langchain
# via langchain-community
aiosignal==1.3.1
# via aiohttp
altair==5.3.0
# via streamlit
annotated-types==0.6.0
# via pydantic
anthropic==0.25.9
@ -22,27 +26,51 @@ anyio==4.3.0
# via groq
# via httpx
# via openai
# via starlette
# via watchfiles
async-timeout==4.0.3
# via aiohttp
# via langchain
attrs==23.2.0
# via aiohttp
# via jsonschema
# via referencing
beautifulsoup4==4.12.3
# via google
# via scrapegraphai
blinker==1.8.2
# via streamlit
boto3==1.34.105
# via langchain-aws
botocore==1.34.105
# via boto3
# via s3transfer
burr==0.17.1
# via scrapegraphai
cachetools==5.3.3
# via google-auth
# via streamlit
certifi==2024.2.2
# via httpcore
# via httpx
# via requests
charset-normalizer==3.3.2
# via requests
click==8.1.7
# via burr
# via streamlit
# via typer
# via uvicorn
colorama==0.4.6
# via click
# via loguru
# via pytest
# via tqdm
# via uvicorn
contourpy==1.2.1
# via matplotlib
cycler==0.12.1
# via matplotlib
dataclasses-json==0.6.6
# via langchain
# via langchain-community
@ -52,13 +80,26 @@ distro==1.9.0
# via anthropic
# via groq
# via openai
dnspython==2.6.1
# via email-validator
email-validator==2.1.1
# via fastapi
exceptiongroup==1.2.1
# via anyio
# via pytest
faiss-cpu==1.8.0
# via scrapegraphai
fastapi==0.111.0
# via burr
# via fastapi-pagination
fastapi-cli==0.0.4
# via fastapi
fastapi-pagination==0.12.24
# via burr
filelock==3.14.0
# via huggingface-hub
fonttools==4.51.0
# via matplotlib
free-proxy==1.1.1
# via scrapegraphai
frozenlist==1.4.1
@ -66,6 +107,10 @@ frozenlist==1.4.1
# via aiosignal
fsspec==2024.3.1
# via huggingface-hub
gitdb==4.0.11
# via gitpython
gitpython==3.1.43
# via streamlit
google==3.0.0
# via scrapegraphai
google-ai-generativelanguage==0.6.3
@ -90,6 +135,7 @@ googleapis-common-protos==1.63.0
# via google-api-core
# via grpcio-status
graphviz==0.20.3
# via burr
# via scrapegraphai
greenlet==3.0.3
# via playwright
@ -103,6 +149,7 @@ grpcio-status==1.62.2
# via google-api-core
h11==0.14.0
# via httpcore
# via uvicorn
html2text==2024.2.26
# via scrapegraphai
httpcore==1.0.5
@ -110,8 +157,11 @@ httpcore==1.0.5
httplib2==0.22.0
# via google-api-python-client
# via google-auth-httplib2
httptools==0.6.1
# via uvicorn
httpx==0.27.0
# via anthropic
# via fastapi
# via groq
# via openai
# via yahoo-search-py
@ -119,11 +169,17 @@ huggingface-hub==0.23.0
# via tokenizers
idna==3.7
# via anyio
# via email-validator
# via httpx
# via requests
# via yarl
iniconfig==2.0.0
# via pytest
jinja2==3.1.4
# via altair
# via burr
# via fastapi
# via pydeck
jmespath==1.0.1
# via boto3
# via botocore
@ -132,6 +188,12 @@ jsonpatch==1.33
# via langchain-core
jsonpointer==2.4
# via jsonpatch
jsonschema==4.22.0
# via altair
jsonschema-specifications==2023.12.1
# via jsonschema
kiwisolver==1.4.5
# via matplotlib
langchain==0.1.15
# via scrapegraphai
langchain-anthropic==0.1.11
@ -161,10 +223,20 @@ langsmith==0.1.57
# via langchain
# via langchain-community
# via langchain-core
loguru==0.7.2
# via burr
lxml==5.2.2
# via free-proxy
markdown-it-py==3.0.0
# via rich
markupsafe==2.1.5
# via jinja2
marshmallow==3.21.2
# via dataclasses-json
matplotlib==3.9.0
# via burr
mdurl==0.1.2
# via markdown-it-py
minify-html==0.15.0
# via scrapegraphai
multidict==6.0.5
@ -173,22 +245,40 @@ multidict==6.0.5
mypy-extensions==1.0.0
# via typing-inspect
numpy==1.26.4
# via altair
# via contourpy
# via faiss-cpu
# via langchain
# via langchain-aws
# via langchain-community
# via matplotlib
# via pandas
# via pyarrow
# via pydeck
# via sf-hamilton
# via streamlit
openai==1.30.1
# via burr
# via langchain-openai
orjson==3.10.3
# via fastapi
# via langsmith
packaging==23.2
# via altair
# via huggingface-hub
# via langchain-core
# via marshmallow
# via matplotlib
# via pytest
# via streamlit
pandas==2.2.2
# via altair
# via scrapegraphai
# via sf-hamilton
# via streamlit
pillow==10.3.0
# via matplotlib
# via streamlit
playwright==1.43.0
# via scrapegraphai
pluggy==1.5.0
@ -203,6 +293,9 @@ protobuf==4.25.3
# via googleapis-common-protos
# via grpcio-status
# via proto-plus
# via streamlit
pyarrow==16.1.0
# via streamlit
pyasn1==0.6.0
# via pyasn1-modules
# via rsa
@ -210,6 +303,9 @@ pyasn1-modules==0.4.0
# via google-auth
pydantic==2.7.1
# via anthropic
# via burr
# via fastapi
# via fastapi-pagination
# via google-generativeai
# via groq
# via langchain
@ -219,18 +315,27 @@ pydantic==2.7.1
# via yahoo-search-py
pydantic-core==2.18.2
# via pydantic
pydeck==0.9.1
# via streamlit
pyee==11.1.0
# via playwright
pygments==2.18.0
# via rich
pyparsing==3.1.2
# via httplib2
# via matplotlib
pytest==8.0.0
# via pytest-mock
pytest-mock==3.14.0
python-dateutil==2.9.0.post0
# via botocore
# via matplotlib
# via pandas
python-dotenv==1.0.1
# via scrapegraphai
# via uvicorn
python-multipart==0.0.9
# via fastapi
pytz==2024.1
# via pandas
pyyaml==6.0.1
@ -238,24 +343,42 @@ pyyaml==6.0.1
# via langchain
# via langchain-community
# via langchain-core
# via uvicorn
referencing==0.35.1
# via jsonschema
# via jsonschema-specifications
regex==2024.5.10
# via tiktoken
requests==2.31.0
# via burr
# via free-proxy
# via google-api-core
# via huggingface-hub
# via langchain
# via langchain-community
# via langsmith
# via streamlit
# via tiktoken
rich==13.7.1
# via streamlit
# via typer
rpds-py==0.18.1
# via jsonschema
# via referencing
rsa==4.9
# via google-auth
s3transfer==0.10.1
# via boto3
selectolax==0.3.21
# via yahoo-search-py
sf-hamilton==1.62.0
# via burr
shellingham==1.5.4
# via typer
six==1.16.0
# via python-dateutil
smmap==5.0.1
# via gitdb
sniffio==1.3.1
# via anthropic
# via anyio
@ -267,25 +390,41 @@ soupsieve==2.5
sqlalchemy==2.0.30
# via langchain
# via langchain-community
starlette==0.37.2
# via fastapi
streamlit==1.34.0
# via burr
tenacity==8.3.0
# via langchain
# via langchain-community
# via langchain-core
# via streamlit
tiktoken==0.6.0
# via langchain-openai
# via scrapegraphai
tokenizers==0.19.1
# via anthropic
toml==0.10.2
# via streamlit
tomli==2.0.1
# via pytest
toolz==0.12.1
# via altair
tornado==6.4
# via streamlit
tqdm==4.66.4
# via google-generativeai
# via huggingface-hub
# via openai
# via scrapegraphai
typer==0.12.3
# via fastapi-cli
typing-extensions==4.11.0
# via altair
# via anthropic
# via anyio
# via fastapi
# via fastapi-pagination
# via google-generativeai
# via groq
# via huggingface-hub
@ -293,18 +432,36 @@ typing-extensions==4.11.0
# via pydantic
# via pydantic-core
# via pyee
# via sf-hamilton
# via sqlalchemy
# via streamlit
# via typer
# via typing-inspect
# via uvicorn
typing-inspect==0.9.0
# via dataclasses-json
# via sf-hamilton
tzdata==2024.1
# via pandas
ujson==5.10.0
# via fastapi
uritemplate==4.1.1
# via google-api-python-client
urllib3==1.26.18
# via botocore
# via requests
# via yahoo-search-py
uvicorn==0.29.0
# via burr
# via fastapi
watchdog==4.0.0
# via streamlit
watchfiles==0.21.0
# via uvicorn
websockets==12.0
# via uvicorn
win32-setctime==1.1.0
# via loguru
yahoo-search-py==0.3
# via scrapegraphai
yarl==1.9.4

View File

@ -8,11 +8,15 @@
# with-sources: false
-e file:.
aiofiles==23.2.1
# via burr
aiohttp==3.9.5
# via langchain
# via langchain-community
aiosignal==1.3.1
# via aiohttp
altair==5.3.0
# via streamlit
annotated-types==0.6.0
# via pydantic
anthropic==0.25.9
@ -22,27 +26,50 @@ anyio==4.3.0
# via groq
# via httpx
# via openai
# via starlette
# via watchfiles
async-timeout==4.0.3
# via aiohttp
# via langchain
attrs==23.2.0
# via aiohttp
# via jsonschema
# via referencing
beautifulsoup4==4.12.3
# via google
# via scrapegraphai
blinker==1.8.2
# via streamlit
boto3==1.34.105
# via langchain-aws
botocore==1.34.105
# via boto3
# via s3transfer
burr==0.17.1
# via scrapegraphai
cachetools==5.3.3
# via google-auth
# via streamlit
certifi==2024.2.2
# via httpcore
# via httpx
# via requests
charset-normalizer==3.3.2
# via requests
click==8.1.7
# via burr
# via streamlit
# via typer
# via uvicorn
colorama==0.4.6
# via click
# via loguru
# via tqdm
# via uvicorn
contourpy==1.2.1
# via matplotlib
cycler==0.12.1
# via matplotlib
dataclasses-json==0.6.6
# via langchain
# via langchain-community
@ -52,12 +79,25 @@ distro==1.9.0
# via anthropic
# via groq
# via openai
dnspython==2.6.1
# via email-validator
email-validator==2.1.1
# via fastapi
exceptiongroup==1.2.1
# via anyio
faiss-cpu==1.8.0
# via scrapegraphai
fastapi==0.111.0
# via burr
# via fastapi-pagination
fastapi-cli==0.0.4
# via fastapi
fastapi-pagination==0.12.24
# via burr
filelock==3.14.0
# via huggingface-hub
fonttools==4.51.0
# via matplotlib
free-proxy==1.1.1
# via scrapegraphai
frozenlist==1.4.1
@ -65,6 +105,10 @@ frozenlist==1.4.1
# via aiosignal
fsspec==2024.3.1
# via huggingface-hub
gitdb==4.0.11
# via gitpython
gitpython==3.1.43
# via streamlit
google==3.0.0
# via scrapegraphai
google-ai-generativelanguage==0.6.3
@ -89,6 +133,7 @@ googleapis-common-protos==1.63.0
# via google-api-core
# via grpcio-status
graphviz==0.20.3
# via burr
# via scrapegraphai
greenlet==3.0.3
# via playwright
@ -102,6 +147,7 @@ grpcio-status==1.62.2
# via google-api-core
h11==0.14.0
# via httpcore
# via uvicorn
html2text==2024.2.26
# via scrapegraphai
httpcore==1.0.5
@ -109,8 +155,11 @@ httpcore==1.0.5
httplib2==0.22.0
# via google-api-python-client
# via google-auth-httplib2
httptools==0.6.1
# via uvicorn
httpx==0.27.0
# via anthropic
# via fastapi
# via groq
# via openai
# via yahoo-search-py
@ -118,9 +167,15 @@ huggingface-hub==0.23.0
# via tokenizers
idna==3.7
# via anyio
# via email-validator
# via httpx
# via requests
# via yarl
jinja2==3.1.4
# via altair
# via burr
# via fastapi
# via pydeck
jmespath==1.0.1
# via boto3
# via botocore
@ -129,6 +184,12 @@ jsonpatch==1.33
# via langchain-core
jsonpointer==2.4
# via jsonpatch
jsonschema==4.22.0
# via altair
jsonschema-specifications==2023.12.1
# via jsonschema
kiwisolver==1.4.5
# via matplotlib
langchain==0.1.15
# via scrapegraphai
langchain-anthropic==0.1.11
@ -158,10 +219,20 @@ langsmith==0.1.57
# via langchain
# via langchain-community
# via langchain-core
loguru==0.7.2
# via burr
lxml==5.2.2
# via free-proxy
markdown-it-py==3.0.0
# via rich
markupsafe==2.1.5
# via jinja2
marshmallow==3.21.2
# via dataclasses-json
matplotlib==3.9.0
# via burr
mdurl==0.1.2
# via markdown-it-py
minify-html==0.15.0
# via scrapegraphai
multidict==6.0.5
@ -170,21 +241,39 @@ multidict==6.0.5
mypy-extensions==1.0.0
# via typing-inspect
numpy==1.26.4
# via altair
# via contourpy
# via faiss-cpu
# via langchain
# via langchain-aws
# via langchain-community
# via matplotlib
# via pandas
# via pyarrow
# via pydeck
# via sf-hamilton
# via streamlit
openai==1.30.1
# via burr
# via langchain-openai
orjson==3.10.3
# via fastapi
# via langsmith
packaging==23.2
# via altair
# via huggingface-hub
# via langchain-core
# via marshmallow
# via matplotlib
# via streamlit
pandas==2.2.2
# via altair
# via scrapegraphai
# via sf-hamilton
# via streamlit
pillow==10.3.0
# via matplotlib
# via streamlit
playwright==1.43.0
# via scrapegraphai
proto-plus==1.23.0
@ -197,6 +286,9 @@ protobuf==4.25.3
# via googleapis-common-protos
# via grpcio-status
# via proto-plus
# via streamlit
pyarrow==16.1.0
# via streamlit
pyasn1==0.6.0
# via pyasn1-modules
# via rsa
@ -204,6 +296,9 @@ pyasn1-modules==0.4.0
# via google-auth
pydantic==2.7.1
# via anthropic
# via burr
# via fastapi
# via fastapi-pagination
# via google-generativeai
# via groq
# via langchain
@ -213,15 +308,24 @@ pydantic==2.7.1
# via yahoo-search-py
pydantic-core==2.18.2
# via pydantic
pydeck==0.9.1
# via streamlit
pyee==11.1.0
# via playwright
pygments==2.18.0
# via rich
pyparsing==3.1.2
# via httplib2
# via matplotlib
python-dateutil==2.9.0.post0
# via botocore
# via matplotlib
# via pandas
python-dotenv==1.0.1
# via scrapegraphai
# via uvicorn
python-multipart==0.0.9
# via fastapi
pytz==2024.1
# via pandas
pyyaml==6.0.1
@ -229,24 +333,42 @@ pyyaml==6.0.1
# via langchain
# via langchain-community
# via langchain-core
# via uvicorn
referencing==0.35.1
# via jsonschema
# via jsonschema-specifications
regex==2024.5.10
# via tiktoken
requests==2.31.0
# via burr
# via free-proxy
# via google-api-core
# via huggingface-hub
# via langchain
# via langchain-community
# via langsmith
# via streamlit
# via tiktoken
rich==13.7.1
# via streamlit
# via typer
rpds-py==0.18.1
# via jsonschema
# via referencing
rsa==4.9
# via google-auth
s3transfer==0.10.1
# via boto3
selectolax==0.3.21
# via yahoo-search-py
sf-hamilton==1.62.0
# via burr
shellingham==1.5.4
# via typer
six==1.16.0
# via python-dateutil
smmap==5.0.1
# via gitdb
sniffio==1.3.1
# via anthropic
# via anyio
@ -258,23 +380,39 @@ soupsieve==2.5
sqlalchemy==2.0.30
# via langchain
# via langchain-community
starlette==0.37.2
# via fastapi
streamlit==1.34.0
# via burr
tenacity==8.3.0
# via langchain
# via langchain-community
# via langchain-core
# via streamlit
tiktoken==0.6.0
# via langchain-openai
# via scrapegraphai
tokenizers==0.19.1
# via anthropic
toml==0.10.2
# via streamlit
toolz==0.12.1
# via altair
tornado==6.4
# via streamlit
tqdm==4.66.4
# via google-generativeai
# via huggingface-hub
# via openai
# via scrapegraphai
typer==0.12.3
# via fastapi-cli
typing-extensions==4.11.0
# via altair
# via anthropic
# via anyio
# via fastapi
# via fastapi-pagination
# via google-generativeai
# via groq
# via huggingface-hub
@ -282,18 +420,36 @@ typing-extensions==4.11.0
# via pydantic
# via pydantic-core
# via pyee
# via sf-hamilton
# via sqlalchemy
# via streamlit
# via typer
# via typing-inspect
# via uvicorn
typing-inspect==0.9.0
# via dataclasses-json
# via sf-hamilton
tzdata==2024.1
# via pandas
ujson==5.10.0
# via fastapi
uritemplate==4.1.1
# via google-api-python-client
urllib3==1.26.18
# via botocore
# via requests
# via yahoo-search-py
uvicorn==0.29.0
# via burr
# via fastapi
watchdog==4.0.0
# via streamlit
watchfiles==0.21.0
# via uvicorn
websockets==12.0
# via uvicorn
win32-setctime==1.1.0
# via loguru
yahoo-search-py==0.3
# via scrapegraphai
yarl==1.9.4

View File

@ -15,4 +15,3 @@ from .csv_scraper_graph import CSVScraperGraph
from .pdf_scraper_graph import PDFScraperGraph
from .omni_scraper_graph import OmniScraperGraph
from .omni_search_graph import OmniSearchGraph
from .turbo_scraper import TurboScraperGraph

View File

@ -7,6 +7,8 @@ import warnings
from langchain_community.callbacks import get_openai_callback
from typing import Tuple
from ..integrations import BurrBridge
class BaseGraph:
"""
@ -40,20 +42,27 @@ class BaseGraph:
... (parse_node, rag_node),
... (rag_node, generate_answer_node)
... ],
... entry_point=fetch_node
... entry_point=fetch_node,
... use_burr=True,
... burr_config={"app_instance_id": "example-instance"}
... )
"""
def __init__(self, nodes: list, edges: list, entry_point: str):
def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool = False, burr_config: dict = None):
self.nodes = nodes
self.edges = self._create_edges({e for e in edges})
self.entry_point = entry_point.node_name
self.initial_state = {}
if nodes[0].node_name != entry_point.node_name:
# raise a warning if the entry point is not the first node in the list
warnings.warn(
"Careful! The entry point node is different from the first node if the graph.")
# Burr configuration
self.use_burr = use_burr
self.burr_config = burr_config or {}
def _create_edges(self, edges: list) -> dict:
"""
@ -71,11 +80,9 @@ class BaseGraph:
edge_dict[from_node.node_name] = to_node.node_name
return edge_dict
def execute(self, initial_state: dict) -> Tuple[dict, list]:
def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
"""
Executes the graph by traversing nodes starting from the entry point. The execution
follows the edges based on the result of each node's execution and continues until
it reaches a node with no outgoing edges.
Executes the graph by traversing nodes starting from the entry point using the standard method.
Args:
initial_state (dict): The initial state to pass to the entry point node.
@ -83,8 +90,7 @@ class BaseGraph:
Returns:
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
"""
current_node_name = self.nodes[0]
current_node_name = self.entry_point
state = initial_state
# variables for tracking execution info
@ -98,18 +104,17 @@ class BaseGraph:
"total_cost_USD": 0.0,
}
for index in self.nodes:
while current_node_name:
curr_time = time.time()
current_node = index
current_node = next(node for node in self.nodes if node.node_name == current_node_name)
with get_openai_callback() as cb:
result = current_node.execute(state)
node_exec_time = time.time() - curr_time
total_exec_time += node_exec_time
cb = {
"node_name": index.node_name,
cb_data = {
"node_name": current_node.node_name,
"total_tokens": cb.total_tokens,
"prompt_tokens": cb.prompt_tokens,
"completion_tokens": cb.completion_tokens,
@ -118,15 +123,13 @@ class BaseGraph:
"exec_time": node_exec_time,
}
exec_info.append(
cb
)
exec_info.append(cb_data)
cb_total["total_tokens"] += cb["total_tokens"]
cb_total["prompt_tokens"] += cb["prompt_tokens"]
cb_total["completion_tokens"] += cb["completion_tokens"]
cb_total["successful_requests"] += cb["successful_requests"]
cb_total["total_cost_USD"] += cb["total_cost_USD"]
cb_total["total_tokens"] += cb_data["total_tokens"]
cb_total["prompt_tokens"] += cb_data["prompt_tokens"]
cb_total["completion_tokens"] += cb_data["completion_tokens"]
cb_total["successful_requests"] += cb_data["successful_requests"]
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]
if current_node.node_type == "conditional_node":
current_node_name = result
@ -137,12 +140,30 @@ class BaseGraph:
exec_info.append({
"node_name": "TOTAL RESULT",
"total_tokens": cb_total["total_tokens"],
"prompt_tokens": cb_total["prompt_tokens"],
"total_tokens": cb_total["total_tokens"],
"prompt_tokens": cb_total["prompt_tokens"],
"completion_tokens": cb_total["completion_tokens"],
"successful_requests": cb_total["successful_requests"],
"total_cost_USD": cb_total["total_cost_USD"],
"total_cost_USD": cb_total["total_cost_USD"],
"exec_time": total_exec_time,
})
return state, exec_info
def execute(self, initial_state: dict) -> Tuple[dict, list]:
"""
Executes the graph by either using BurrBridge or the standard method.
Args:
initial_state (dict): The initial state to pass to the entry point node.
Returns:
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
"""
self.initial_state = initial_state
if self.use_burr:
bridge = BurrBridge(self, self.burr_config)
return bridge.execute(initial_state)
else:
return self._execute_standard(initial_state)

View File

@ -25,7 +25,7 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
from tqdm import tqdm
if __name__ == '__main__':
from scrapegraphai.utils.remover import remover
from scrapegraphai.utils import cleanup_html
else:
from ..utils.remover import remover

View File

@ -1,146 +0,0 @@
"""
SmartScraperGraph Module
"""
from .base_graph import BaseGraph
from ..nodes import (
FetchNode,
ParseNode,
RAGNode,
SearchLinksWithContext,
GraphIteratorNode,
MergeAnswersNode
)
from .search_graph import SearchGraph
from .abstract_graph import AbstractGraph
class SmartScraperGraph(AbstractGraph):
"""
SmartScraper is a scraping pipeline that automates the process of
extracting information from web pages
using a natural language model to interpret and answer prompts.
Attributes:
prompt (str): The prompt for the graph.
source (str): The source of the graph.
config (dict): Configuration parameters for the graph.
llm_model: An instance of a language model client, configured for generating answers.
embedder_model: An instance of an embedding model client,
configured for generating embeddings.
verbose (bool): A flag indicating whether to show print statements during execution.
headless (bool): A flag indicating whether to run the graph in headless mode.
Args:
prompt (str): The prompt for the graph.
source (str): The source of the graph.
config (dict): Configuration parameters for the graph.
Example:
>>> smart_scraper = SmartScraperGraph(
... "List me all the attractions in Chioggia.",
... "https://en.wikipedia.org/wiki/Chioggia",
... {"llm": {"model": "gpt-3.5-turbo"}}
... )
>>> result = smart_scraper.run()
)
"""
def __init__(self, prompt: str, source: str, config: dict):
super().__init__(prompt, config, source)
self.input_key = "url" if source.startswith("http") else "local_dir"
def _create_graph(self) -> BaseGraph:
"""
Creates the graph of nodes representing the workflow for web scraping.
Returns:
BaseGraph: A graph instance representing the web scraping workflow.
"""
smart_scraper_graph = SmartScraperGraph(
prompt="",
source="",
config=self.llm_model
)
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"]
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
search_link_with_context_node = SearchLinksWithContext(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm_model": self.llm_model
}
)
graph_iterator_node = GraphIteratorNode(
input="user_prompt & urls",
output=["results"],
node_config={
"graph_instance": smart_scraper_graph,
"verbose": True,
}
)
merge_answers_node = MergeAnswersNode(
input="user_prompt & results",
output=["answer"],
node_config={
"llm_model": self.llm_model,
"verbose": True,
}
)
return BaseGraph(
nodes=[
fetch_node,
parse_node,
rag_node,
search_link_with_context_node,
graph_iterator_node,
merge_answers_node
],
edges=[
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, search_link_with_context_node),
(search_link_with_context_node, graph_iterator_node),
(graph_iterator_node, merge_answers_node),
],
entry_point=fetch_node
)
def run(self) -> str:
"""
Executes the scraping process and returns the answer to the prompt.
Returns:
str: The answer to the prompt.
"""
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)
return self.final_state.get("answer", "No answer found.")

View File

@ -0,0 +1 @@
from .burr_bridge import BurrBridge

View File

@ -0,0 +1,198 @@
"""
Bridge class to integrate Burr into ScrapeGraphAI graphs
[Burr](https://github.com/DAGWorks-Inc/burr)
"""
import re
from typing import Any, Dict, List, Tuple
from burr import tracking
from burr.core import Application, ApplicationBuilder, State, Action, default
from burr.core.action import action
from burr.lifecycle import PostRunStepHook, PreRunStepHook
class PrintLnHook(PostRunStepHook, PreRunStepHook):
"""
Hook to print the action name before and after it is executed.
"""
def pre_run_step(self, *, state: "State", action: "Action", **future_kwargs: Any):
print(f"Starting action: {action.name}")
def post_run_step(self, *, state: "State", action: "Action", **future_kwargs: Any):
print(f"Finishing action: {action.name}")
class BurrBridge:
"""
Bridge class to integrate Burr into ScrapeGraphAI graphs.
Args:
base_graph (BaseGraph): The base graph to convert to a Burr application.
burr_config (dict): Configuration parameters for the Burr application.
Attributes:
base_graph (BaseGraph): The base graph to convert to a Burr application.
burr_config (dict): Configuration parameters for the Burr application.
tracker (LocalTrackingClient): The tracking client for the Burr application.
app_instance_id (str): The instance ID for the Burr application.
burr_inputs (dict): The inputs for the Burr application.
burr_app (Application): The Burr application instance.
Example:
>>> burr_bridge = BurrBridge(base_graph, burr_config)
>>> result = burr_bridge.execute(initial_state={"input_key": "input_value"})
"""
def __init__(self, base_graph, burr_config):
self.base_graph = base_graph
self.burr_config = burr_config
self.tracker = tracking.LocalTrackingClient(project="smart-scraper-graph")
self.app_instance_id = burr_config.get("app_instance_id", "default-instance")
self.burr_inputs = burr_config.get("inputs", {})
self.burr_app = None
def _initialize_burr_app(self, initial_state: Dict[str, Any] = {}) -> Application:
"""
Initialize a Burr application from the base graph.
Args:
initial_state (dict): The initial state of the Burr application.
Returns:
Application: The Burr application instance.
"""
actions = self._create_actions()
transitions = self._create_transitions()
hooks = [PrintLnHook()]
burr_state = self._convert_state_to_burr(initial_state)
app = (
ApplicationBuilder()
.with_actions(**actions)
.with_transitions(*transitions)
.with_entrypoint(self.base_graph.entry_point)
.with_state(**burr_state)
.with_identifiers(app_id=self.app_instance_id)
.with_tracker(self.tracker)
.with_hooks(*hooks)
.build()
)
return app
def _create_actions(self) -> Dict[str, Any]:
"""
Create Burr actions from the base graph nodes.
Returns:
dict: A dictionary of Burr actions with the node name as keys and the action functions as values.
"""
actions = {}
for node in self.base_graph.nodes:
action_func = self._create_action(node)
actions[node.node_name] = action_func
return actions
def _create_action(self, node) -> Any:
"""
Create a Burr action function from a base graph node.
Args:
node (Node): The base graph node to convert to a Burr action.
Returns:
function: The Burr action function.
"""
@action(reads=self._parse_boolean_expression(node.input), writes=node.output)
def dynamic_action(state: State, **kwargs):
node_inputs = {key: state[key] for key in self._parse_boolean_expression(node.input)}
result_state = node.execute(node_inputs, **kwargs)
return result_state, state.update(**result_state)
return dynamic_action
def _create_transitions(self) -> List[Tuple[str, str, Any]]:
"""
Create Burr transitions from the base graph edges.
Returns:
list: A list of tuples representing the transitions between Burr actions.
"""
transitions = []
for from_node, to_node in self.base_graph.edges.items():
transitions.append((from_node, to_node, default))
return transitions
def _parse_boolean_expression(self, expression: str) -> List[str]:
"""
Parse a boolean expression to extract the keys used in the expression, without boolean operators.
Args:
expression (str): The boolean expression to parse.
Returns:
list: A list of unique keys used in the expression.
"""
# Use regular expression to extract all unique keys
keys = re.findall(r'\w+', expression)
return list(set(keys)) # Remove duplicates
def _convert_state_to_burr(self, state: Dict[str, Any]) -> State:
"""
Convert a dictionary state to a Burr state.
Args:
state (dict): The dictionary state to convert.
Returns:
State: The Burr state instance.
"""
burr_state = State()
for key, value in state.items():
setattr(burr_state, key, value)
return burr_state
def _convert_state_from_burr(self, burr_state: State) -> Dict[str, Any]:
"""
Convert a Burr state to a dictionary state.
Args:
burr_state (State): The Burr state to convert.
Returns:
dict: The dictionary state instance.
"""
state = {}
for key in burr_state.__dict__.keys():
state[key] = getattr(burr_state, key)
return state
def execute(self, initial_state: Dict[str, Any] = {}) -> Dict[str, Any]:
"""
Execute the Burr application with the given initial state.
Args:
initial_state (dict): The initial state to pass to the Burr application.
Returns:
dict: The final state of the Burr application.
"""
self.burr_app = self._initialize_burr_app(initial_state)
# TODO: to fix final nodes detection
final_nodes = [self.burr_app.graph.actions[-1].name]
# TODO: fix inputs
last_action, result, final_state = self.burr_app.run(
halt_after=final_nodes,
inputs=self.burr_inputs
)
return self._convert_state_from_burr(final_state)