feat(knowledgegraph): add knowledge graph node

This commit is contained in:
Marco Perini 2024-05-17 23:41:44 +02:00
parent 8c33ea3fbc
commit 0196423bde
9 changed files with 242 additions and 75 deletions

3
.gitignore vendored
View File

@ -32,5 +32,6 @@ examples/graph_examples/ScrapeGraphAI_generated_graph
examples/**/result.csv examples/**/result.csv
examples/**/result.json examples/**/result.json
main.py main.py
lib/
*.html

View File

@ -50,6 +50,7 @@ graph_config = {
"model": "gpt-4o", "model": "gpt-4o",
"temperature": 0, "temperature": 0,
}, },
"verbose": True,
} }
# ************************************************ # ************************************************
@ -59,11 +60,9 @@ graph_config = {
llm_model = OpenAI(graph_config["llm"]) llm_model = OpenAI(graph_config["llm"])
robots_node = KnowledgeGraphNode( robots_node = KnowledgeGraphNode(
input="answer & user_prompt", input="user_prompt & answer_dict",
output=["is_scrapable"], output=["is_scrapable"],
node_config={"llm_model": llm_model, node_config={"llm_model": llm_model}
"headless": False
}
) )
# ************************************************ # ************************************************
@ -71,7 +70,8 @@ robots_node = KnowledgeGraphNode(
# ************************************************ # ************************************************
state = { state = {
"url": "https://twitter.com/home" "user_prompt": "What are the job postings?",
"answer_dict": job_postings
} }
result = robots_node.execute(state) result = robots_node.execute(state)

View File

@ -30,6 +30,8 @@ dependencies = [
"playwright==1.43.0", "playwright==1.43.0",
"google==3.0.0", "google==3.0.0",
"yahoo-search-py==0.3", "yahoo-search-py==0.3",
"networkx==3.3",
"pyvis==0.3.2",
] ]
license = "MIT" license = "MIT"

View File

@ -22,6 +22,8 @@ anyio==4.3.0
# via groq # via groq
# via httpx # via httpx
# via openai # via openai
asttokens==2.4.1
# via stack-data
async-timeout==4.0.3 async-timeout==4.0.3
# via aiohttp # via aiohttp
# via langchain # via langchain
@ -43,9 +45,15 @@ certifi==2024.2.2
# via requests # via requests
charset-normalizer==3.3.2 charset-normalizer==3.3.2
# via requests # via requests
colorama==0.4.6
# via ipython
# via pytest
# via tqdm
dataclasses-json==0.6.6 dataclasses-json==0.6.6
# via langchain # via langchain
# via langchain-community # via langchain-community
decorator==5.1.1
# via ipython
defusedxml==0.7.1 defusedxml==0.7.1
# via langchain-anthropic # via langchain-anthropic
distro==1.9.0 distro==1.9.0
@ -54,7 +62,10 @@ distro==1.9.0
# via openai # via openai
exceptiongroup==1.2.1 exceptiongroup==1.2.1
# via anyio # via anyio
# via ipython
# via pytest # via pytest
executing==2.0.1
# via stack-data
faiss-cpu==1.8.0 faiss-cpu==1.8.0
# via scrapegraphai # via scrapegraphai
filelock==3.14.0 filelock==3.14.0
@ -93,6 +104,7 @@ graphviz==0.20.3
# via scrapegraphai # via scrapegraphai
greenlet==3.0.3 greenlet==3.0.3
# via playwright # via playwright
# via sqlalchemy
groq==0.5.0 groq==0.5.0
# via langchain-groq # via langchain-groq
grpcio==1.63.0 grpcio==1.63.0
@ -123,12 +135,20 @@ idna==3.7
# via yarl # via yarl
iniconfig==2.0.0 iniconfig==2.0.0
# via pytest # via pytest
ipython==8.24.0
# via pyvis
jedi==0.19.1
# via ipython
jinja2==3.1.4
# via pyvis
jmespath==1.0.1 jmespath==1.0.1
# via boto3 # via boto3
# via botocore # via botocore
jsonpatch==1.33 jsonpatch==1.33
# via langchain # via langchain
# via langchain-core # via langchain-core
jsonpickle==3.0.4
# via pyvis
jsonpointer==2.4 jsonpointer==2.4
# via jsonpatch # via jsonpatch
langchain==0.1.15 langchain==0.1.15
@ -162,8 +182,12 @@ langsmith==0.1.58
# via langchain-core # via langchain-core
lxml==5.2.2 lxml==5.2.2
# via free-proxy # via free-proxy
markupsafe==2.1.5
# via jinja2
marshmallow==3.21.2 marshmallow==3.21.2
# via dataclasses-json # via dataclasses-json
matplotlib-inline==0.1.7
# via ipython
minify-html==0.15.0 minify-html==0.15.0
# via scrapegraphai # via scrapegraphai
multidict==6.0.5 multidict==6.0.5
@ -171,6 +195,9 @@ multidict==6.0.5
# via yarl # via yarl
mypy-extensions==1.0.0 mypy-extensions==1.0.0
# via typing-inspect # via typing-inspect
networkx==3.3
# via pyvis
# via scrapegraphai
numpy==1.26.4 numpy==1.26.4
# via faiss-cpu # via faiss-cpu
# via langchain # via langchain
@ -188,10 +215,14 @@ packaging==23.2
# via pytest # via pytest
pandas==2.2.2 pandas==2.2.2
# via scrapegraphai # via scrapegraphai
parso==0.8.4
# via jedi
playwright==1.43.0 playwright==1.43.0
# via scrapegraphai # via scrapegraphai
pluggy==1.5.0 pluggy==1.5.0
# via pytest # via pytest
prompt-toolkit==3.0.43
# via ipython
proto-plus==1.23.0 proto-plus==1.23.0
# via google-ai-generativelanguage # via google-ai-generativelanguage
# via google-api-core # via google-api-core
@ -202,6 +233,8 @@ protobuf==4.25.3
# via googleapis-common-protos # via googleapis-common-protos
# via grpcio-status # via grpcio-status
# via proto-plus # via proto-plus
pure-eval==0.2.2
# via stack-data
pyasn1==0.6.0 pyasn1==0.6.0
# via pyasn1-modules # via pyasn1-modules
# via rsa # via rsa
@ -220,6 +253,8 @@ pydantic-core==2.18.2
# via pydantic # via pydantic
pyee==11.1.0 pyee==11.1.0
# via playwright # via playwright
pygments==2.18.0
# via ipython
pyparsing==3.1.2 pyparsing==3.1.2
# via httplib2 # via httplib2
pytest==8.0.0 pytest==8.0.0
@ -232,6 +267,8 @@ python-dotenv==1.0.1
# via scrapegraphai # via scrapegraphai
pytz==2024.1 pytz==2024.1
# via pandas # via pandas
pyvis==0.3.2
# via scrapegraphai
pyyaml==6.0.1 pyyaml==6.0.1
# via huggingface-hub # via huggingface-hub
# via langchain # via langchain
@ -254,6 +291,7 @@ s3transfer==0.10.1
selectolax==0.3.21 selectolax==0.3.21
# via yahoo-search-py # via yahoo-search-py
six==1.16.0 six==1.16.0
# via asttokens
# via python-dateutil # via python-dateutil
sniffio==1.3.1 sniffio==1.3.1
# via anthropic # via anthropic
@ -266,6 +304,8 @@ soupsieve==2.5
sqlalchemy==2.0.30 sqlalchemy==2.0.30
# via langchain # via langchain
# via langchain-community # via langchain-community
stack-data==0.6.3
# via ipython
tenacity==8.3.0 tenacity==8.3.0
# via langchain # via langchain
# via langchain-community # via langchain-community
@ -282,12 +322,16 @@ tqdm==4.66.4
# via huggingface-hub # via huggingface-hub
# via openai # via openai
# via scrapegraphai # via scrapegraphai
traitlets==5.14.3
# via ipython
# via matplotlib-inline
typing-extensions==4.11.0 typing-extensions==4.11.0
# via anthropic # via anthropic
# via anyio # via anyio
# via google-generativeai # via google-generativeai
# via groq # via groq
# via huggingface-hub # via huggingface-hub
# via ipython
# via openai # via openai
# via pydantic # via pydantic
# via pydantic-core # via pydantic-core
@ -304,6 +348,8 @@ urllib3==2.2.1
# via botocore # via botocore
# via requests # via requests
# via yahoo-search-py # via yahoo-search-py
wcwidth==0.2.13
# via prompt-toolkit
yahoo-search-py==0.3 yahoo-search-py==0.3
# via scrapegraphai # via scrapegraphai
yarl==1.9.4 yarl==1.9.4

View File

@ -22,6 +22,8 @@ anyio==4.3.0
# via groq # via groq
# via httpx # via httpx
# via openai # via openai
asttokens==2.4.1
# via stack-data
async-timeout==4.0.3 async-timeout==4.0.3
# via aiohttp # via aiohttp
# via langchain # via langchain
@ -43,9 +45,14 @@ certifi==2024.2.2
# via requests # via requests
charset-normalizer==3.3.2 charset-normalizer==3.3.2
# via requests # via requests
colorama==0.4.6
# via ipython
# via tqdm
dataclasses-json==0.6.6 dataclasses-json==0.6.6
# via langchain # via langchain
# via langchain-community # via langchain-community
decorator==5.1.1
# via ipython
defusedxml==0.7.1 defusedxml==0.7.1
# via langchain-anthropic # via langchain-anthropic
distro==1.9.0 distro==1.9.0
@ -54,6 +61,9 @@ distro==1.9.0
# via openai # via openai
exceptiongroup==1.2.1 exceptiongroup==1.2.1
# via anyio # via anyio
# via ipython
executing==2.0.1
# via stack-data
faiss-cpu==1.8.0 faiss-cpu==1.8.0
# via scrapegraphai # via scrapegraphai
filelock==3.14.0 filelock==3.14.0
@ -92,6 +102,7 @@ graphviz==0.20.3
# via scrapegraphai # via scrapegraphai
greenlet==3.0.3 greenlet==3.0.3
# via playwright # via playwright
# via sqlalchemy
groq==0.5.0 groq==0.5.0
# via langchain-groq # via langchain-groq
grpcio==1.63.0 grpcio==1.63.0
@ -120,12 +131,20 @@ idna==3.7
# via httpx # via httpx
# via requests # via requests
# via yarl # via yarl
ipython==8.24.0
# via pyvis
jedi==0.19.1
# via ipython
jinja2==3.1.4
# via pyvis
jmespath==1.0.1 jmespath==1.0.1
# via boto3 # via boto3
# via botocore # via botocore
jsonpatch==1.33 jsonpatch==1.33
# via langchain # via langchain
# via langchain-core # via langchain-core
jsonpickle==3.0.4
# via pyvis
jsonpointer==2.4 jsonpointer==2.4
# via jsonpatch # via jsonpatch
langchain==0.1.15 langchain==0.1.15
@ -159,8 +178,12 @@ langsmith==0.1.58
# via langchain-core # via langchain-core
lxml==5.2.2 lxml==5.2.2
# via free-proxy # via free-proxy
markupsafe==2.1.5
# via jinja2
marshmallow==3.21.2 marshmallow==3.21.2
# via dataclasses-json # via dataclasses-json
matplotlib-inline==0.1.7
# via ipython
minify-html==0.15.0 minify-html==0.15.0
# via scrapegraphai # via scrapegraphai
multidict==6.0.5 multidict==6.0.5
@ -168,6 +191,9 @@ multidict==6.0.5
# via yarl # via yarl
mypy-extensions==1.0.0 mypy-extensions==1.0.0
# via typing-inspect # via typing-inspect
networkx==3.3
# via pyvis
# via scrapegraphai
numpy==1.26.4 numpy==1.26.4
# via faiss-cpu # via faiss-cpu
# via langchain # via langchain
@ -184,8 +210,12 @@ packaging==23.2
# via marshmallow # via marshmallow
pandas==2.2.2 pandas==2.2.2
# via scrapegraphai # via scrapegraphai
parso==0.8.4
# via jedi
playwright==1.43.0 playwright==1.43.0
# via scrapegraphai # via scrapegraphai
prompt-toolkit==3.0.43
# via ipython
proto-plus==1.23.0 proto-plus==1.23.0
# via google-ai-generativelanguage # via google-ai-generativelanguage
# via google-api-core # via google-api-core
@ -196,6 +226,8 @@ protobuf==4.25.3
# via googleapis-common-protos # via googleapis-common-protos
# via grpcio-status # via grpcio-status
# via proto-plus # via proto-plus
pure-eval==0.2.2
# via stack-data
pyasn1==0.6.0 pyasn1==0.6.0
# via pyasn1-modules # via pyasn1-modules
# via rsa # via rsa
@ -214,6 +246,8 @@ pydantic-core==2.18.2
# via pydantic # via pydantic
pyee==11.1.0 pyee==11.1.0
# via playwright # via playwright
pygments==2.18.0
# via ipython
pyparsing==3.1.2 pyparsing==3.1.2
# via httplib2 # via httplib2
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
@ -223,6 +257,8 @@ python-dotenv==1.0.1
# via scrapegraphai # via scrapegraphai
pytz==2024.1 pytz==2024.1
# via pandas # via pandas
pyvis==0.3.2
# via scrapegraphai
pyyaml==6.0.1 pyyaml==6.0.1
# via huggingface-hub # via huggingface-hub
# via langchain # via langchain
@ -245,6 +281,7 @@ s3transfer==0.10.1
selectolax==0.3.21 selectolax==0.3.21
# via yahoo-search-py # via yahoo-search-py
six==1.16.0 six==1.16.0
# via asttokens
# via python-dateutil # via python-dateutil
sniffio==1.3.1 sniffio==1.3.1
# via anthropic # via anthropic
@ -257,6 +294,8 @@ soupsieve==2.5
sqlalchemy==2.0.30 sqlalchemy==2.0.30
# via langchain # via langchain
# via langchain-community # via langchain-community
stack-data==0.6.3
# via ipython
tenacity==8.3.0 tenacity==8.3.0
# via langchain # via langchain
# via langchain-community # via langchain-community
@ -271,12 +310,16 @@ tqdm==4.66.4
# via huggingface-hub # via huggingface-hub
# via openai # via openai
# via scrapegraphai # via scrapegraphai
traitlets==5.14.3
# via ipython
# via matplotlib-inline
typing-extensions==4.11.0 typing-extensions==4.11.0
# via anthropic # via anthropic
# via anyio # via anyio
# via google-generativeai # via google-generativeai
# via groq # via groq
# via huggingface-hub # via huggingface-hub
# via ipython
# via openai # via openai
# via pydantic # via pydantic
# via pydantic-core # via pydantic-core
@ -293,6 +336,8 @@ urllib3==2.2.1
# via botocore # via botocore
# via requests # via requests
# via yahoo-search-py # via yahoo-search-py
wcwidth==0.2.13
# via prompt-toolkit
yahoo-search-py==0.3 yahoo-search-py==0.3
# via scrapegraphai # via scrapegraphai
yarl==1.9.4 yarl==1.9.4

View File

@ -6,7 +6,6 @@ import time
import warnings import warnings
from langchain_community.callbacks import get_openai_callback from langchain_community.callbacks import get_openai_callback
from typing import Tuple from typing import Tuple
from collections import deque
class BaseGraph: class BaseGraph:
@ -27,8 +26,6 @@ class BaseGraph:
Raises: Raises:
Warning: If the entry point node is not the first node in the list. Warning: If the entry point node is not the first node in the list.
ValueError: If conditional_node does not have exactly two outgoing edges
Example: Example:
>>> BaseGraph( >>> BaseGraph(
@ -51,7 +48,7 @@ class BaseGraph:
self.nodes = nodes self.nodes = nodes
self.edges = self._create_edges({e for e in edges}) self.edges = self._create_edges({e for e in edges})
self.entry_point = entry_point self.entry_point = entry_point.node_name
if nodes[0].node_name != entry_point.node_name: if nodes[0].node_name != entry_point.node_name:
# raise a warning if the entry point is not the first node in the list # raise a warning if the entry point is not the first node in the list
@ -71,16 +68,13 @@ class BaseGraph:
edge_dict = {} edge_dict = {}
for from_node, to_node in edges: for from_node, to_node in edges:
if from_node in edge_dict: edge_dict[from_node.node_name] = to_node.node_name
edge_dict[from_node].append(to_node)
else:
edge_dict[from_node] = [to_node]
return edge_dict return edge_dict
def execute(self, initial_state: dict) -> Tuple[dict, list]: def execute(self, initial_state: dict) -> Tuple[dict, list]:
""" """
Executes the graph by traversing nodes in breadth-first order starting from the entry point. Executes the graph by traversing nodes starting from the entry point. The execution
The execution follows the edges based on the result of each node's execution and continues until follows the edges based on the result of each node's execution and continues until
it reaches a node with no outgoing edges. it reaches a node with no outgoing edges.
Args: Args:
@ -90,6 +84,7 @@ class BaseGraph:
Tuple[dict, list]: A tuple containing the final state and a list of execution info. Tuple[dict, list]: A tuple containing the final state and a list of execution info.
""" """
current_node_name = self.nodes[0]
state = initial_state state = initial_state
# variables for tracking execution info # variables for tracking execution info
@ -103,22 +98,23 @@ class BaseGraph:
"total_cost_USD": 0.0, "total_cost_USD": 0.0,
} }
queue = deque([self.entry_point]) for index in self.nodes:
while queue:
current_node = queue.popleft()
curr_time = time.time() curr_time = time.time()
with get_openai_callback() as callback: current_node = index
with get_openai_callback() as cb:
result = current_node.execute(state) result = current_node.execute(state)
node_exec_time = time.time() - curr_time node_exec_time = time.time() - curr_time
total_exec_time += node_exec_time total_exec_time += node_exec_time
cb = { cb = {
"node_name": current_node.node_name, "node_name": index.node_name,
"total_tokens": callback.total_tokens, "total_tokens": cb.total_tokens,
"prompt_tokens": callback.prompt_tokens, "prompt_tokens": cb.prompt_tokens,
"completion_tokens": callback.completion_tokens, "completion_tokens": cb.completion_tokens,
"successful_requests": callback.successful_requests, "successful_requests": cb.successful_requests,
"total_cost_USD": callback.total_cost, "total_cost_USD": cb.total_cost,
"exec_time": node_exec_time, "exec_time": node_exec_time,
} }
@ -132,31 +128,21 @@ class BaseGraph:
cb_total["successful_requests"] += cb["successful_requests"] cb_total["successful_requests"] += cb["successful_requests"]
cb_total["total_cost_USD"] += cb["total_cost_USD"] cb_total["total_cost_USD"] += cb["total_cost_USD"]
if current_node.node_type == "conditional_node":
current_node_name = result
current_node_connections = self.edges[current_node] elif current_node_name in self.edges:
if current_node.node_type == 'conditional_node': current_node_name = self.edges[current_node_name]
# Assert that there are exactly two out edges from the conditional node else:
if len(current_node_connections) != 2: current_node_name = None
raise ValueError(f"Conditional node should have exactly two out connections {current_node_connections.node_name}")
if result["next_node"] == 0:
queue.append(current_node_connections[0])
else:
queue.append(current_node_connections[1])
# remove the conditional node result
del result["next_node"]
else:
queue.extend(node for node in current_node_connections)
exec_info.append({
"node_name": "TOTAL RESULT",
"total_tokens": cb_total["total_tokens"],
"prompt_tokens": cb_total["prompt_tokens"],
"completion_tokens": cb_total["completion_tokens"],
"successful_requests": cb_total["successful_requests"],
"total_cost_USD": cb_total["total_cost_USD"],
"exec_time": total_exec_time,
})
exec_info.append({ return state, exec_info
"node_name": "TOTAL RESULT",
"total_tokens": cb_total["total_tokens"],
"prompt_tokens": cb_total["prompt_tokens"],
"completion_tokens": cb_total["completion_tokens"],
"successful_requests": cb_total["successful_requests"],
"total_cost_USD": cb_total["total_cost_USD"],
"exec_time": total_exec_time,
})
return state, exec_info

View File

@ -12,7 +12,7 @@ from langchain_core.output_parsers import JsonOutputParser
# Imports from the library # Imports from the library
from .base_node import BaseNode from .base_node import BaseNode
from ..utils import create_graph, add_customizations, create_interactive_graph
class KnowledgeGraphNode(BaseNode): class KnowledgeGraphNode(BaseNode):
""" """
@ -65,31 +65,36 @@ class KnowledgeGraphNode(BaseNode):
user_prompt = input_data[0] user_prompt = input_data[0]
answer_dict = input_data[1] answer_dict = input_data[1]
output_parser = JsonOutputParser() # Build the graph
format_instructions = output_parser.get_format_instructions() graph = create_graph(answer_dict)
# Create the interactive graph
create_interactive_graph(graph, output_file='knowledge_graph.html')
template_merge = """ # output_parser = JsonOutputParser()
You are a website scraper and you have just scraped some content from multiple websites.\n # format_instructions = output_parser.get_format_instructions()
You are now asked to provide an answer to a USER PROMPT based on the content you have scraped.\n
You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n
The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n
OUTPUT INSTRUCTIONS: {format_instructions}\n
USER PROMPT: {user_prompt}\n
WEBSITE CONTENT: {website_content}
"""
prompt_template = PromptTemplate( # template_merge = """
template=template_merge, # You are a website scraper and you have just scraped some content from multiple websites.\n
input_variables=["user_prompt"], # You are now asked to provide an answer to a USER PROMPT based on the content you have scraped.\n
partial_variables={ # You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n
"format_instructions": format_instructions, # The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n
"website_content": answers_str, # OUTPUT INSTRUCTIONS: {format_instructions}\n
}, # USER PROMPT: {user_prompt}\n
) # WEBSITE CONTENT: {website_content}
# """
merge_chain = prompt_template | self.llm_model | output_parser # prompt_template = PromptTemplate(
answer = merge_chain.invoke({"user_prompt": user_prompt}) # template=template_merge,
# input_variables=["user_prompt"],
# partial_variables={
# "format_instructions": format_instructions,
# "website_content": answers_str,
# },
# )
# merge_chain = prompt_template | self.llm_model | output_parser
# answer = merge_chain.invoke({"user_prompt": user_prompt})
# Update the state with the generated answer # Update the state with the generated answer
state.update({self.output[0]: answer}) state.update({self.output[0]: graph})
return state return state

View File

@ -9,3 +9,4 @@ from .proxy_rotation import Proxy, parse_or_search_proxy, search_proxy_servers
from .save_audio_from_bytes import save_audio_from_bytes from .save_audio_from_bytes import save_audio_from_bytes
from .sys_dynamic_import import dynamic_import, srcfile_import from .sys_dynamic_import import dynamic_import, srcfile_import
from .cleanup_html import cleanup_html from .cleanup_html import cleanup_html
from .knowledge_graph import create_graph, add_customizations, create_interactive_graph

View File

@ -0,0 +1,81 @@
import networkx as nx
from pyvis.network import Network
import webbrowser
import os
# Create and visualize graph
def create_graph(job_postings):
graph = nx.DiGraph()
# Add the main "Job Postings" node
graph.add_node("Job Postings")
for company, jobs in job_postings["Job Postings"].items():
# Add company node
graph.add_node(company)
graph.add_edge("Job Postings", company)
# Add job nodes and their details
for idx, job in enumerate(jobs, start=1):
job_id = f"{company}-Job{idx}"
graph.add_node(job_id)
graph.add_edge(company, job_id)
for key, value in job.items():
if isinstance(value, list):
list_node_id = f"{job_id}-{key}"
graph.add_node(list_node_id, label=key)
graph.add_edge(job_id, list_node_id)
for item in value:
detail_id = f"{list_node_id}-{item}"
graph.add_node(detail_id, label=item, title=item)
graph.add_edge(list_node_id, detail_id)
else:
detail_id = f"{job_id}-{key}"
graph.add_node(detail_id, label=key, title=f"{key}: {value}")
graph.add_edge(job_id, detail_id)
return graph
# Add customizations to the network
def add_customizations(net, graph):
node_colors = {}
node_sizes = {}
# Custom colors and sizes for nodes
node_colors["Job Postings"] = '#8470FF'
node_sizes["Job Postings"] = 50
for node in graph.nodes:
if node in node_colors:
continue
if '-' not in node: # Company nodes
node_colors[node] = '#3CB371'
node_sizes[node] = 30
elif '-' in node and node.count('-') == 1: # Job nodes
node_colors[node] = '#FFA07A'
node_sizes[node] = 20
else: # Job detail nodes
node_colors[node] = '#B0C4DE'
node_sizes[node] = 10
# Add nodes and edges to the network with customized styles
for node in graph.nodes:
net.add_node(node,
label=graph.nodes[node].get('label', node.split('-')[-1]),
color=node_colors.get(node, 'lightgray'),
size=node_sizes.get(node, 15),
title=graph.nodes[node].get('title', ''))
for edge in graph.edges:
net.add_edge(edge[0], edge[1])
return net
# Create interactive graph
def create_interactive_graph(graph, output_file='interactive_graph.html'):
net = Network(notebook=False, height='1000px', width='100%', bgcolor='white', font_color='black')
net = add_customizations(net, graph)
net.save_graph(output_file)
# Automatically open the generated HTML file in the default web browser
webbrowser.open(f"file://{os.path.realpath(output_file)}")