mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
feat(knowledgegraph): add knowledge graph node
This commit is contained in:
parent
8c33ea3fbc
commit
0196423bde
3
.gitignore
vendored
3
.gitignore
vendored
@ -32,5 +32,6 @@ examples/graph_examples/ScrapeGraphAI_generated_graph
|
|||||||
examples/**/result.csv
|
examples/**/result.csv
|
||||||
examples/**/result.json
|
examples/**/result.json
|
||||||
main.py
|
main.py
|
||||||
|
lib/
|
||||||
|
*.html
|
||||||
|
|
||||||
@ -50,6 +50,7 @@ graph_config = {
|
|||||||
"model": "gpt-4o",
|
"model": "gpt-4o",
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
},
|
},
|
||||||
|
"verbose": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# ************************************************
|
# ************************************************
|
||||||
@ -59,11 +60,9 @@ graph_config = {
|
|||||||
llm_model = OpenAI(graph_config["llm"])
|
llm_model = OpenAI(graph_config["llm"])
|
||||||
|
|
||||||
robots_node = KnowledgeGraphNode(
|
robots_node = KnowledgeGraphNode(
|
||||||
input="answer & user_prompt",
|
input="user_prompt & answer_dict",
|
||||||
output=["is_scrapable"],
|
output=["is_scrapable"],
|
||||||
node_config={"llm_model": llm_model,
|
node_config={"llm_model": llm_model}
|
||||||
"headless": False
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ************************************************
|
# ************************************************
|
||||||
@ -71,7 +70,8 @@ robots_node = KnowledgeGraphNode(
|
|||||||
# ************************************************
|
# ************************************************
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
"url": "https://twitter.com/home"
|
"user_prompt": "What are the job postings?",
|
||||||
|
"answer_dict": job_postings
|
||||||
}
|
}
|
||||||
|
|
||||||
result = robots_node.execute(state)
|
result = robots_node.execute(state)
|
||||||
|
|||||||
@ -30,6 +30,8 @@ dependencies = [
|
|||||||
"playwright==1.43.0",
|
"playwright==1.43.0",
|
||||||
"google==3.0.0",
|
"google==3.0.0",
|
||||||
"yahoo-search-py==0.3",
|
"yahoo-search-py==0.3",
|
||||||
|
"networkx==3.3",
|
||||||
|
"pyvis==0.3.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
|||||||
@ -22,6 +22,8 @@ anyio==4.3.0
|
|||||||
# via groq
|
# via groq
|
||||||
# via httpx
|
# via httpx
|
||||||
# via openai
|
# via openai
|
||||||
|
asttokens==2.4.1
|
||||||
|
# via stack-data
|
||||||
async-timeout==4.0.3
|
async-timeout==4.0.3
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
# via langchain
|
# via langchain
|
||||||
@ -43,9 +45,15 @@ certifi==2024.2.2
|
|||||||
# via requests
|
# via requests
|
||||||
charset-normalizer==3.3.2
|
charset-normalizer==3.3.2
|
||||||
# via requests
|
# via requests
|
||||||
|
colorama==0.4.6
|
||||||
|
# via ipython
|
||||||
|
# via pytest
|
||||||
|
# via tqdm
|
||||||
dataclasses-json==0.6.6
|
dataclasses-json==0.6.6
|
||||||
# via langchain
|
# via langchain
|
||||||
# via langchain-community
|
# via langchain-community
|
||||||
|
decorator==5.1.1
|
||||||
|
# via ipython
|
||||||
defusedxml==0.7.1
|
defusedxml==0.7.1
|
||||||
# via langchain-anthropic
|
# via langchain-anthropic
|
||||||
distro==1.9.0
|
distro==1.9.0
|
||||||
@ -54,7 +62,10 @@ distro==1.9.0
|
|||||||
# via openai
|
# via openai
|
||||||
exceptiongroup==1.2.1
|
exceptiongroup==1.2.1
|
||||||
# via anyio
|
# via anyio
|
||||||
|
# via ipython
|
||||||
# via pytest
|
# via pytest
|
||||||
|
executing==2.0.1
|
||||||
|
# via stack-data
|
||||||
faiss-cpu==1.8.0
|
faiss-cpu==1.8.0
|
||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
filelock==3.14.0
|
filelock==3.14.0
|
||||||
@ -93,6 +104,7 @@ graphviz==0.20.3
|
|||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
greenlet==3.0.3
|
greenlet==3.0.3
|
||||||
# via playwright
|
# via playwright
|
||||||
|
# via sqlalchemy
|
||||||
groq==0.5.0
|
groq==0.5.0
|
||||||
# via langchain-groq
|
# via langchain-groq
|
||||||
grpcio==1.63.0
|
grpcio==1.63.0
|
||||||
@ -123,12 +135,20 @@ idna==3.7
|
|||||||
# via yarl
|
# via yarl
|
||||||
iniconfig==2.0.0
|
iniconfig==2.0.0
|
||||||
# via pytest
|
# via pytest
|
||||||
|
ipython==8.24.0
|
||||||
|
# via pyvis
|
||||||
|
jedi==0.19.1
|
||||||
|
# via ipython
|
||||||
|
jinja2==3.1.4
|
||||||
|
# via pyvis
|
||||||
jmespath==1.0.1
|
jmespath==1.0.1
|
||||||
# via boto3
|
# via boto3
|
||||||
# via botocore
|
# via botocore
|
||||||
jsonpatch==1.33
|
jsonpatch==1.33
|
||||||
# via langchain
|
# via langchain
|
||||||
# via langchain-core
|
# via langchain-core
|
||||||
|
jsonpickle==3.0.4
|
||||||
|
# via pyvis
|
||||||
jsonpointer==2.4
|
jsonpointer==2.4
|
||||||
# via jsonpatch
|
# via jsonpatch
|
||||||
langchain==0.1.15
|
langchain==0.1.15
|
||||||
@ -162,8 +182,12 @@ langsmith==0.1.58
|
|||||||
# via langchain-core
|
# via langchain-core
|
||||||
lxml==5.2.2
|
lxml==5.2.2
|
||||||
# via free-proxy
|
# via free-proxy
|
||||||
|
markupsafe==2.1.5
|
||||||
|
# via jinja2
|
||||||
marshmallow==3.21.2
|
marshmallow==3.21.2
|
||||||
# via dataclasses-json
|
# via dataclasses-json
|
||||||
|
matplotlib-inline==0.1.7
|
||||||
|
# via ipython
|
||||||
minify-html==0.15.0
|
minify-html==0.15.0
|
||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
multidict==6.0.5
|
multidict==6.0.5
|
||||||
@ -171,6 +195,9 @@ multidict==6.0.5
|
|||||||
# via yarl
|
# via yarl
|
||||||
mypy-extensions==1.0.0
|
mypy-extensions==1.0.0
|
||||||
# via typing-inspect
|
# via typing-inspect
|
||||||
|
networkx==3.3
|
||||||
|
# via pyvis
|
||||||
|
# via scrapegraphai
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
# via faiss-cpu
|
# via faiss-cpu
|
||||||
# via langchain
|
# via langchain
|
||||||
@ -188,10 +215,14 @@ packaging==23.2
|
|||||||
# via pytest
|
# via pytest
|
||||||
pandas==2.2.2
|
pandas==2.2.2
|
||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
|
parso==0.8.4
|
||||||
|
# via jedi
|
||||||
playwright==1.43.0
|
playwright==1.43.0
|
||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
pluggy==1.5.0
|
pluggy==1.5.0
|
||||||
# via pytest
|
# via pytest
|
||||||
|
prompt-toolkit==3.0.43
|
||||||
|
# via ipython
|
||||||
proto-plus==1.23.0
|
proto-plus==1.23.0
|
||||||
# via google-ai-generativelanguage
|
# via google-ai-generativelanguage
|
||||||
# via google-api-core
|
# via google-api-core
|
||||||
@ -202,6 +233,8 @@ protobuf==4.25.3
|
|||||||
# via googleapis-common-protos
|
# via googleapis-common-protos
|
||||||
# via grpcio-status
|
# via grpcio-status
|
||||||
# via proto-plus
|
# via proto-plus
|
||||||
|
pure-eval==0.2.2
|
||||||
|
# via stack-data
|
||||||
pyasn1==0.6.0
|
pyasn1==0.6.0
|
||||||
# via pyasn1-modules
|
# via pyasn1-modules
|
||||||
# via rsa
|
# via rsa
|
||||||
@ -220,6 +253,8 @@ pydantic-core==2.18.2
|
|||||||
# via pydantic
|
# via pydantic
|
||||||
pyee==11.1.0
|
pyee==11.1.0
|
||||||
# via playwright
|
# via playwright
|
||||||
|
pygments==2.18.0
|
||||||
|
# via ipython
|
||||||
pyparsing==3.1.2
|
pyparsing==3.1.2
|
||||||
# via httplib2
|
# via httplib2
|
||||||
pytest==8.0.0
|
pytest==8.0.0
|
||||||
@ -232,6 +267,8 @@ python-dotenv==1.0.1
|
|||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
pytz==2024.1
|
pytz==2024.1
|
||||||
# via pandas
|
# via pandas
|
||||||
|
pyvis==0.3.2
|
||||||
|
# via scrapegraphai
|
||||||
pyyaml==6.0.1
|
pyyaml==6.0.1
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
# via langchain
|
# via langchain
|
||||||
@ -254,6 +291,7 @@ s3transfer==0.10.1
|
|||||||
selectolax==0.3.21
|
selectolax==0.3.21
|
||||||
# via yahoo-search-py
|
# via yahoo-search-py
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
|
# via asttokens
|
||||||
# via python-dateutil
|
# via python-dateutil
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
# via anthropic
|
# via anthropic
|
||||||
@ -266,6 +304,8 @@ soupsieve==2.5
|
|||||||
sqlalchemy==2.0.30
|
sqlalchemy==2.0.30
|
||||||
# via langchain
|
# via langchain
|
||||||
# via langchain-community
|
# via langchain-community
|
||||||
|
stack-data==0.6.3
|
||||||
|
# via ipython
|
||||||
tenacity==8.3.0
|
tenacity==8.3.0
|
||||||
# via langchain
|
# via langchain
|
||||||
# via langchain-community
|
# via langchain-community
|
||||||
@ -282,12 +322,16 @@ tqdm==4.66.4
|
|||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
# via openai
|
# via openai
|
||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
|
traitlets==5.14.3
|
||||||
|
# via ipython
|
||||||
|
# via matplotlib-inline
|
||||||
typing-extensions==4.11.0
|
typing-extensions==4.11.0
|
||||||
# via anthropic
|
# via anthropic
|
||||||
# via anyio
|
# via anyio
|
||||||
# via google-generativeai
|
# via google-generativeai
|
||||||
# via groq
|
# via groq
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
|
# via ipython
|
||||||
# via openai
|
# via openai
|
||||||
# via pydantic
|
# via pydantic
|
||||||
# via pydantic-core
|
# via pydantic-core
|
||||||
@ -304,6 +348,8 @@ urllib3==2.2.1
|
|||||||
# via botocore
|
# via botocore
|
||||||
# via requests
|
# via requests
|
||||||
# via yahoo-search-py
|
# via yahoo-search-py
|
||||||
|
wcwidth==0.2.13
|
||||||
|
# via prompt-toolkit
|
||||||
yahoo-search-py==0.3
|
yahoo-search-py==0.3
|
||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
yarl==1.9.4
|
yarl==1.9.4
|
||||||
|
|||||||
@ -22,6 +22,8 @@ anyio==4.3.0
|
|||||||
# via groq
|
# via groq
|
||||||
# via httpx
|
# via httpx
|
||||||
# via openai
|
# via openai
|
||||||
|
asttokens==2.4.1
|
||||||
|
# via stack-data
|
||||||
async-timeout==4.0.3
|
async-timeout==4.0.3
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
# via langchain
|
# via langchain
|
||||||
@ -43,9 +45,14 @@ certifi==2024.2.2
|
|||||||
# via requests
|
# via requests
|
||||||
charset-normalizer==3.3.2
|
charset-normalizer==3.3.2
|
||||||
# via requests
|
# via requests
|
||||||
|
colorama==0.4.6
|
||||||
|
# via ipython
|
||||||
|
# via tqdm
|
||||||
dataclasses-json==0.6.6
|
dataclasses-json==0.6.6
|
||||||
# via langchain
|
# via langchain
|
||||||
# via langchain-community
|
# via langchain-community
|
||||||
|
decorator==5.1.1
|
||||||
|
# via ipython
|
||||||
defusedxml==0.7.1
|
defusedxml==0.7.1
|
||||||
# via langchain-anthropic
|
# via langchain-anthropic
|
||||||
distro==1.9.0
|
distro==1.9.0
|
||||||
@ -54,6 +61,9 @@ distro==1.9.0
|
|||||||
# via openai
|
# via openai
|
||||||
exceptiongroup==1.2.1
|
exceptiongroup==1.2.1
|
||||||
# via anyio
|
# via anyio
|
||||||
|
# via ipython
|
||||||
|
executing==2.0.1
|
||||||
|
# via stack-data
|
||||||
faiss-cpu==1.8.0
|
faiss-cpu==1.8.0
|
||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
filelock==3.14.0
|
filelock==3.14.0
|
||||||
@ -92,6 +102,7 @@ graphviz==0.20.3
|
|||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
greenlet==3.0.3
|
greenlet==3.0.3
|
||||||
# via playwright
|
# via playwright
|
||||||
|
# via sqlalchemy
|
||||||
groq==0.5.0
|
groq==0.5.0
|
||||||
# via langchain-groq
|
# via langchain-groq
|
||||||
grpcio==1.63.0
|
grpcio==1.63.0
|
||||||
@ -120,12 +131,20 @@ idna==3.7
|
|||||||
# via httpx
|
# via httpx
|
||||||
# via requests
|
# via requests
|
||||||
# via yarl
|
# via yarl
|
||||||
|
ipython==8.24.0
|
||||||
|
# via pyvis
|
||||||
|
jedi==0.19.1
|
||||||
|
# via ipython
|
||||||
|
jinja2==3.1.4
|
||||||
|
# via pyvis
|
||||||
jmespath==1.0.1
|
jmespath==1.0.1
|
||||||
# via boto3
|
# via boto3
|
||||||
# via botocore
|
# via botocore
|
||||||
jsonpatch==1.33
|
jsonpatch==1.33
|
||||||
# via langchain
|
# via langchain
|
||||||
# via langchain-core
|
# via langchain-core
|
||||||
|
jsonpickle==3.0.4
|
||||||
|
# via pyvis
|
||||||
jsonpointer==2.4
|
jsonpointer==2.4
|
||||||
# via jsonpatch
|
# via jsonpatch
|
||||||
langchain==0.1.15
|
langchain==0.1.15
|
||||||
@ -159,8 +178,12 @@ langsmith==0.1.58
|
|||||||
# via langchain-core
|
# via langchain-core
|
||||||
lxml==5.2.2
|
lxml==5.2.2
|
||||||
# via free-proxy
|
# via free-proxy
|
||||||
|
markupsafe==2.1.5
|
||||||
|
# via jinja2
|
||||||
marshmallow==3.21.2
|
marshmallow==3.21.2
|
||||||
# via dataclasses-json
|
# via dataclasses-json
|
||||||
|
matplotlib-inline==0.1.7
|
||||||
|
# via ipython
|
||||||
minify-html==0.15.0
|
minify-html==0.15.0
|
||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
multidict==6.0.5
|
multidict==6.0.5
|
||||||
@ -168,6 +191,9 @@ multidict==6.0.5
|
|||||||
# via yarl
|
# via yarl
|
||||||
mypy-extensions==1.0.0
|
mypy-extensions==1.0.0
|
||||||
# via typing-inspect
|
# via typing-inspect
|
||||||
|
networkx==3.3
|
||||||
|
# via pyvis
|
||||||
|
# via scrapegraphai
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
# via faiss-cpu
|
# via faiss-cpu
|
||||||
# via langchain
|
# via langchain
|
||||||
@ -184,8 +210,12 @@ packaging==23.2
|
|||||||
# via marshmallow
|
# via marshmallow
|
||||||
pandas==2.2.2
|
pandas==2.2.2
|
||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
|
parso==0.8.4
|
||||||
|
# via jedi
|
||||||
playwright==1.43.0
|
playwright==1.43.0
|
||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
|
prompt-toolkit==3.0.43
|
||||||
|
# via ipython
|
||||||
proto-plus==1.23.0
|
proto-plus==1.23.0
|
||||||
# via google-ai-generativelanguage
|
# via google-ai-generativelanguage
|
||||||
# via google-api-core
|
# via google-api-core
|
||||||
@ -196,6 +226,8 @@ protobuf==4.25.3
|
|||||||
# via googleapis-common-protos
|
# via googleapis-common-protos
|
||||||
# via grpcio-status
|
# via grpcio-status
|
||||||
# via proto-plus
|
# via proto-plus
|
||||||
|
pure-eval==0.2.2
|
||||||
|
# via stack-data
|
||||||
pyasn1==0.6.0
|
pyasn1==0.6.0
|
||||||
# via pyasn1-modules
|
# via pyasn1-modules
|
||||||
# via rsa
|
# via rsa
|
||||||
@ -214,6 +246,8 @@ pydantic-core==2.18.2
|
|||||||
# via pydantic
|
# via pydantic
|
||||||
pyee==11.1.0
|
pyee==11.1.0
|
||||||
# via playwright
|
# via playwright
|
||||||
|
pygments==2.18.0
|
||||||
|
# via ipython
|
||||||
pyparsing==3.1.2
|
pyparsing==3.1.2
|
||||||
# via httplib2
|
# via httplib2
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
@ -223,6 +257,8 @@ python-dotenv==1.0.1
|
|||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
pytz==2024.1
|
pytz==2024.1
|
||||||
# via pandas
|
# via pandas
|
||||||
|
pyvis==0.3.2
|
||||||
|
# via scrapegraphai
|
||||||
pyyaml==6.0.1
|
pyyaml==6.0.1
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
# via langchain
|
# via langchain
|
||||||
@ -245,6 +281,7 @@ s3transfer==0.10.1
|
|||||||
selectolax==0.3.21
|
selectolax==0.3.21
|
||||||
# via yahoo-search-py
|
# via yahoo-search-py
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
|
# via asttokens
|
||||||
# via python-dateutil
|
# via python-dateutil
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
# via anthropic
|
# via anthropic
|
||||||
@ -257,6 +294,8 @@ soupsieve==2.5
|
|||||||
sqlalchemy==2.0.30
|
sqlalchemy==2.0.30
|
||||||
# via langchain
|
# via langchain
|
||||||
# via langchain-community
|
# via langchain-community
|
||||||
|
stack-data==0.6.3
|
||||||
|
# via ipython
|
||||||
tenacity==8.3.0
|
tenacity==8.3.0
|
||||||
# via langchain
|
# via langchain
|
||||||
# via langchain-community
|
# via langchain-community
|
||||||
@ -271,12 +310,16 @@ tqdm==4.66.4
|
|||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
# via openai
|
# via openai
|
||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
|
traitlets==5.14.3
|
||||||
|
# via ipython
|
||||||
|
# via matplotlib-inline
|
||||||
typing-extensions==4.11.0
|
typing-extensions==4.11.0
|
||||||
# via anthropic
|
# via anthropic
|
||||||
# via anyio
|
# via anyio
|
||||||
# via google-generativeai
|
# via google-generativeai
|
||||||
# via groq
|
# via groq
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
|
# via ipython
|
||||||
# via openai
|
# via openai
|
||||||
# via pydantic
|
# via pydantic
|
||||||
# via pydantic-core
|
# via pydantic-core
|
||||||
@ -293,6 +336,8 @@ urllib3==2.2.1
|
|||||||
# via botocore
|
# via botocore
|
||||||
# via requests
|
# via requests
|
||||||
# via yahoo-search-py
|
# via yahoo-search-py
|
||||||
|
wcwidth==0.2.13
|
||||||
|
# via prompt-toolkit
|
||||||
yahoo-search-py==0.3
|
yahoo-search-py==0.3
|
||||||
# via scrapegraphai
|
# via scrapegraphai
|
||||||
yarl==1.9.4
|
yarl==1.9.4
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from langchain_community.callbacks import get_openai_callback
|
from langchain_community.callbacks import get_openai_callback
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
|
|
||||||
class BaseGraph:
|
class BaseGraph:
|
||||||
@ -27,8 +26,6 @@ class BaseGraph:
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Warning: If the entry point node is not the first node in the list.
|
Warning: If the entry point node is not the first node in the list.
|
||||||
ValueError: If conditional_node does not have exactly two outgoing edges
|
|
||||||
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> BaseGraph(
|
>>> BaseGraph(
|
||||||
@ -51,7 +48,7 @@ class BaseGraph:
|
|||||||
|
|
||||||
self.nodes = nodes
|
self.nodes = nodes
|
||||||
self.edges = self._create_edges({e for e in edges})
|
self.edges = self._create_edges({e for e in edges})
|
||||||
self.entry_point = entry_point
|
self.entry_point = entry_point.node_name
|
||||||
|
|
||||||
if nodes[0].node_name != entry_point.node_name:
|
if nodes[0].node_name != entry_point.node_name:
|
||||||
# raise a warning if the entry point is not the first node in the list
|
# raise a warning if the entry point is not the first node in the list
|
||||||
@ -71,16 +68,13 @@ class BaseGraph:
|
|||||||
|
|
||||||
edge_dict = {}
|
edge_dict = {}
|
||||||
for from_node, to_node in edges:
|
for from_node, to_node in edges:
|
||||||
if from_node in edge_dict:
|
edge_dict[from_node.node_name] = to_node.node_name
|
||||||
edge_dict[from_node].append(to_node)
|
|
||||||
else:
|
|
||||||
edge_dict[from_node] = [to_node]
|
|
||||||
return edge_dict
|
return edge_dict
|
||||||
|
|
||||||
def execute(self, initial_state: dict) -> Tuple[dict, list]:
|
def execute(self, initial_state: dict) -> Tuple[dict, list]:
|
||||||
"""
|
"""
|
||||||
Executes the graph by traversing nodes in breadth-first order starting from the entry point.
|
Executes the graph by traversing nodes starting from the entry point. The execution
|
||||||
The execution follows the edges based on the result of each node's execution and continues until
|
follows the edges based on the result of each node's execution and continues until
|
||||||
it reaches a node with no outgoing edges.
|
it reaches a node with no outgoing edges.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -90,6 +84,7 @@ class BaseGraph:
|
|||||||
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
|
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
current_node_name = self.nodes[0]
|
||||||
state = initial_state
|
state = initial_state
|
||||||
|
|
||||||
# variables for tracking execution info
|
# variables for tracking execution info
|
||||||
@ -103,22 +98,23 @@ class BaseGraph:
|
|||||||
"total_cost_USD": 0.0,
|
"total_cost_USD": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
queue = deque([self.entry_point])
|
for index in self.nodes:
|
||||||
while queue:
|
|
||||||
current_node = queue.popleft()
|
|
||||||
curr_time = time.time()
|
curr_time = time.time()
|
||||||
with get_openai_callback() as callback:
|
current_node = index
|
||||||
|
|
||||||
|
with get_openai_callback() as cb:
|
||||||
result = current_node.execute(state)
|
result = current_node.execute(state)
|
||||||
node_exec_time = time.time() - curr_time
|
node_exec_time = time.time() - curr_time
|
||||||
total_exec_time += node_exec_time
|
total_exec_time += node_exec_time
|
||||||
|
|
||||||
cb = {
|
cb = {
|
||||||
"node_name": current_node.node_name,
|
"node_name": index.node_name,
|
||||||
"total_tokens": callback.total_tokens,
|
"total_tokens": cb.total_tokens,
|
||||||
"prompt_tokens": callback.prompt_tokens,
|
"prompt_tokens": cb.prompt_tokens,
|
||||||
"completion_tokens": callback.completion_tokens,
|
"completion_tokens": cb.completion_tokens,
|
||||||
"successful_requests": callback.successful_requests,
|
"successful_requests": cb.successful_requests,
|
||||||
"total_cost_USD": callback.total_cost,
|
"total_cost_USD": cb.total_cost,
|
||||||
"exec_time": node_exec_time,
|
"exec_time": node_exec_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,31 +128,21 @@ class BaseGraph:
|
|||||||
cb_total["successful_requests"] += cb["successful_requests"]
|
cb_total["successful_requests"] += cb["successful_requests"]
|
||||||
cb_total["total_cost_USD"] += cb["total_cost_USD"]
|
cb_total["total_cost_USD"] += cb["total_cost_USD"]
|
||||||
|
|
||||||
|
if current_node.node_type == "conditional_node":
|
||||||
|
current_node_name = result
|
||||||
current_node_connections = self.edges[current_node]
|
elif current_node_name in self.edges:
|
||||||
if current_node.node_type == 'conditional_node':
|
current_node_name = self.edges[current_node_name]
|
||||||
# Assert that there are exactly two out edges from the conditional node
|
else:
|
||||||
if len(current_node_connections) != 2:
|
current_node_name = None
|
||||||
raise ValueError(f"Conditional node should have exactly two out connections {current_node_connections.node_name}")
|
|
||||||
if result["next_node"] == 0:
|
|
||||||
queue.append(current_node_connections[0])
|
|
||||||
else:
|
|
||||||
queue.append(current_node_connections[1])
|
|
||||||
# remove the conditional node result
|
|
||||||
del result["next_node"]
|
|
||||||
else:
|
|
||||||
queue.extend(node for node in current_node_connections)
|
|
||||||
|
|
||||||
|
exec_info.append({
|
||||||
|
"node_name": "TOTAL RESULT",
|
||||||
|
"total_tokens": cb_total["total_tokens"],
|
||||||
|
"prompt_tokens": cb_total["prompt_tokens"],
|
||||||
|
"completion_tokens": cb_total["completion_tokens"],
|
||||||
|
"successful_requests": cb_total["successful_requests"],
|
||||||
|
"total_cost_USD": cb_total["total_cost_USD"],
|
||||||
|
"exec_time": total_exec_time,
|
||||||
|
})
|
||||||
|
|
||||||
exec_info.append({
|
return state, exec_info
|
||||||
"node_name": "TOTAL RESULT",
|
|
||||||
"total_tokens": cb_total["total_tokens"],
|
|
||||||
"prompt_tokens": cb_total["prompt_tokens"],
|
|
||||||
"completion_tokens": cb_total["completion_tokens"],
|
|
||||||
"successful_requests": cb_total["successful_requests"],
|
|
||||||
"total_cost_USD": cb_total["total_cost_USD"],
|
|
||||||
"exec_time": total_exec_time,
|
|
||||||
})
|
|
||||||
|
|
||||||
return state, exec_info
|
|
||||||
@ -12,7 +12,7 @@ from langchain_core.output_parsers import JsonOutputParser
|
|||||||
|
|
||||||
# Imports from the library
|
# Imports from the library
|
||||||
from .base_node import BaseNode
|
from .base_node import BaseNode
|
||||||
|
from ..utils import create_graph, add_customizations, create_interactive_graph
|
||||||
|
|
||||||
class KnowledgeGraphNode(BaseNode):
|
class KnowledgeGraphNode(BaseNode):
|
||||||
"""
|
"""
|
||||||
@ -65,31 +65,36 @@ class KnowledgeGraphNode(BaseNode):
|
|||||||
user_prompt = input_data[0]
|
user_prompt = input_data[0]
|
||||||
answer_dict = input_data[1]
|
answer_dict = input_data[1]
|
||||||
|
|
||||||
output_parser = JsonOutputParser()
|
# Build the graph
|
||||||
format_instructions = output_parser.get_format_instructions()
|
graph = create_graph(answer_dict)
|
||||||
|
# Create the interactive graph
|
||||||
|
create_interactive_graph(graph, output_file='knowledge_graph.html')
|
||||||
|
|
||||||
template_merge = """
|
# output_parser = JsonOutputParser()
|
||||||
You are a website scraper and you have just scraped some content from multiple websites.\n
|
# format_instructions = output_parser.get_format_instructions()
|
||||||
You are now asked to provide an answer to a USER PROMPT based on the content you have scraped.\n
|
|
||||||
You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n
|
|
||||||
The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n
|
|
||||||
OUTPUT INSTRUCTIONS: {format_instructions}\n
|
|
||||||
USER PROMPT: {user_prompt}\n
|
|
||||||
WEBSITE CONTENT: {website_content}
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt_template = PromptTemplate(
|
# template_merge = """
|
||||||
template=template_merge,
|
# You are a website scraper and you have just scraped some content from multiple websites.\n
|
||||||
input_variables=["user_prompt"],
|
# You are now asked to provide an answer to a USER PROMPT based on the content you have scraped.\n
|
||||||
partial_variables={
|
# You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n
|
||||||
"format_instructions": format_instructions,
|
# The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n
|
||||||
"website_content": answers_str,
|
# OUTPUT INSTRUCTIONS: {format_instructions}\n
|
||||||
},
|
# USER PROMPT: {user_prompt}\n
|
||||||
)
|
# WEBSITE CONTENT: {website_content}
|
||||||
|
# """
|
||||||
|
|
||||||
merge_chain = prompt_template | self.llm_model | output_parser
|
# prompt_template = PromptTemplate(
|
||||||
answer = merge_chain.invoke({"user_prompt": user_prompt})
|
# template=template_merge,
|
||||||
|
# input_variables=["user_prompt"],
|
||||||
|
# partial_variables={
|
||||||
|
# "format_instructions": format_instructions,
|
||||||
|
# "website_content": answers_str,
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
|
||||||
|
# merge_chain = prompt_template | self.llm_model | output_parser
|
||||||
|
# answer = merge_chain.invoke({"user_prompt": user_prompt})
|
||||||
|
|
||||||
# Update the state with the generated answer
|
# Update the state with the generated answer
|
||||||
state.update({self.output[0]: answer})
|
state.update({self.output[0]: graph})
|
||||||
return state
|
return state
|
||||||
|
|||||||
@ -9,3 +9,4 @@ from .proxy_rotation import Proxy, parse_or_search_proxy, search_proxy_servers
|
|||||||
from .save_audio_from_bytes import save_audio_from_bytes
|
from .save_audio_from_bytes import save_audio_from_bytes
|
||||||
from .sys_dynamic_import import dynamic_import, srcfile_import
|
from .sys_dynamic_import import dynamic_import, srcfile_import
|
||||||
from .cleanup_html import cleanup_html
|
from .cleanup_html import cleanup_html
|
||||||
|
from .knowledge_graph import create_graph, add_customizations, create_interactive_graph
|
||||||
81
scrapegraphai/utils/knowledge_graph.py
Normal file
81
scrapegraphai/utils/knowledge_graph.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import networkx as nx
|
||||||
|
from pyvis.network import Network
|
||||||
|
import webbrowser
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Create and visualize graph
|
||||||
|
def create_graph(job_postings):
|
||||||
|
graph = nx.DiGraph()
|
||||||
|
|
||||||
|
# Add the main "Job Postings" node
|
||||||
|
graph.add_node("Job Postings")
|
||||||
|
|
||||||
|
for company, jobs in job_postings["Job Postings"].items():
|
||||||
|
# Add company node
|
||||||
|
graph.add_node(company)
|
||||||
|
graph.add_edge("Job Postings", company)
|
||||||
|
|
||||||
|
# Add job nodes and their details
|
||||||
|
for idx, job in enumerate(jobs, start=1):
|
||||||
|
job_id = f"{company}-Job{idx}"
|
||||||
|
graph.add_node(job_id)
|
||||||
|
graph.add_edge(company, job_id)
|
||||||
|
|
||||||
|
for key, value in job.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
list_node_id = f"{job_id}-{key}"
|
||||||
|
graph.add_node(list_node_id, label=key)
|
||||||
|
graph.add_edge(job_id, list_node_id)
|
||||||
|
for item in value:
|
||||||
|
detail_id = f"{list_node_id}-{item}"
|
||||||
|
graph.add_node(detail_id, label=item, title=item)
|
||||||
|
graph.add_edge(list_node_id, detail_id)
|
||||||
|
else:
|
||||||
|
detail_id = f"{job_id}-{key}"
|
||||||
|
graph.add_node(detail_id, label=key, title=f"{key}: {value}")
|
||||||
|
graph.add_edge(job_id, detail_id)
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
# Add customizations to the network
|
||||||
|
def add_customizations(net, graph):
|
||||||
|
node_colors = {}
|
||||||
|
node_sizes = {}
|
||||||
|
|
||||||
|
# Custom colors and sizes for nodes
|
||||||
|
node_colors["Job Postings"] = '#8470FF'
|
||||||
|
node_sizes["Job Postings"] = 50
|
||||||
|
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node in node_colors:
|
||||||
|
continue
|
||||||
|
if '-' not in node: # Company nodes
|
||||||
|
node_colors[node] = '#3CB371'
|
||||||
|
node_sizes[node] = 30
|
||||||
|
elif '-' in node and node.count('-') == 1: # Job nodes
|
||||||
|
node_colors[node] = '#FFA07A'
|
||||||
|
node_sizes[node] = 20
|
||||||
|
else: # Job detail nodes
|
||||||
|
node_colors[node] = '#B0C4DE'
|
||||||
|
node_sizes[node] = 10
|
||||||
|
|
||||||
|
# Add nodes and edges to the network with customized styles
|
||||||
|
for node in graph.nodes:
|
||||||
|
net.add_node(node,
|
||||||
|
label=graph.nodes[node].get('label', node.split('-')[-1]),
|
||||||
|
color=node_colors.get(node, 'lightgray'),
|
||||||
|
size=node_sizes.get(node, 15),
|
||||||
|
title=graph.nodes[node].get('title', ''))
|
||||||
|
for edge in graph.edges:
|
||||||
|
net.add_edge(edge[0], edge[1])
|
||||||
|
return net
|
||||||
|
|
||||||
|
# Create interactive graph
|
||||||
|
def create_interactive_graph(graph, output_file='interactive_graph.html'):
|
||||||
|
net = Network(notebook=False, height='1000px', width='100%', bgcolor='white', font_color='black')
|
||||||
|
net = add_customizations(net, graph)
|
||||||
|
net.save_graph(output_file)
|
||||||
|
|
||||||
|
# Automatically open the generated HTML file in the default web browser
|
||||||
|
webbrowser.open(f"file://{os.path.realpath(output_file)}")
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user