feat(knowledgegraph): add knowledge graph node

This commit is contained in:
Marco Perini 2024-05-17 23:41:44 +02:00
parent 8c33ea3fbc
commit 0196423bde
9 changed files with 242 additions and 75 deletions

3
.gitignore vendored
View File

@ -32,5 +32,6 @@ examples/graph_examples/ScrapeGraphAI_generated_graph
examples/**/result.csv
examples/**/result.json
main.py
lib/
*.html

View File

@ -50,6 +50,7 @@ graph_config = {
"model": "gpt-4o",
"temperature": 0,
},
"verbose": True,
}
# ************************************************
@ -59,11 +60,9 @@ graph_config = {
llm_model = OpenAI(graph_config["llm"])
robots_node = KnowledgeGraphNode(
input="answer & user_prompt",
input="user_prompt & answer_dict",
output=["is_scrapable"],
node_config={"llm_model": llm_model,
"headless": False
}
node_config={"llm_model": llm_model}
)
# ************************************************
@ -71,7 +70,8 @@ robots_node = KnowledgeGraphNode(
# ************************************************
state = {
"url": "https://twitter.com/home"
"user_prompt": "What are the job postings?",
"answer_dict": job_postings
}
result = robots_node.execute(state)

View File

@ -30,6 +30,8 @@ dependencies = [
"playwright==1.43.0",
"google==3.0.0",
"yahoo-search-py==0.3",
"networkx==3.3",
"pyvis==0.3.2",
]
license = "MIT"

View File

@ -22,6 +22,8 @@ anyio==4.3.0
# via groq
# via httpx
# via openai
asttokens==2.4.1
# via stack-data
async-timeout==4.0.3
# via aiohttp
# via langchain
@ -43,9 +45,15 @@ certifi==2024.2.2
# via requests
charset-normalizer==3.3.2
# via requests
colorama==0.4.6
# via ipython
# via pytest
# via tqdm
dataclasses-json==0.6.6
# via langchain
# via langchain-community
decorator==5.1.1
# via ipython
defusedxml==0.7.1
# via langchain-anthropic
distro==1.9.0
@ -54,7 +62,10 @@ distro==1.9.0
# via openai
exceptiongroup==1.2.1
# via anyio
# via ipython
# via pytest
executing==2.0.1
# via stack-data
faiss-cpu==1.8.0
# via scrapegraphai
filelock==3.14.0
@ -93,6 +104,7 @@ graphviz==0.20.3
# via scrapegraphai
greenlet==3.0.3
# via playwright
# via sqlalchemy
groq==0.5.0
# via langchain-groq
grpcio==1.63.0
@ -123,12 +135,20 @@ idna==3.7
# via yarl
iniconfig==2.0.0
# via pytest
ipython==8.24.0
# via pyvis
jedi==0.19.1
# via ipython
jinja2==3.1.4
# via pyvis
jmespath==1.0.1
# via boto3
# via botocore
jsonpatch==1.33
# via langchain
# via langchain-core
jsonpickle==3.0.4
# via pyvis
jsonpointer==2.4
# via jsonpatch
langchain==0.1.15
@ -162,8 +182,12 @@ langsmith==0.1.58
# via langchain-core
lxml==5.2.2
# via free-proxy
markupsafe==2.1.5
# via jinja2
marshmallow==3.21.2
# via dataclasses-json
matplotlib-inline==0.1.7
# via ipython
minify-html==0.15.0
# via scrapegraphai
multidict==6.0.5
@ -171,6 +195,9 @@ multidict==6.0.5
# via yarl
mypy-extensions==1.0.0
# via typing-inspect
networkx==3.3
# via pyvis
# via scrapegraphai
numpy==1.26.4
# via faiss-cpu
# via langchain
@ -188,10 +215,14 @@ packaging==23.2
# via pytest
pandas==2.2.2
# via scrapegraphai
parso==0.8.4
# via jedi
playwright==1.43.0
# via scrapegraphai
pluggy==1.5.0
# via pytest
prompt-toolkit==3.0.43
# via ipython
proto-plus==1.23.0
# via google-ai-generativelanguage
# via google-api-core
@ -202,6 +233,8 @@ protobuf==4.25.3
# via googleapis-common-protos
# via grpcio-status
# via proto-plus
pure-eval==0.2.2
# via stack-data
pyasn1==0.6.0
# via pyasn1-modules
# via rsa
@ -220,6 +253,8 @@ pydantic-core==2.18.2
# via pydantic
pyee==11.1.0
# via playwright
pygments==2.18.0
# via ipython
pyparsing==3.1.2
# via httplib2
pytest==8.0.0
@ -232,6 +267,8 @@ python-dotenv==1.0.1
# via scrapegraphai
pytz==2024.1
# via pandas
pyvis==0.3.2
# via scrapegraphai
pyyaml==6.0.1
# via huggingface-hub
# via langchain
@ -254,6 +291,7 @@ s3transfer==0.10.1
selectolax==0.3.21
# via yahoo-search-py
six==1.16.0
# via asttokens
# via python-dateutil
sniffio==1.3.1
# via anthropic
@ -266,6 +304,8 @@ soupsieve==2.5
sqlalchemy==2.0.30
# via langchain
# via langchain-community
stack-data==0.6.3
# via ipython
tenacity==8.3.0
# via langchain
# via langchain-community
@ -282,12 +322,16 @@ tqdm==4.66.4
# via huggingface-hub
# via openai
# via scrapegraphai
traitlets==5.14.3
# via ipython
# via matplotlib-inline
typing-extensions==4.11.0
# via anthropic
# via anyio
# via google-generativeai
# via groq
# via huggingface-hub
# via ipython
# via openai
# via pydantic
# via pydantic-core
@ -304,6 +348,8 @@ urllib3==2.2.1
# via botocore
# via requests
# via yahoo-search-py
wcwidth==0.2.13
# via prompt-toolkit
yahoo-search-py==0.3
# via scrapegraphai
yarl==1.9.4

View File

@ -22,6 +22,8 @@ anyio==4.3.0
# via groq
# via httpx
# via openai
asttokens==2.4.1
# via stack-data
async-timeout==4.0.3
# via aiohttp
# via langchain
@ -43,9 +45,14 @@ certifi==2024.2.2
# via requests
charset-normalizer==3.3.2
# via requests
colorama==0.4.6
# via ipython
# via tqdm
dataclasses-json==0.6.6
# via langchain
# via langchain-community
decorator==5.1.1
# via ipython
defusedxml==0.7.1
# via langchain-anthropic
distro==1.9.0
@ -54,6 +61,9 @@ distro==1.9.0
# via openai
exceptiongroup==1.2.1
# via anyio
# via ipython
executing==2.0.1
# via stack-data
faiss-cpu==1.8.0
# via scrapegraphai
filelock==3.14.0
@ -92,6 +102,7 @@ graphviz==0.20.3
# via scrapegraphai
greenlet==3.0.3
# via playwright
# via sqlalchemy
groq==0.5.0
# via langchain-groq
grpcio==1.63.0
@ -120,12 +131,20 @@ idna==3.7
# via httpx
# via requests
# via yarl
ipython==8.24.0
# via pyvis
jedi==0.19.1
# via ipython
jinja2==3.1.4
# via pyvis
jmespath==1.0.1
# via boto3
# via botocore
jsonpatch==1.33
# via langchain
# via langchain-core
jsonpickle==3.0.4
# via pyvis
jsonpointer==2.4
# via jsonpatch
langchain==0.1.15
@ -159,8 +178,12 @@ langsmith==0.1.58
# via langchain-core
lxml==5.2.2
# via free-proxy
markupsafe==2.1.5
# via jinja2
marshmallow==3.21.2
# via dataclasses-json
matplotlib-inline==0.1.7
# via ipython
minify-html==0.15.0
# via scrapegraphai
multidict==6.0.5
@ -168,6 +191,9 @@ multidict==6.0.5
# via yarl
mypy-extensions==1.0.0
# via typing-inspect
networkx==3.3
# via pyvis
# via scrapegraphai
numpy==1.26.4
# via faiss-cpu
# via langchain
@ -184,8 +210,12 @@ packaging==23.2
# via marshmallow
pandas==2.2.2
# via scrapegraphai
parso==0.8.4
# via jedi
playwright==1.43.0
# via scrapegraphai
prompt-toolkit==3.0.43
# via ipython
proto-plus==1.23.0
# via google-ai-generativelanguage
# via google-api-core
@ -196,6 +226,8 @@ protobuf==4.25.3
# via googleapis-common-protos
# via grpcio-status
# via proto-plus
pure-eval==0.2.2
# via stack-data
pyasn1==0.6.0
# via pyasn1-modules
# via rsa
@ -214,6 +246,8 @@ pydantic-core==2.18.2
# via pydantic
pyee==11.1.0
# via playwright
pygments==2.18.0
# via ipython
pyparsing==3.1.2
# via httplib2
python-dateutil==2.9.0.post0
@ -223,6 +257,8 @@ python-dotenv==1.0.1
# via scrapegraphai
pytz==2024.1
# via pandas
pyvis==0.3.2
# via scrapegraphai
pyyaml==6.0.1
# via huggingface-hub
# via langchain
@ -245,6 +281,7 @@ s3transfer==0.10.1
selectolax==0.3.21
# via yahoo-search-py
six==1.16.0
# via asttokens
# via python-dateutil
sniffio==1.3.1
# via anthropic
@ -257,6 +294,8 @@ soupsieve==2.5
sqlalchemy==2.0.30
# via langchain
# via langchain-community
stack-data==0.6.3
# via ipython
tenacity==8.3.0
# via langchain
# via langchain-community
@ -271,12 +310,16 @@ tqdm==4.66.4
# via huggingface-hub
# via openai
# via scrapegraphai
traitlets==5.14.3
# via ipython
# via matplotlib-inline
typing-extensions==4.11.0
# via anthropic
# via anyio
# via google-generativeai
# via groq
# via huggingface-hub
# via ipython
# via openai
# via pydantic
# via pydantic-core
@ -293,6 +336,8 @@ urllib3==2.2.1
# via botocore
# via requests
# via yahoo-search-py
wcwidth==0.2.13
# via prompt-toolkit
yahoo-search-py==0.3
# via scrapegraphai
yarl==1.9.4

View File

@ -6,7 +6,6 @@ import time
import warnings
from langchain_community.callbacks import get_openai_callback
from typing import Tuple
from collections import deque
class BaseGraph:
@ -27,8 +26,6 @@ class BaseGraph:
Raises:
Warning: If the entry point node is not the first node in the list.
ValueError: If conditional_node does not have exactly two outgoing edges
Example:
>>> BaseGraph(
@ -51,7 +48,7 @@ class BaseGraph:
self.nodes = nodes
self.edges = self._create_edges({e for e in edges})
self.entry_point = entry_point
self.entry_point = entry_point.node_name
if nodes[0].node_name != entry_point.node_name:
# raise a warning if the entry point is not the first node in the list
@ -71,16 +68,13 @@ class BaseGraph:
edge_dict = {}
for from_node, to_node in edges:
if from_node in edge_dict:
edge_dict[from_node].append(to_node)
else:
edge_dict[from_node] = [to_node]
edge_dict[from_node.node_name] = to_node.node_name
return edge_dict
def execute(self, initial_state: dict) -> Tuple[dict, list]:
"""
Executes the graph by traversing nodes in breadth-first order starting from the entry point.
The execution follows the edges based on the result of each node's execution and continues until
Executes the graph by traversing nodes starting from the entry point. The execution
follows the edges based on the result of each node's execution and continues until
it reaches a node with no outgoing edges.
Args:
@ -90,6 +84,7 @@ class BaseGraph:
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
"""
current_node_name = self.nodes[0]
state = initial_state
# variables for tracking execution info
@ -103,22 +98,23 @@ class BaseGraph:
"total_cost_USD": 0.0,
}
queue = deque([self.entry_point])
while queue:
current_node = queue.popleft()
for index in self.nodes:
curr_time = time.time()
with get_openai_callback() as callback:
current_node = index
with get_openai_callback() as cb:
result = current_node.execute(state)
node_exec_time = time.time() - curr_time
total_exec_time += node_exec_time
cb = {
"node_name": current_node.node_name,
"total_tokens": callback.total_tokens,
"prompt_tokens": callback.prompt_tokens,
"completion_tokens": callback.completion_tokens,
"successful_requests": callback.successful_requests,
"total_cost_USD": callback.total_cost,
"node_name": index.node_name,
"total_tokens": cb.total_tokens,
"prompt_tokens": cb.prompt_tokens,
"completion_tokens": cb.completion_tokens,
"successful_requests": cb.successful_requests,
"total_cost_USD": cb.total_cost,
"exec_time": node_exec_time,
}
@ -132,31 +128,21 @@ class BaseGraph:
cb_total["successful_requests"] += cb["successful_requests"]
cb_total["total_cost_USD"] += cb["total_cost_USD"]
current_node_connections = self.edges[current_node]
if current_node.node_type == 'conditional_node':
# Assert that there are exactly two out edges from the conditional node
if len(current_node_connections) != 2:
raise ValueError(f"Conditional node should have exactly two out connections {current_node_connections.node_name}")
if result["next_node"] == 0:
queue.append(current_node_connections[0])
else:
queue.append(current_node_connections[1])
# remove the conditional node result
del result["next_node"]
else:
queue.extend(node for node in current_node_connections)
if current_node.node_type == "conditional_node":
current_node_name = result
elif current_node_name in self.edges:
current_node_name = self.edges[current_node_name]
else:
current_node_name = None
exec_info.append({
"node_name": "TOTAL RESULT",
"total_tokens": cb_total["total_tokens"],
"prompt_tokens": cb_total["prompt_tokens"],
"completion_tokens": cb_total["completion_tokens"],
"successful_requests": cb_total["successful_requests"],
"total_cost_USD": cb_total["total_cost_USD"],
"exec_time": total_exec_time,
})
exec_info.append({
"node_name": "TOTAL RESULT",
"total_tokens": cb_total["total_tokens"],
"prompt_tokens": cb_total["prompt_tokens"],
"completion_tokens": cb_total["completion_tokens"],
"successful_requests": cb_total["successful_requests"],
"total_cost_USD": cb_total["total_cost_USD"],
"exec_time": total_exec_time,
})
return state, exec_info
return state, exec_info

View File

@ -12,7 +12,7 @@ from langchain_core.output_parsers import JsonOutputParser
# Imports from the library
from .base_node import BaseNode
from ..utils import create_graph, add_customizations, create_interactive_graph
class KnowledgeGraphNode(BaseNode):
"""
@ -65,31 +65,36 @@ class KnowledgeGraphNode(BaseNode):
user_prompt = input_data[0]
answer_dict = input_data[1]
output_parser = JsonOutputParser()
format_instructions = output_parser.get_format_instructions()
# Build the graph
graph = create_graph(answer_dict)
# Create the interactive graph
create_interactive_graph(graph, output_file='knowledge_graph.html')
template_merge = """
You are a website scraper and you have just scraped some content from multiple websites.\n
You are now asked to provide an answer to a USER PROMPT based on the content you have scraped.\n
You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n
The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n
OUTPUT INSTRUCTIONS: {format_instructions}\n
USER PROMPT: {user_prompt}\n
WEBSITE CONTENT: {website_content}
"""
# output_parser = JsonOutputParser()
# format_instructions = output_parser.get_format_instructions()
prompt_template = PromptTemplate(
template=template_merge,
input_variables=["user_prompt"],
partial_variables={
"format_instructions": format_instructions,
"website_content": answers_str,
},
)
# template_merge = """
# You are a website scraper and you have just scraped some content from multiple websites.\n
# You are now asked to provide an answer to a USER PROMPT based on the content you have scraped.\n
# You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n
# The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n
# OUTPUT INSTRUCTIONS: {format_instructions}\n
# USER PROMPT: {user_prompt}\n
# WEBSITE CONTENT: {website_content}
# """
merge_chain = prompt_template | self.llm_model | output_parser
answer = merge_chain.invoke({"user_prompt": user_prompt})
# prompt_template = PromptTemplate(
# template=template_merge,
# input_variables=["user_prompt"],
# partial_variables={
# "format_instructions": format_instructions,
# "website_content": answers_str,
# },
# )
# merge_chain = prompt_template | self.llm_model | output_parser
# answer = merge_chain.invoke({"user_prompt": user_prompt})
# Update the state with the generated answer
state.update({self.output[0]: answer})
state.update({self.output[0]: graph})
return state

View File

@ -9,3 +9,4 @@ from .proxy_rotation import Proxy, parse_or_search_proxy, search_proxy_servers
from .save_audio_from_bytes import save_audio_from_bytes
from .sys_dynamic_import import dynamic_import, srcfile_import
from .cleanup_html import cleanup_html
from .knowledge_graph import create_graph, add_customizations, create_interactive_graph

View File

@ -0,0 +1,81 @@
import networkx as nx
from pyvis.network import Network
import webbrowser
import os
# Create and visualize graph
def create_graph(job_postings):
graph = nx.DiGraph()
# Add the main "Job Postings" node
graph.add_node("Job Postings")
for company, jobs in job_postings["Job Postings"].items():
# Add company node
graph.add_node(company)
graph.add_edge("Job Postings", company)
# Add job nodes and their details
for idx, job in enumerate(jobs, start=1):
job_id = f"{company}-Job{idx}"
graph.add_node(job_id)
graph.add_edge(company, job_id)
for key, value in job.items():
if isinstance(value, list):
list_node_id = f"{job_id}-{key}"
graph.add_node(list_node_id, label=key)
graph.add_edge(job_id, list_node_id)
for item in value:
detail_id = f"{list_node_id}-{item}"
graph.add_node(detail_id, label=item, title=item)
graph.add_edge(list_node_id, detail_id)
else:
detail_id = f"{job_id}-{key}"
graph.add_node(detail_id, label=key, title=f"{key}: {value}")
graph.add_edge(job_id, detail_id)
return graph
# Add customizations to the network
def add_customizations(net, graph):
node_colors = {}
node_sizes = {}
# Custom colors and sizes for nodes
node_colors["Job Postings"] = '#8470FF'
node_sizes["Job Postings"] = 50
for node in graph.nodes:
if node in node_colors:
continue
if '-' not in node: # Company nodes
node_colors[node] = '#3CB371'
node_sizes[node] = 30
elif '-' in node and node.count('-') == 1: # Job nodes
node_colors[node] = '#FFA07A'
node_sizes[node] = 20
else: # Job detail nodes
node_colors[node] = '#B0C4DE'
node_sizes[node] = 10
# Add nodes and edges to the network with customized styles
for node in graph.nodes:
net.add_node(node,
label=graph.nodes[node].get('label', node.split('-')[-1]),
color=node_colors.get(node, 'lightgray'),
size=node_sizes.get(node, 15),
title=graph.nodes[node].get('title', ''))
for edge in graph.edges:
net.add_edge(edge[0], edge[1])
return net
# Create interactive graph
def create_interactive_graph(graph, output_file='interactive_graph.html'):
net = Network(notebook=False, height='1000px', width='100%', bgcolor='white', font_color='black')
net = add_customizations(net, graph)
net.save_graph(output_file)
# Automatically open the generated HTML file in the default web browser
webbrowser.open(f"file://{os.path.realpath(output_file)}")