mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
feat(burr-bridge): BurrBridge class to integrate inside BaseGraph
This commit is contained in:
parent
0b5cdd48fd
commit
6cbd84f254
@ -1 +0,0 @@
|
||||
3.9.19
|
||||
109
examples/openai/burr_integration_openai.py
Normal file
109
examples/openai/burr_integration_openai.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""
|
||||
Example of custom graph using existing nodes
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from scrapegraphai.models import OpenAI
|
||||
from scrapegraphai.graphs import BaseGraph
|
||||
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode
|
||||
load_dotenv()
|
||||
|
||||
# ************************************************
|
||||
# Define the configuration for the graph
|
||||
# ************************************************
|
||||
|
||||
openai_key = os.getenv("OPENAI_APIKEY")
|
||||
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"api_key": openai_key,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"temperature": 0,
|
||||
"streaming": False
|
||||
},
|
||||
}
|
||||
|
||||
# ************************************************
|
||||
# Define the graph nodes
|
||||
# ************************************************
|
||||
|
||||
llm_model = OpenAI(graph_config["llm"])
|
||||
embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key)
|
||||
|
||||
# define the nodes for the graph
|
||||
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc", "link_urls", "img_urls"],
|
||||
node_config={
|
||||
"verbose": True,
|
||||
"headless": True,
|
||||
}
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={
|
||||
"chunk_size": 4096,
|
||||
"verbose": True,
|
||||
}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
node_config={
|
||||
"llm_model": llm_model,
|
||||
"embedder_model": embedder,
|
||||
"verbose": True,
|
||||
}
|
||||
)
|
||||
generate_answer_node = GenerateAnswerNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": llm_model,
|
||||
"verbose": True,
|
||||
}
|
||||
)
|
||||
|
||||
# ************************************************
|
||||
# Create the graph by defining the connections
|
||||
# ************************************************
|
||||
|
||||
graph = BaseGraph(
|
||||
nodes=[
|
||||
fetch_node,
|
||||
parse_node,
|
||||
rag_node,
|
||||
generate_answer_node,
|
||||
],
|
||||
edges=[
|
||||
(fetch_node, parse_node),
|
||||
(parse_node, rag_node),
|
||||
(rag_node, generate_answer_node)
|
||||
],
|
||||
entry_point=fetch_node,
|
||||
use_burr=True,
|
||||
burr_config={
|
||||
"app_instance_id": "custom_graph_openai",
|
||||
"inputs": {
|
||||
"llm_model": graph_config["llm"].get("model", "gpt-3.5-turbo"),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# ************************************************
|
||||
# Execute the graph
|
||||
# ************************************************
|
||||
|
||||
result, execution_info = graph.execute({
|
||||
"user_prompt": "Describe the content",
|
||||
"url": "https://example.com/"
|
||||
})
|
||||
|
||||
# get the answer from the result
|
||||
result = result.get("answer", "No answer found.")
|
||||
print(result)
|
||||
@ -29,6 +29,7 @@ dependencies = [
|
||||
"playwright==1.43.0",
|
||||
"google==3.0.0",
|
||||
"yahoo-search-py==0.3",
|
||||
"burr[start]"
|
||||
]
|
||||
|
||||
license = "MIT"
|
||||
|
||||
@ -8,11 +8,15 @@
|
||||
# with-sources: false
|
||||
|
||||
-e file:.
|
||||
aiofiles==23.2.1
|
||||
# via burr
|
||||
aiohttp==3.9.5
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
altair==5.3.0
|
||||
# via streamlit
|
||||
annotated-types==0.6.0
|
||||
# via pydantic
|
||||
anthropic==0.25.9
|
||||
@ -22,27 +26,51 @@ anyio==4.3.0
|
||||
# via groq
|
||||
# via httpx
|
||||
# via openai
|
||||
# via starlette
|
||||
# via watchfiles
|
||||
async-timeout==4.0.3
|
||||
# via aiohttp
|
||||
# via langchain
|
||||
attrs==23.2.0
|
||||
# via aiohttp
|
||||
# via jsonschema
|
||||
# via referencing
|
||||
beautifulsoup4==4.12.3
|
||||
# via google
|
||||
# via scrapegraphai
|
||||
blinker==1.8.2
|
||||
# via streamlit
|
||||
boto3==1.34.105
|
||||
# via langchain-aws
|
||||
botocore==1.34.105
|
||||
# via boto3
|
||||
# via s3transfer
|
||||
burr==0.17.1
|
||||
# via scrapegraphai
|
||||
cachetools==5.3.3
|
||||
# via google-auth
|
||||
# via streamlit
|
||||
certifi==2024.2.2
|
||||
# via httpcore
|
||||
# via httpx
|
||||
# via requests
|
||||
charset-normalizer==3.3.2
|
||||
# via requests
|
||||
click==8.1.7
|
||||
# via burr
|
||||
# via streamlit
|
||||
# via typer
|
||||
# via uvicorn
|
||||
colorama==0.4.6
|
||||
# via click
|
||||
# via loguru
|
||||
# via pytest
|
||||
# via tqdm
|
||||
# via uvicorn
|
||||
contourpy==1.2.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
dataclasses-json==0.6.6
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
@ -52,13 +80,26 @@ distro==1.9.0
|
||||
# via anthropic
|
||||
# via groq
|
||||
# via openai
|
||||
dnspython==2.6.1
|
||||
# via email-validator
|
||||
email-validator==2.1.1
|
||||
# via fastapi
|
||||
exceptiongroup==1.2.1
|
||||
# via anyio
|
||||
# via pytest
|
||||
faiss-cpu==1.8.0
|
||||
# via scrapegraphai
|
||||
fastapi==0.111.0
|
||||
# via burr
|
||||
# via fastapi-pagination
|
||||
fastapi-cli==0.0.4
|
||||
# via fastapi
|
||||
fastapi-pagination==0.12.24
|
||||
# via burr
|
||||
filelock==3.14.0
|
||||
# via huggingface-hub
|
||||
fonttools==4.51.0
|
||||
# via matplotlib
|
||||
free-proxy==1.1.1
|
||||
# via scrapegraphai
|
||||
frozenlist==1.4.1
|
||||
@ -66,6 +107,10 @@ frozenlist==1.4.1
|
||||
# via aiosignal
|
||||
fsspec==2024.3.1
|
||||
# via huggingface-hub
|
||||
gitdb==4.0.11
|
||||
# via gitpython
|
||||
gitpython==3.1.43
|
||||
# via streamlit
|
||||
google==3.0.0
|
||||
# via scrapegraphai
|
||||
google-ai-generativelanguage==0.6.3
|
||||
@ -90,6 +135,7 @@ googleapis-common-protos==1.63.0
|
||||
# via google-api-core
|
||||
# via grpcio-status
|
||||
graphviz==0.20.3
|
||||
# via burr
|
||||
# via scrapegraphai
|
||||
greenlet==3.0.3
|
||||
# via playwright
|
||||
@ -103,6 +149,7 @@ grpcio-status==1.62.2
|
||||
# via google-api-core
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
# via uvicorn
|
||||
html2text==2024.2.26
|
||||
# via scrapegraphai
|
||||
httpcore==1.0.5
|
||||
@ -110,8 +157,11 @@ httpcore==1.0.5
|
||||
httplib2==0.22.0
|
||||
# via google-api-python-client
|
||||
# via google-auth-httplib2
|
||||
httptools==0.6.1
|
||||
# via uvicorn
|
||||
httpx==0.27.0
|
||||
# via anthropic
|
||||
# via fastapi
|
||||
# via groq
|
||||
# via openai
|
||||
# via yahoo-search-py
|
||||
@ -119,11 +169,17 @@ huggingface-hub==0.23.0
|
||||
# via tokenizers
|
||||
idna==3.7
|
||||
# via anyio
|
||||
# via email-validator
|
||||
# via httpx
|
||||
# via requests
|
||||
# via yarl
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
jinja2==3.1.4
|
||||
# via altair
|
||||
# via burr
|
||||
# via fastapi
|
||||
# via pydeck
|
||||
jmespath==1.0.1
|
||||
# via boto3
|
||||
# via botocore
|
||||
@ -132,6 +188,12 @@ jsonpatch==1.33
|
||||
# via langchain-core
|
||||
jsonpointer==2.4
|
||||
# via jsonpatch
|
||||
jsonschema==4.22.0
|
||||
# via altair
|
||||
jsonschema-specifications==2023.12.1
|
||||
# via jsonschema
|
||||
kiwisolver==1.4.5
|
||||
# via matplotlib
|
||||
langchain==0.1.15
|
||||
# via scrapegraphai
|
||||
langchain-anthropic==0.1.11
|
||||
@ -161,10 +223,20 @@ langsmith==0.1.57
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
loguru==0.7.2
|
||||
# via burr
|
||||
lxml==5.2.2
|
||||
# via free-proxy
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==2.1.5
|
||||
# via jinja2
|
||||
marshmallow==3.21.2
|
||||
# via dataclasses-json
|
||||
matplotlib==3.9.0
|
||||
# via burr
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
minify-html==0.15.0
|
||||
# via scrapegraphai
|
||||
multidict==6.0.5
|
||||
@ -173,22 +245,40 @@ multidict==6.0.5
|
||||
mypy-extensions==1.0.0
|
||||
# via typing-inspect
|
||||
numpy==1.26.4
|
||||
# via altair
|
||||
# via contourpy
|
||||
# via faiss-cpu
|
||||
# via langchain
|
||||
# via langchain-aws
|
||||
# via langchain-community
|
||||
# via matplotlib
|
||||
# via pandas
|
||||
# via pyarrow
|
||||
# via pydeck
|
||||
# via sf-hamilton
|
||||
# via streamlit
|
||||
openai==1.30.1
|
||||
# via burr
|
||||
# via langchain-openai
|
||||
orjson==3.10.3
|
||||
# via fastapi
|
||||
# via langsmith
|
||||
packaging==23.2
|
||||
# via altair
|
||||
# via huggingface-hub
|
||||
# via langchain-core
|
||||
# via marshmallow
|
||||
# via matplotlib
|
||||
# via pytest
|
||||
# via streamlit
|
||||
pandas==2.2.2
|
||||
# via altair
|
||||
# via scrapegraphai
|
||||
# via sf-hamilton
|
||||
# via streamlit
|
||||
pillow==10.3.0
|
||||
# via matplotlib
|
||||
# via streamlit
|
||||
playwright==1.43.0
|
||||
# via scrapegraphai
|
||||
pluggy==1.5.0
|
||||
@ -203,6 +293,9 @@ protobuf==4.25.3
|
||||
# via googleapis-common-protos
|
||||
# via grpcio-status
|
||||
# via proto-plus
|
||||
# via streamlit
|
||||
pyarrow==16.1.0
|
||||
# via streamlit
|
||||
pyasn1==0.6.0
|
||||
# via pyasn1-modules
|
||||
# via rsa
|
||||
@ -210,6 +303,9 @@ pyasn1-modules==0.4.0
|
||||
# via google-auth
|
||||
pydantic==2.7.1
|
||||
# via anthropic
|
||||
# via burr
|
||||
# via fastapi
|
||||
# via fastapi-pagination
|
||||
# via google-generativeai
|
||||
# via groq
|
||||
# via langchain
|
||||
@ -219,18 +315,27 @@ pydantic==2.7.1
|
||||
# via yahoo-search-py
|
||||
pydantic-core==2.18.2
|
||||
# via pydantic
|
||||
pydeck==0.9.1
|
||||
# via streamlit
|
||||
pyee==11.1.0
|
||||
# via playwright
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyparsing==3.1.2
|
||||
# via httplib2
|
||||
# via matplotlib
|
||||
pytest==8.0.0
|
||||
# via pytest-mock
|
||||
pytest-mock==3.14.0
|
||||
python-dateutil==2.9.0.post0
|
||||
# via botocore
|
||||
# via matplotlib
|
||||
# via pandas
|
||||
python-dotenv==1.0.1
|
||||
# via scrapegraphai
|
||||
# via uvicorn
|
||||
python-multipart==0.0.9
|
||||
# via fastapi
|
||||
pytz==2024.1
|
||||
# via pandas
|
||||
pyyaml==6.0.1
|
||||
@ -238,24 +343,42 @@ pyyaml==6.0.1
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
# via uvicorn
|
||||
referencing==0.35.1
|
||||
# via jsonschema
|
||||
# via jsonschema-specifications
|
||||
regex==2024.5.10
|
||||
# via tiktoken
|
||||
requests==2.31.0
|
||||
# via burr
|
||||
# via free-proxy
|
||||
# via google-api-core
|
||||
# via huggingface-hub
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langsmith
|
||||
# via streamlit
|
||||
# via tiktoken
|
||||
rich==13.7.1
|
||||
# via streamlit
|
||||
# via typer
|
||||
rpds-py==0.18.1
|
||||
# via jsonschema
|
||||
# via referencing
|
||||
rsa==4.9
|
||||
# via google-auth
|
||||
s3transfer==0.10.1
|
||||
# via boto3
|
||||
selectolax==0.3.21
|
||||
# via yahoo-search-py
|
||||
sf-hamilton==1.62.0
|
||||
# via burr
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.16.0
|
||||
# via python-dateutil
|
||||
smmap==5.0.1
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anthropic
|
||||
# via anyio
|
||||
@ -267,25 +390,41 @@ soupsieve==2.5
|
||||
sqlalchemy==2.0.30
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
starlette==0.37.2
|
||||
# via fastapi
|
||||
streamlit==1.34.0
|
||||
# via burr
|
||||
tenacity==8.3.0
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
# via streamlit
|
||||
tiktoken==0.6.0
|
||||
# via langchain-openai
|
||||
# via scrapegraphai
|
||||
tokenizers==0.19.1
|
||||
# via anthropic
|
||||
toml==0.10.2
|
||||
# via streamlit
|
||||
tomli==2.0.1
|
||||
# via pytest
|
||||
toolz==0.12.1
|
||||
# via altair
|
||||
tornado==6.4
|
||||
# via streamlit
|
||||
tqdm==4.66.4
|
||||
# via google-generativeai
|
||||
# via huggingface-hub
|
||||
# via openai
|
||||
# via scrapegraphai
|
||||
typer==0.12.3
|
||||
# via fastapi-cli
|
||||
typing-extensions==4.11.0
|
||||
# via altair
|
||||
# via anthropic
|
||||
# via anyio
|
||||
# via fastapi
|
||||
# via fastapi-pagination
|
||||
# via google-generativeai
|
||||
# via groq
|
||||
# via huggingface-hub
|
||||
@ -293,18 +432,36 @@ typing-extensions==4.11.0
|
||||
# via pydantic
|
||||
# via pydantic-core
|
||||
# via pyee
|
||||
# via sf-hamilton
|
||||
# via sqlalchemy
|
||||
# via streamlit
|
||||
# via typer
|
||||
# via typing-inspect
|
||||
# via uvicorn
|
||||
typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
# via sf-hamilton
|
||||
tzdata==2024.1
|
||||
# via pandas
|
||||
ujson==5.10.0
|
||||
# via fastapi
|
||||
uritemplate==4.1.1
|
||||
# via google-api-python-client
|
||||
urllib3==1.26.18
|
||||
# via botocore
|
||||
# via requests
|
||||
# via yahoo-search-py
|
||||
uvicorn==0.29.0
|
||||
# via burr
|
||||
# via fastapi
|
||||
watchdog==4.0.0
|
||||
# via streamlit
|
||||
watchfiles==0.21.0
|
||||
# via uvicorn
|
||||
websockets==12.0
|
||||
# via uvicorn
|
||||
win32-setctime==1.1.0
|
||||
# via loguru
|
||||
yahoo-search-py==0.3
|
||||
# via scrapegraphai
|
||||
yarl==1.9.4
|
||||
|
||||
@ -8,11 +8,15 @@
|
||||
# with-sources: false
|
||||
|
||||
-e file:.
|
||||
aiofiles==23.2.1
|
||||
# via burr
|
||||
aiohttp==3.9.5
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
altair==5.3.0
|
||||
# via streamlit
|
||||
annotated-types==0.6.0
|
||||
# via pydantic
|
||||
anthropic==0.25.9
|
||||
@ -22,27 +26,50 @@ anyio==4.3.0
|
||||
# via groq
|
||||
# via httpx
|
||||
# via openai
|
||||
# via starlette
|
||||
# via watchfiles
|
||||
async-timeout==4.0.3
|
||||
# via aiohttp
|
||||
# via langchain
|
||||
attrs==23.2.0
|
||||
# via aiohttp
|
||||
# via jsonschema
|
||||
# via referencing
|
||||
beautifulsoup4==4.12.3
|
||||
# via google
|
||||
# via scrapegraphai
|
||||
blinker==1.8.2
|
||||
# via streamlit
|
||||
boto3==1.34.105
|
||||
# via langchain-aws
|
||||
botocore==1.34.105
|
||||
# via boto3
|
||||
# via s3transfer
|
||||
burr==0.17.1
|
||||
# via scrapegraphai
|
||||
cachetools==5.3.3
|
||||
# via google-auth
|
||||
# via streamlit
|
||||
certifi==2024.2.2
|
||||
# via httpcore
|
||||
# via httpx
|
||||
# via requests
|
||||
charset-normalizer==3.3.2
|
||||
# via requests
|
||||
click==8.1.7
|
||||
# via burr
|
||||
# via streamlit
|
||||
# via typer
|
||||
# via uvicorn
|
||||
colorama==0.4.6
|
||||
# via click
|
||||
# via loguru
|
||||
# via tqdm
|
||||
# via uvicorn
|
||||
contourpy==1.2.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
dataclasses-json==0.6.6
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
@ -52,12 +79,25 @@ distro==1.9.0
|
||||
# via anthropic
|
||||
# via groq
|
||||
# via openai
|
||||
dnspython==2.6.1
|
||||
# via email-validator
|
||||
email-validator==2.1.1
|
||||
# via fastapi
|
||||
exceptiongroup==1.2.1
|
||||
# via anyio
|
||||
faiss-cpu==1.8.0
|
||||
# via scrapegraphai
|
||||
fastapi==0.111.0
|
||||
# via burr
|
||||
# via fastapi-pagination
|
||||
fastapi-cli==0.0.4
|
||||
# via fastapi
|
||||
fastapi-pagination==0.12.24
|
||||
# via burr
|
||||
filelock==3.14.0
|
||||
# via huggingface-hub
|
||||
fonttools==4.51.0
|
||||
# via matplotlib
|
||||
free-proxy==1.1.1
|
||||
# via scrapegraphai
|
||||
frozenlist==1.4.1
|
||||
@ -65,6 +105,10 @@ frozenlist==1.4.1
|
||||
# via aiosignal
|
||||
fsspec==2024.3.1
|
||||
# via huggingface-hub
|
||||
gitdb==4.0.11
|
||||
# via gitpython
|
||||
gitpython==3.1.43
|
||||
# via streamlit
|
||||
google==3.0.0
|
||||
# via scrapegraphai
|
||||
google-ai-generativelanguage==0.6.3
|
||||
@ -89,6 +133,7 @@ googleapis-common-protos==1.63.0
|
||||
# via google-api-core
|
||||
# via grpcio-status
|
||||
graphviz==0.20.3
|
||||
# via burr
|
||||
# via scrapegraphai
|
||||
greenlet==3.0.3
|
||||
# via playwright
|
||||
@ -102,6 +147,7 @@ grpcio-status==1.62.2
|
||||
# via google-api-core
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
# via uvicorn
|
||||
html2text==2024.2.26
|
||||
# via scrapegraphai
|
||||
httpcore==1.0.5
|
||||
@ -109,8 +155,11 @@ httpcore==1.0.5
|
||||
httplib2==0.22.0
|
||||
# via google-api-python-client
|
||||
# via google-auth-httplib2
|
||||
httptools==0.6.1
|
||||
# via uvicorn
|
||||
httpx==0.27.0
|
||||
# via anthropic
|
||||
# via fastapi
|
||||
# via groq
|
||||
# via openai
|
||||
# via yahoo-search-py
|
||||
@ -118,9 +167,15 @@ huggingface-hub==0.23.0
|
||||
# via tokenizers
|
||||
idna==3.7
|
||||
# via anyio
|
||||
# via email-validator
|
||||
# via httpx
|
||||
# via requests
|
||||
# via yarl
|
||||
jinja2==3.1.4
|
||||
# via altair
|
||||
# via burr
|
||||
# via fastapi
|
||||
# via pydeck
|
||||
jmespath==1.0.1
|
||||
# via boto3
|
||||
# via botocore
|
||||
@ -129,6 +184,12 @@ jsonpatch==1.33
|
||||
# via langchain-core
|
||||
jsonpointer==2.4
|
||||
# via jsonpatch
|
||||
jsonschema==4.22.0
|
||||
# via altair
|
||||
jsonschema-specifications==2023.12.1
|
||||
# via jsonschema
|
||||
kiwisolver==1.4.5
|
||||
# via matplotlib
|
||||
langchain==0.1.15
|
||||
# via scrapegraphai
|
||||
langchain-anthropic==0.1.11
|
||||
@ -158,10 +219,20 @@ langsmith==0.1.57
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
loguru==0.7.2
|
||||
# via burr
|
||||
lxml==5.2.2
|
||||
# via free-proxy
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==2.1.5
|
||||
# via jinja2
|
||||
marshmallow==3.21.2
|
||||
# via dataclasses-json
|
||||
matplotlib==3.9.0
|
||||
# via burr
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
minify-html==0.15.0
|
||||
# via scrapegraphai
|
||||
multidict==6.0.5
|
||||
@ -170,21 +241,39 @@ multidict==6.0.5
|
||||
mypy-extensions==1.0.0
|
||||
# via typing-inspect
|
||||
numpy==1.26.4
|
||||
# via altair
|
||||
# via contourpy
|
||||
# via faiss-cpu
|
||||
# via langchain
|
||||
# via langchain-aws
|
||||
# via langchain-community
|
||||
# via matplotlib
|
||||
# via pandas
|
||||
# via pyarrow
|
||||
# via pydeck
|
||||
# via sf-hamilton
|
||||
# via streamlit
|
||||
openai==1.30.1
|
||||
# via burr
|
||||
# via langchain-openai
|
||||
orjson==3.10.3
|
||||
# via fastapi
|
||||
# via langsmith
|
||||
packaging==23.2
|
||||
# via altair
|
||||
# via huggingface-hub
|
||||
# via langchain-core
|
||||
# via marshmallow
|
||||
# via matplotlib
|
||||
# via streamlit
|
||||
pandas==2.2.2
|
||||
# via altair
|
||||
# via scrapegraphai
|
||||
# via sf-hamilton
|
||||
# via streamlit
|
||||
pillow==10.3.0
|
||||
# via matplotlib
|
||||
# via streamlit
|
||||
playwright==1.43.0
|
||||
# via scrapegraphai
|
||||
proto-plus==1.23.0
|
||||
@ -197,6 +286,9 @@ protobuf==4.25.3
|
||||
# via googleapis-common-protos
|
||||
# via grpcio-status
|
||||
# via proto-plus
|
||||
# via streamlit
|
||||
pyarrow==16.1.0
|
||||
# via streamlit
|
||||
pyasn1==0.6.0
|
||||
# via pyasn1-modules
|
||||
# via rsa
|
||||
@ -204,6 +296,9 @@ pyasn1-modules==0.4.0
|
||||
# via google-auth
|
||||
pydantic==2.7.1
|
||||
# via anthropic
|
||||
# via burr
|
||||
# via fastapi
|
||||
# via fastapi-pagination
|
||||
# via google-generativeai
|
||||
# via groq
|
||||
# via langchain
|
||||
@ -213,15 +308,24 @@ pydantic==2.7.1
|
||||
# via yahoo-search-py
|
||||
pydantic-core==2.18.2
|
||||
# via pydantic
|
||||
pydeck==0.9.1
|
||||
# via streamlit
|
||||
pyee==11.1.0
|
||||
# via playwright
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyparsing==3.1.2
|
||||
# via httplib2
|
||||
# via matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via botocore
|
||||
# via matplotlib
|
||||
# via pandas
|
||||
python-dotenv==1.0.1
|
||||
# via scrapegraphai
|
||||
# via uvicorn
|
||||
python-multipart==0.0.9
|
||||
# via fastapi
|
||||
pytz==2024.1
|
||||
# via pandas
|
||||
pyyaml==6.0.1
|
||||
@ -229,24 +333,42 @@ pyyaml==6.0.1
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
# via uvicorn
|
||||
referencing==0.35.1
|
||||
# via jsonschema
|
||||
# via jsonschema-specifications
|
||||
regex==2024.5.10
|
||||
# via tiktoken
|
||||
requests==2.31.0
|
||||
# via burr
|
||||
# via free-proxy
|
||||
# via google-api-core
|
||||
# via huggingface-hub
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langsmith
|
||||
# via streamlit
|
||||
# via tiktoken
|
||||
rich==13.7.1
|
||||
# via streamlit
|
||||
# via typer
|
||||
rpds-py==0.18.1
|
||||
# via jsonschema
|
||||
# via referencing
|
||||
rsa==4.9
|
||||
# via google-auth
|
||||
s3transfer==0.10.1
|
||||
# via boto3
|
||||
selectolax==0.3.21
|
||||
# via yahoo-search-py
|
||||
sf-hamilton==1.62.0
|
||||
# via burr
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.16.0
|
||||
# via python-dateutil
|
||||
smmap==5.0.1
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via anthropic
|
||||
# via anyio
|
||||
@ -258,23 +380,39 @@ soupsieve==2.5
|
||||
sqlalchemy==2.0.30
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
starlette==0.37.2
|
||||
# via fastapi
|
||||
streamlit==1.34.0
|
||||
# via burr
|
||||
tenacity==8.3.0
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
# via streamlit
|
||||
tiktoken==0.6.0
|
||||
# via langchain-openai
|
||||
# via scrapegraphai
|
||||
tokenizers==0.19.1
|
||||
# via anthropic
|
||||
toml==0.10.2
|
||||
# via streamlit
|
||||
toolz==0.12.1
|
||||
# via altair
|
||||
tornado==6.4
|
||||
# via streamlit
|
||||
tqdm==4.66.4
|
||||
# via google-generativeai
|
||||
# via huggingface-hub
|
||||
# via openai
|
||||
# via scrapegraphai
|
||||
typer==0.12.3
|
||||
# via fastapi-cli
|
||||
typing-extensions==4.11.0
|
||||
# via altair
|
||||
# via anthropic
|
||||
# via anyio
|
||||
# via fastapi
|
||||
# via fastapi-pagination
|
||||
# via google-generativeai
|
||||
# via groq
|
||||
# via huggingface-hub
|
||||
@ -282,18 +420,36 @@ typing-extensions==4.11.0
|
||||
# via pydantic
|
||||
# via pydantic-core
|
||||
# via pyee
|
||||
# via sf-hamilton
|
||||
# via sqlalchemy
|
||||
# via streamlit
|
||||
# via typer
|
||||
# via typing-inspect
|
||||
# via uvicorn
|
||||
typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
# via sf-hamilton
|
||||
tzdata==2024.1
|
||||
# via pandas
|
||||
ujson==5.10.0
|
||||
# via fastapi
|
||||
uritemplate==4.1.1
|
||||
# via google-api-python-client
|
||||
urllib3==1.26.18
|
||||
# via botocore
|
||||
# via requests
|
||||
# via yahoo-search-py
|
||||
uvicorn==0.29.0
|
||||
# via burr
|
||||
# via fastapi
|
||||
watchdog==4.0.0
|
||||
# via streamlit
|
||||
watchfiles==0.21.0
|
||||
# via uvicorn
|
||||
websockets==12.0
|
||||
# via uvicorn
|
||||
win32-setctime==1.1.0
|
||||
# via loguru
|
||||
yahoo-search-py==0.3
|
||||
# via scrapegraphai
|
||||
yarl==1.9.4
|
||||
|
||||
@ -15,4 +15,3 @@ from .csv_scraper_graph import CSVScraperGraph
|
||||
from .pdf_scraper_graph import PDFScraperGraph
|
||||
from .omni_scraper_graph import OmniScraperGraph
|
||||
from .omni_search_graph import OmniSearchGraph
|
||||
from .turbo_scraper import TurboScraperGraph
|
||||
|
||||
@ -7,6 +7,8 @@ import warnings
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from typing import Tuple
|
||||
|
||||
from ..integrations import BurrBridge
|
||||
|
||||
|
||||
class BaseGraph:
|
||||
"""
|
||||
@ -40,20 +42,27 @@ class BaseGraph:
|
||||
... (parse_node, rag_node),
|
||||
... (rag_node, generate_answer_node)
|
||||
... ],
|
||||
... entry_point=fetch_node
|
||||
... entry_point=fetch_node,
|
||||
... use_burr=True,
|
||||
... burr_config={"app_instance_id": "example-instance"}
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, nodes: list, edges: list, entry_point: str):
|
||||
def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool = False, burr_config: dict = None):
|
||||
|
||||
self.nodes = nodes
|
||||
self.edges = self._create_edges({e for e in edges})
|
||||
self.entry_point = entry_point.node_name
|
||||
self.initial_state = {}
|
||||
|
||||
if nodes[0].node_name != entry_point.node_name:
|
||||
# raise a warning if the entry point is not the first node in the list
|
||||
warnings.warn(
|
||||
"Careful! The entry point node is different from the first node if the graph.")
|
||||
|
||||
# Burr configuration
|
||||
self.use_burr = use_burr
|
||||
self.burr_config = burr_config or {}
|
||||
|
||||
def _create_edges(self, edges: list) -> dict:
|
||||
"""
|
||||
@ -71,11 +80,9 @@ class BaseGraph:
|
||||
edge_dict[from_node.node_name] = to_node.node_name
|
||||
return edge_dict
|
||||
|
||||
def execute(self, initial_state: dict) -> Tuple[dict, list]:
|
||||
def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
|
||||
"""
|
||||
Executes the graph by traversing nodes starting from the entry point. The execution
|
||||
follows the edges based on the result of each node's execution and continues until
|
||||
it reaches a node with no outgoing edges.
|
||||
Executes the graph by traversing nodes starting from the entry point using the standard method.
|
||||
|
||||
Args:
|
||||
initial_state (dict): The initial state to pass to the entry point node.
|
||||
@ -83,8 +90,7 @@ class BaseGraph:
|
||||
Returns:
|
||||
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
|
||||
"""
|
||||
|
||||
current_node_name = self.nodes[0]
|
||||
current_node_name = self.entry_point
|
||||
state = initial_state
|
||||
|
||||
# variables for tracking execution info
|
||||
@ -98,18 +104,17 @@ class BaseGraph:
|
||||
"total_cost_USD": 0.0,
|
||||
}
|
||||
|
||||
for index in self.nodes:
|
||||
|
||||
while current_node_name:
|
||||
curr_time = time.time()
|
||||
current_node = index
|
||||
current_node = next(node for node in self.nodes if node.node_name == current_node_name)
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
result = current_node.execute(state)
|
||||
node_exec_time = time.time() - curr_time
|
||||
total_exec_time += node_exec_time
|
||||
|
||||
cb = {
|
||||
"node_name": index.node_name,
|
||||
cb_data = {
|
||||
"node_name": current_node.node_name,
|
||||
"total_tokens": cb.total_tokens,
|
||||
"prompt_tokens": cb.prompt_tokens,
|
||||
"completion_tokens": cb.completion_tokens,
|
||||
@ -118,15 +123,13 @@ class BaseGraph:
|
||||
"exec_time": node_exec_time,
|
||||
}
|
||||
|
||||
exec_info.append(
|
||||
cb
|
||||
)
|
||||
exec_info.append(cb_data)
|
||||
|
||||
cb_total["total_tokens"] += cb["total_tokens"]
|
||||
cb_total["prompt_tokens"] += cb["prompt_tokens"]
|
||||
cb_total["completion_tokens"] += cb["completion_tokens"]
|
||||
cb_total["successful_requests"] += cb["successful_requests"]
|
||||
cb_total["total_cost_USD"] += cb["total_cost_USD"]
|
||||
cb_total["total_tokens"] += cb_data["total_tokens"]
|
||||
cb_total["prompt_tokens"] += cb_data["prompt_tokens"]
|
||||
cb_total["completion_tokens"] += cb_data["completion_tokens"]
|
||||
cb_total["successful_requests"] += cb_data["successful_requests"]
|
||||
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]
|
||||
|
||||
if current_node.node_type == "conditional_node":
|
||||
current_node_name = result
|
||||
@ -137,12 +140,30 @@ class BaseGraph:
|
||||
|
||||
exec_info.append({
|
||||
"node_name": "TOTAL RESULT",
|
||||
"total_tokens": cb_total["total_tokens"],
|
||||
"prompt_tokens": cb_total["prompt_tokens"],
|
||||
"total_tokens": cb_total["total_tokens"],
|
||||
"prompt_tokens": cb_total["prompt_tokens"],
|
||||
"completion_tokens": cb_total["completion_tokens"],
|
||||
"successful_requests": cb_total["successful_requests"],
|
||||
"total_cost_USD": cb_total["total_cost_USD"],
|
||||
"total_cost_USD": cb_total["total_cost_USD"],
|
||||
"exec_time": total_exec_time,
|
||||
})
|
||||
|
||||
return state, exec_info
|
||||
|
||||
def execute(self, initial_state: dict) -> Tuple[dict, list]:
|
||||
"""
|
||||
Executes the graph by either using BurrBridge or the standard method.
|
||||
|
||||
Args:
|
||||
initial_state (dict): The initial state to pass to the entry point node.
|
||||
|
||||
Returns:
|
||||
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
|
||||
"""
|
||||
|
||||
self.initial_state = initial_state
|
||||
if self.use_burr:
|
||||
bridge = BurrBridge(self, self.burr_config)
|
||||
return bridge.execute(initial_state)
|
||||
else:
|
||||
return self._execute_standard(initial_state)
|
||||
@ -25,7 +25,7 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from tqdm import tqdm
|
||||
|
||||
if __name__ == '__main__':
|
||||
from scrapegraphai.utils.remover import remover
|
||||
from scrapegraphai.utils import cleanup_html
|
||||
else:
|
||||
from ..utils.remover import remover
|
||||
|
||||
|
||||
@ -1,146 +0,0 @@
|
||||
"""
|
||||
SmartScraperGraph Module
|
||||
"""
|
||||
|
||||
from .base_graph import BaseGraph
|
||||
from ..nodes import (
|
||||
FetchNode,
|
||||
ParseNode,
|
||||
RAGNode,
|
||||
SearchLinksWithContext,
|
||||
GraphIteratorNode,
|
||||
MergeAnswersNode
|
||||
)
|
||||
from .search_graph import SearchGraph
|
||||
from .abstract_graph import AbstractGraph
|
||||
|
||||
|
||||
class SmartScraperGraph(AbstractGraph):
|
||||
"""
|
||||
SmartScraper is a scraping pipeline that automates the process of
|
||||
extracting information from web pages
|
||||
using a natural language model to interpret and answer prompts.
|
||||
|
||||
Attributes:
|
||||
prompt (str): The prompt for the graph.
|
||||
source (str): The source of the graph.
|
||||
config (dict): Configuration parameters for the graph.
|
||||
llm_model: An instance of a language model client, configured for generating answers.
|
||||
embedder_model: An instance of an embedding model client,
|
||||
configured for generating embeddings.
|
||||
verbose (bool): A flag indicating whether to show print statements during execution.
|
||||
headless (bool): A flag indicating whether to run the graph in headless mode.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt for the graph.
|
||||
source (str): The source of the graph.
|
||||
config (dict): Configuration parameters for the graph.
|
||||
|
||||
Example:
|
||||
>>> smart_scraper = SmartScraperGraph(
|
||||
... "List me all the attractions in Chioggia.",
|
||||
... "https://en.wikipedia.org/wiki/Chioggia",
|
||||
... {"llm": {"model": "gpt-3.5-turbo"}}
|
||||
... )
|
||||
>>> result = smart_scraper.run()
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str, source: str, config: dict):
|
||||
super().__init__(prompt, config, source)
|
||||
|
||||
self.input_key = "url" if source.startswith("http") else "local_dir"
|
||||
|
||||
def _create_graph(self) -> BaseGraph:
|
||||
"""
|
||||
Creates the graph of nodes representing the workflow for web scraping.
|
||||
|
||||
Returns:
|
||||
BaseGraph: A graph instance representing the web scraping workflow.
|
||||
"""
|
||||
smart_scraper_graph = SmartScraperGraph(
|
||||
prompt="",
|
||||
source="",
|
||||
config=self.llm_model
|
||||
)
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"]
|
||||
)
|
||||
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={
|
||||
"chunk_size": self.model_token
|
||||
}
|
||||
)
|
||||
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"embedder_model": self.embedder_model
|
||||
}
|
||||
)
|
||||
|
||||
search_link_with_context_node = SearchLinksWithContext(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model
|
||||
}
|
||||
)
|
||||
|
||||
graph_iterator_node = GraphIteratorNode(
|
||||
input="user_prompt & urls",
|
||||
output=["results"],
|
||||
node_config={
|
||||
"graph_instance": smart_scraper_graph,
|
||||
"verbose": True,
|
||||
}
|
||||
)
|
||||
|
||||
merge_answers_node = MergeAnswersNode(
|
||||
input="user_prompt & results",
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"verbose": True,
|
||||
}
|
||||
)
|
||||
|
||||
return BaseGraph(
|
||||
nodes=[
|
||||
fetch_node,
|
||||
parse_node,
|
||||
rag_node,
|
||||
search_link_with_context_node,
|
||||
graph_iterator_node,
|
||||
merge_answers_node
|
||||
|
||||
],
|
||||
edges=[
|
||||
(fetch_node, parse_node),
|
||||
(parse_node, rag_node),
|
||||
(rag_node, search_link_with_context_node),
|
||||
(search_link_with_context_node, graph_iterator_node),
|
||||
(graph_iterator_node, merge_answers_node),
|
||||
|
||||
],
|
||||
entry_point=fetch_node
|
||||
)
|
||||
|
||||
def run(self) -> str:
|
||||
"""
|
||||
Executes the scraping process and returns the answer to the prompt.
|
||||
|
||||
Returns:
|
||||
str: The answer to the prompt.
|
||||
"""
|
||||
|
||||
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
|
||||
self.final_state, self.execution_info = self.graph.execute(inputs)
|
||||
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
1
scrapegraphai/integrations/__init__.py
Normal file
1
scrapegraphai/integrations/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .burr_bridge import BurrBridge
|
||||
198
scrapegraphai/integrations/burr_bridge.py
Normal file
198
scrapegraphai/integrations/burr_bridge.py
Normal file
@ -0,0 +1,198 @@
|
||||
"""
|
||||
Bridge class to integrate Burr into ScrapeGraphAI graphs
|
||||
[Burr](https://github.com/DAGWorks-Inc/burr)
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from burr import tracking
|
||||
from burr.core import Application, ApplicationBuilder, State, Action, default
|
||||
from burr.core.action import action
|
||||
from burr.lifecycle import PostRunStepHook, PreRunStepHook
|
||||
|
||||
|
||||
class PrintLnHook(PostRunStepHook, PreRunStepHook):
|
||||
"""
|
||||
Hook to print the action name before and after it is executed.
|
||||
"""
|
||||
|
||||
def pre_run_step(self, *, state: "State", action: "Action", **future_kwargs: Any):
|
||||
print(f"Starting action: {action.name}")
|
||||
|
||||
def post_run_step(self, *, state: "State", action: "Action", **future_kwargs: Any):
|
||||
print(f"Finishing action: {action.name}")
|
||||
|
||||
class BurrBridge:
|
||||
"""
|
||||
Bridge class to integrate Burr into ScrapeGraphAI graphs.
|
||||
|
||||
Args:
|
||||
base_graph (BaseGraph): The base graph to convert to a Burr application.
|
||||
burr_config (dict): Configuration parameters for the Burr application.
|
||||
|
||||
Attributes:
|
||||
base_graph (BaseGraph): The base graph to convert to a Burr application.
|
||||
burr_config (dict): Configuration parameters for the Burr application.
|
||||
tracker (LocalTrackingClient): The tracking client for the Burr application.
|
||||
app_instance_id (str): The instance ID for the Burr application.
|
||||
burr_inputs (dict): The inputs for the Burr application.
|
||||
burr_app (Application): The Burr application instance.
|
||||
|
||||
Example:
|
||||
>>> burr_bridge = BurrBridge(base_graph, burr_config)
|
||||
>>> result = burr_bridge.execute(initial_state={"input_key": "input_value"})
|
||||
"""
|
||||
|
||||
def __init__(self, base_graph, burr_config):
|
||||
self.base_graph = base_graph
|
||||
self.burr_config = burr_config
|
||||
self.tracker = tracking.LocalTrackingClient(project="smart-scraper-graph")
|
||||
self.app_instance_id = burr_config.get("app_instance_id", "default-instance")
|
||||
self.burr_inputs = burr_config.get("inputs", {})
|
||||
self.burr_app = None
|
||||
|
||||
def _initialize_burr_app(self, initial_state: Dict[str, Any] = {}) -> Application:
|
||||
"""
|
||||
Initialize a Burr application from the base graph.
|
||||
|
||||
Args:
|
||||
initial_state (dict): The initial state of the Burr application.
|
||||
|
||||
Returns:
|
||||
Application: The Burr application instance.
|
||||
"""
|
||||
|
||||
actions = self._create_actions()
|
||||
transitions = self._create_transitions()
|
||||
hooks = [PrintLnHook()]
|
||||
burr_state = self._convert_state_to_burr(initial_state)
|
||||
|
||||
app = (
|
||||
ApplicationBuilder()
|
||||
.with_actions(**actions)
|
||||
.with_transitions(*transitions)
|
||||
.with_entrypoint(self.base_graph.entry_point)
|
||||
.with_state(**burr_state)
|
||||
.with_identifiers(app_id=self.app_instance_id)
|
||||
.with_tracker(self.tracker)
|
||||
.with_hooks(*hooks)
|
||||
.build()
|
||||
)
|
||||
return app
|
||||
|
||||
def _create_actions(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Create Burr actions from the base graph nodes.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of Burr actions with the node name as keys and the action functions as values.
|
||||
"""
|
||||
|
||||
actions = {}
|
||||
for node in self.base_graph.nodes:
|
||||
action_func = self._create_action(node)
|
||||
actions[node.node_name] = action_func
|
||||
return actions
|
||||
|
||||
def _create_action(self, node) -> Any:
|
||||
"""
|
||||
Create a Burr action function from a base graph node.
|
||||
|
||||
Args:
|
||||
node (Node): The base graph node to convert to a Burr action.
|
||||
|
||||
Returns:
|
||||
function: The Burr action function.
|
||||
"""
|
||||
|
||||
@action(reads=self._parse_boolean_expression(node.input), writes=node.output)
|
||||
def dynamic_action(state: State, **kwargs):
|
||||
node_inputs = {key: state[key] for key in self._parse_boolean_expression(node.input)}
|
||||
result_state = node.execute(node_inputs, **kwargs)
|
||||
return result_state, state.update(**result_state)
|
||||
return dynamic_action
|
||||
|
||||
def _create_transitions(self) -> List[Tuple[str, str, Any]]:
|
||||
"""
|
||||
Create Burr transitions from the base graph edges.
|
||||
|
||||
Returns:
|
||||
list: A list of tuples representing the transitions between Burr actions.
|
||||
"""
|
||||
|
||||
transitions = []
|
||||
for from_node, to_node in self.base_graph.edges.items():
|
||||
transitions.append((from_node, to_node, default))
|
||||
return transitions
|
||||
|
||||
def _parse_boolean_expression(self, expression: str) -> List[str]:
|
||||
"""
|
||||
Parse a boolean expression to extract the keys used in the expression, without boolean operators.
|
||||
|
||||
Args:
|
||||
expression (str): The boolean expression to parse.
|
||||
|
||||
Returns:
|
||||
list: A list of unique keys used in the expression.
|
||||
"""
|
||||
|
||||
# Use regular expression to extract all unique keys
|
||||
keys = re.findall(r'\w+', expression)
|
||||
return list(set(keys)) # Remove duplicates
|
||||
|
||||
def _convert_state_to_burr(self, state: Dict[str, Any]) -> State:
|
||||
"""
|
||||
Convert a dictionary state to a Burr state.
|
||||
|
||||
Args:
|
||||
state (dict): The dictionary state to convert.
|
||||
|
||||
Returns:
|
||||
State: The Burr state instance.
|
||||
"""
|
||||
|
||||
burr_state = State()
|
||||
for key, value in state.items():
|
||||
setattr(burr_state, key, value)
|
||||
return burr_state
|
||||
|
||||
def _convert_state_from_burr(self, burr_state: State) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Burr state to a dictionary state.
|
||||
|
||||
Args:
|
||||
burr_state (State): The Burr state to convert.
|
||||
|
||||
Returns:
|
||||
dict: The dictionary state instance.
|
||||
"""
|
||||
|
||||
state = {}
|
||||
for key in burr_state.__dict__.keys():
|
||||
state[key] = getattr(burr_state, key)
|
||||
return state
|
||||
|
||||
def execute(self, initial_state: Dict[str, Any] = {}) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the Burr application with the given initial state.
|
||||
|
||||
Args:
|
||||
initial_state (dict): The initial state to pass to the Burr application.
|
||||
|
||||
Returns:
|
||||
dict: The final state of the Burr application.
|
||||
"""
|
||||
|
||||
self.burr_app = self._initialize_burr_app(initial_state)
|
||||
|
||||
# TODO: to fix final nodes detection
|
||||
final_nodes = [self.burr_app.graph.actions[-1].name]
|
||||
|
||||
# TODO: fix inputs
|
||||
last_action, result, final_state = self.burr_app.run(
|
||||
halt_after=final_nodes,
|
||||
inputs=self.burr_inputs
|
||||
)
|
||||
|
||||
return self._convert_state_from_burr(final_state)
|
||||
Loading…
Reference in New Issue
Block a user