diff --git a/.gitignore b/.gitignore index b8ab5703..8b9c55dd 100644 --- a/.gitignore +++ b/.gitignore @@ -32,5 +32,6 @@ examples/graph_examples/ScrapeGraphAI_generated_graph examples/**/result.csv examples/**/result.json main.py - +lib/ +*.html \ No newline at end of file diff --git a/examples/single_node/kg_node.py b/examples/single_node/kg_node.py index d434b6af..a25d8eda 100644 --- a/examples/single_node/kg_node.py +++ b/examples/single_node/kg_node.py @@ -50,6 +50,7 @@ graph_config = { "model": "gpt-4o", "temperature": 0, }, + "verbose": True, } # ************************************************ @@ -59,11 +60,9 @@ graph_config = { llm_model = OpenAI(graph_config["llm"]) robots_node = KnowledgeGraphNode( - input="answer & user_prompt", + input="user_prompt & answer_dict", output=["is_scrapable"], - node_config={"llm_model": llm_model, - "headless": False - } + node_config={"llm_model": llm_model} ) # ************************************************ @@ -71,7 +70,8 @@ robots_node = KnowledgeGraphNode( # ************************************************ state = { - "url": "https://twitter.com/home" + "user_prompt": "What are the job postings?", + "answer_dict": job_postings } result = robots_node.execute(state) diff --git a/pyproject.toml b/pyproject.toml index 19c714e8..e49c6a63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,8 @@ dependencies = [ "playwright==1.43.0", "google==3.0.0", "yahoo-search-py==0.3", + "networkx==3.3", + "pyvis==0.3.2", ] license = "MIT" diff --git a/requirements-dev.lock b/requirements-dev.lock index 18155637..84a8a445 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -22,6 +22,8 @@ anyio==4.3.0 # via groq # via httpx # via openai +asttokens==2.4.1 + # via stack-data async-timeout==4.0.3 # via aiohttp # via langchain @@ -43,9 +45,15 @@ certifi==2024.2.2 # via requests charset-normalizer==3.3.2 # via requests +colorama==0.4.6 + # via ipython + # via pytest + # via tqdm dataclasses-json==0.6.6 # via langchain # via langchain-community +decorator==5.1.1 + # via ipython defusedxml==0.7.1 # via langchain-anthropic distro==1.9.0 @@ -54,7 +62,10 @@ distro==1.9.0 # via openai exceptiongroup==1.2.1 # via anyio + # via ipython # via pytest +executing==2.0.1 + # via stack-data faiss-cpu==1.8.0 # via scrapegraphai filelock==3.14.0 @@ -93,6 +104,7 @@ graphviz==0.20.3 # via scrapegraphai greenlet==3.0.3 # via playwright + # via sqlalchemy groq==0.5.0 # via langchain-groq grpcio==1.63.0 @@ -123,12 +135,20 @@ idna==3.7 # via yarl iniconfig==2.0.0 # via pytest +ipython==8.24.0 + # via pyvis +jedi==0.19.1 + # via ipython +jinja2==3.1.4 + # via pyvis jmespath==1.0.1 # via boto3 # via botocore jsonpatch==1.33 # via langchain # via langchain-core +jsonpickle==3.0.4 + # via pyvis jsonpointer==2.4 # via jsonpatch langchain==0.1.15 @@ -162,8 +182,12 @@ langsmith==0.1.58 # via langchain-core lxml==5.2.2 # via free-proxy +markupsafe==2.1.5 + # via jinja2 marshmallow==3.21.2 # via dataclasses-json +matplotlib-inline==0.1.7 + # via ipython minify-html==0.15.0 # via scrapegraphai multidict==6.0.5 @@ -171,6 +195,9 @@ multidict==6.0.5 # via yarl mypy-extensions==1.0.0 # via typing-inspect +networkx==3.3 + # via pyvis + # via scrapegraphai numpy==1.26.4 # via faiss-cpu # via langchain @@ -188,10 +215,14 @@ packaging==23.2 # via pytest pandas==2.2.2 # via scrapegraphai +parso==0.8.4 + # via jedi playwright==1.43.0 # via scrapegraphai pluggy==1.5.0 # via pytest +prompt-toolkit==3.0.43 + # via ipython proto-plus==1.23.0 # via google-ai-generativelanguage # via google-api-core @@ -202,6 +233,8 @@ protobuf==4.25.3 # via googleapis-common-protos # via grpcio-status # via proto-plus +pure-eval==0.2.2 + # via stack-data pyasn1==0.6.0 # via pyasn1-modules # via rsa @@ -220,6 +253,8 @@ pydantic-core==2.18.2 # via pydantic pyee==11.1.0 # via playwright +pygments==2.18.0 + # via ipython pyparsing==3.1.2 # via httplib2 pytest==8.0.0 @@ -232,6 +267,8 @@ python-dotenv==1.0.1 # via scrapegraphai pytz==2024.1 # via pandas +pyvis==0.3.2 + # via scrapegraphai pyyaml==6.0.1 # via huggingface-hub # via langchain @@ -254,6 +291,7 @@ s3transfer==0.10.1 selectolax==0.3.21 # via yahoo-search-py six==1.16.0 + # via asttokens # via python-dateutil sniffio==1.3.1 # via anthropic @@ -266,6 +304,8 @@ soupsieve==2.5 sqlalchemy==2.0.30 # via langchain # via langchain-community +stack-data==0.6.3 + # via ipython tenacity==8.3.0 # via langchain # via langchain-community @@ -282,12 +322,16 @@ tqdm==4.66.4 # via huggingface-hub # via openai # via scrapegraphai +traitlets==5.14.3 + # via ipython + # via matplotlib-inline typing-extensions==4.11.0 # via anthropic # via anyio # via google-generativeai # via groq # via huggingface-hub + # via ipython # via openai # via pydantic # via pydantic-core @@ -304,6 +348,8 @@ urllib3==2.2.1 # via botocore # via requests # via yahoo-search-py +wcwidth==0.2.13 + # via prompt-toolkit yahoo-search-py==0.3 # via scrapegraphai yarl==1.9.4 diff --git a/requirements.lock b/requirements.lock index f6381059..f33598cf 100644 --- a/requirements.lock +++ b/requirements.lock @@ -22,6 +22,8 @@ anyio==4.3.0 # via groq # via httpx # via openai +asttokens==2.4.1 + # via stack-data async-timeout==4.0.3 # via aiohttp # via langchain @@ -43,9 +45,14 @@ certifi==2024.2.2 # via requests charset-normalizer==3.3.2 # via requests +colorama==0.4.6 + # via ipython + # via tqdm dataclasses-json==0.6.6 # via langchain # via langchain-community +decorator==5.1.1 + # via ipython defusedxml==0.7.1 # via langchain-anthropic distro==1.9.0 @@ -54,6 +61,9 @@ distro==1.9.0 # via openai exceptiongroup==1.2.1 # via anyio + # via ipython +executing==2.0.1 + # via stack-data faiss-cpu==1.8.0 # via scrapegraphai filelock==3.14.0 @@ -92,6 +102,7 @@ graphviz==0.20.3 # via scrapegraphai greenlet==3.0.3 # via playwright + # via sqlalchemy groq==0.5.0 # via langchain-groq grpcio==1.63.0 @@ -120,12 +131,20 @@ idna==3.7 # via httpx # via requests # via yarl +ipython==8.24.0 + # via pyvis +jedi==0.19.1 + # via ipython +jinja2==3.1.4 + # via pyvis jmespath==1.0.1 # via boto3 # via botocore jsonpatch==1.33 # via langchain # via langchain-core +jsonpickle==3.0.4 + # via pyvis jsonpointer==2.4 # via jsonpatch langchain==0.1.15 @@ -159,8 +178,12 @@ langsmith==0.1.58 # via langchain-core lxml==5.2.2 # via free-proxy +markupsafe==2.1.5 + # via jinja2 marshmallow==3.21.2 # via dataclasses-json +matplotlib-inline==0.1.7 + # via ipython minify-html==0.15.0 # via scrapegraphai multidict==6.0.5 @@ -168,6 +191,9 @@ multidict==6.0.5 # via yarl mypy-extensions==1.0.0 # via typing-inspect +networkx==3.3 + # via pyvis + # via scrapegraphai numpy==1.26.4 # via faiss-cpu # via langchain @@ -184,8 +210,12 @@ packaging==23.2 # via marshmallow pandas==2.2.2 # via scrapegraphai +parso==0.8.4 + # via jedi playwright==1.43.0 # via scrapegraphai +prompt-toolkit==3.0.43 + # via ipython proto-plus==1.23.0 # via google-ai-generativelanguage # via google-api-core @@ -196,6 +226,8 @@ protobuf==4.25.3 # via googleapis-common-protos # via grpcio-status # via proto-plus +pure-eval==0.2.2 + # via stack-data pyasn1==0.6.0 # via pyasn1-modules # via rsa @@ -214,6 +246,8 @@ pydantic-core==2.18.2 # via pydantic pyee==11.1.0 # via playwright +pygments==2.18.0 + # via ipython pyparsing==3.1.2 # via httplib2 python-dateutil==2.9.0.post0 @@ -223,6 +257,8 @@ python-dotenv==1.0.1 # via scrapegraphai pytz==2024.1 # via pandas +pyvis==0.3.2 + # via scrapegraphai pyyaml==6.0.1 # via huggingface-hub # via langchain @@ -245,6 +281,7 @@ s3transfer==0.10.1 selectolax==0.3.21 # via yahoo-search-py six==1.16.0 + # via asttokens # via python-dateutil sniffio==1.3.1 # via anthropic @@ -257,6 +294,8 @@ soupsieve==2.5 sqlalchemy==2.0.30 # via langchain # via langchain-community +stack-data==0.6.3 + # via ipython tenacity==8.3.0 # via langchain # via langchain-community @@ -271,12 +310,16 @@ tqdm==4.66.4 # via huggingface-hub # via openai # via scrapegraphai +traitlets==5.14.3 + # via ipython + # via matplotlib-inline typing-extensions==4.11.0 # via anthropic # via anyio # via google-generativeai # via groq # via huggingface-hub + # via ipython # via openai # via pydantic # via pydantic-core @@ -293,6 +336,8 @@ urllib3==2.2.1 # via botocore # via requests # via yahoo-search-py +wcwidth==0.2.13 + # via prompt-toolkit yahoo-search-py==0.3 # via scrapegraphai yarl==1.9.4 diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index ed5ba54f..7c4df3d8 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -6,7 +6,6 @@ import time import warnings from langchain_community.callbacks import get_openai_callback from typing import Tuple -from collections import deque class BaseGraph: @@ -27,8 +26,6 @@ class BaseGraph: Raises: Warning: If the entry point node is not the first node in the list. - ValueError: If conditional_node does not have exactly two outgoing edges - Example: >>> BaseGraph( @@ -51,7 +48,7 @@ class BaseGraph: self.nodes = nodes self.edges = self._create_edges({e for e in edges}) - self.entry_point = entry_point + self.entry_point = entry_point.node_name if nodes[0].node_name != entry_point.node_name: # raise a warning if the entry point is not the first node in the list @@ -71,16 +68,13 @@ class BaseGraph: edge_dict = {} for from_node, to_node in edges: - if from_node in edge_dict: - edge_dict[from_node].append(to_node) - else: - edge_dict[from_node] = [to_node] + edge_dict[from_node.node_name] = to_node.node_name return edge_dict def execute(self, initial_state: dict) -> Tuple[dict, list]: """ - Executes the graph by traversing nodes in breadth-first order starting from the entry point. - The execution follows the edges based on the result of each node's execution and continues until + Executes the graph by traversing nodes starting from the entry point. The execution + follows the edges based on the result of each node's execution and continues until it reaches a node with no outgoing edges. Args: @@ -90,6 +84,7 @@ class BaseGraph: Tuple[dict, list]: A tuple containing the final state and a list of execution info. """ + current_node_name = self.nodes[0] state = initial_state # variables for tracking execution info @@ -103,22 +98,23 @@ class BaseGraph: "total_cost_USD": 0.0, } - queue = deque([self.entry_point]) - while queue: - current_node = queue.popleft() + for index in self.nodes: + curr_time = time.time() - with get_openai_callback() as callback: + current_node = index + + with get_openai_callback() as cb: result = current_node.execute(state) node_exec_time = time.time() - curr_time total_exec_time += node_exec_time cb = { - "node_name": current_node.node_name, - "total_tokens": callback.total_tokens, - "prompt_tokens": callback.prompt_tokens, - "completion_tokens": callback.completion_tokens, - "successful_requests": callback.successful_requests, - "total_cost_USD": callback.total_cost, + "node_name": index.node_name, + "total_tokens": cb.total_tokens, + "prompt_tokens": cb.prompt_tokens, + "completion_tokens": cb.completion_tokens, + "successful_requests": cb.successful_requests, + "total_cost_USD": cb.total_cost, "exec_time": node_exec_time, } @@ -132,31 +128,21 @@ class BaseGraph: cb_total["successful_requests"] += cb["successful_requests"] cb_total["total_cost_USD"] += cb["total_cost_USD"] - - - current_node_connections = self.edges[current_node] - if current_node.node_type == 'conditional_node': - # Assert that there are exactly two out edges from the conditional node - if len(current_node_connections) != 2: - raise ValueError(f"Conditional node should have exactly two out connections {current_node_connections.node_name}") - if result["next_node"] == 0: - queue.append(current_node_connections[0]) - else: - queue.append(current_node_connections[1]) - # remove the conditional node result - del result["next_node"] - else: - queue.extend(node for node in current_node_connections) + if current_node.node_type == "conditional_node": + current_node_name = result + elif current_node_name in self.edges: + current_node_name = self.edges[current_node_name] + else: + current_node_name = None + exec_info.append({ + "node_name": "TOTAL RESULT", + "total_tokens": cb_total["total_tokens"], + "prompt_tokens": cb_total["prompt_tokens"], + "completion_tokens": cb_total["completion_tokens"], + "successful_requests": cb_total["successful_requests"], + "total_cost_USD": cb_total["total_cost_USD"], + "exec_time": total_exec_time, + }) - exec_info.append({ - "node_name": "TOTAL RESULT", - "total_tokens": cb_total["total_tokens"], - "prompt_tokens": cb_total["prompt_tokens"], - "completion_tokens": cb_total["completion_tokens"], - "successful_requests": cb_total["successful_requests"], - "total_cost_USD": cb_total["total_cost_USD"], - "exec_time": total_exec_time, - }) - - return state, exec_info + return state, exec_info \ No newline at end of file diff --git a/scrapegraphai/nodes/knowledge_graph_node.py b/scrapegraphai/nodes/knowledge_graph_node.py index 9181ca80..8f040b5e 100644 --- a/scrapegraphai/nodes/knowledge_graph_node.py +++ b/scrapegraphai/nodes/knowledge_graph_node.py @@ -12,7 +12,7 @@ from langchain_core.output_parsers import JsonOutputParser # Imports from the library from .base_node import BaseNode - +from ..utils import create_graph, add_customizations, create_interactive_graph class KnowledgeGraphNode(BaseNode): """ @@ -65,31 +65,36 @@ class KnowledgeGraphNode(BaseNode): user_prompt = input_data[0] answer_dict = input_data[1] - output_parser = JsonOutputParser() - format_instructions = output_parser.get_format_instructions() + # Build the graph + graph = create_graph(answer_dict) + # Create the interactive graph + create_interactive_graph(graph, output_file='knowledge_graph.html') - template_merge = """ - You are a website scraper and you have just scraped some content from multiple websites.\n - You are now asked to provide an answer to a USER PROMPT based on the content you have scraped.\n - You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n - The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n - OUTPUT INSTRUCTIONS: {format_instructions}\n - USER PROMPT: {user_prompt}\n - WEBSITE CONTENT: {website_content} - """ + # output_parser = JsonOutputParser() + # format_instructions = output_parser.get_format_instructions() - prompt_template = PromptTemplate( - template=template_merge, - input_variables=["user_prompt"], - partial_variables={ - "format_instructions": format_instructions, - "website_content": answers_str, - }, - ) + # template_merge = """ + # You are a website scraper and you have just scraped some content from multiple websites.\n + # You are now asked to provide an answer to a USER PROMPT based on the content you have scraped.\n + # You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n + # The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n + # OUTPUT INSTRUCTIONS: {format_instructions}\n + # USER PROMPT: {user_prompt}\n + # WEBSITE CONTENT: {website_content} + # """ - merge_chain = prompt_template | self.llm_model | output_parser - answer = merge_chain.invoke({"user_prompt": user_prompt}) + # prompt_template = PromptTemplate( + # template=template_merge, + # input_variables=["user_prompt"], + # partial_variables={ + # "format_instructions": format_instructions, + # "website_content": answers_str, + # }, + # ) + + # merge_chain = prompt_template | self.llm_model | output_parser + # answer = merge_chain.invoke({"user_prompt": user_prompt}) # Update the state with the generated answer - state.update({self.output[0]: answer}) + state.update({self.output[0]: graph}) return state diff --git a/scrapegraphai/utils/__init__.py b/scrapegraphai/utils/__init__.py index 72a8b96c..eced80ea 100644 --- a/scrapegraphai/utils/__init__.py +++ b/scrapegraphai/utils/__init__.py @@ -9,3 +9,4 @@ from .proxy_rotation import Proxy, parse_or_search_proxy, search_proxy_servers from .save_audio_from_bytes import save_audio_from_bytes from .sys_dynamic_import import dynamic_import, srcfile_import from .cleanup_html import cleanup_html +from .knowledge_graph import create_graph, add_customizations, create_interactive_graph \ No newline at end of file diff --git a/scrapegraphai/utils/knowledge_graph.py b/scrapegraphai/utils/knowledge_graph.py new file mode 100644 index 00000000..1b6682aa --- /dev/null +++ b/scrapegraphai/utils/knowledge_graph.py @@ -0,0 +1,81 @@ +import networkx as nx +from pyvis.network import Network +import webbrowser +import os + +# Create and visualize graph +def create_graph(job_postings): + graph = nx.DiGraph() + + # Add the main "Job Postings" node + graph.add_node("Job Postings") + + for company, jobs in job_postings["Job Postings"].items(): + # Add company node + graph.add_node(company) + graph.add_edge("Job Postings", company) + + # Add job nodes and their details + for idx, job in enumerate(jobs, start=1): + job_id = f"{company}-Job{idx}" + graph.add_node(job_id) + graph.add_edge(company, job_id) + + for key, value in job.items(): + if isinstance(value, list): + list_node_id = f"{job_id}-{key}" + graph.add_node(list_node_id, label=key) + graph.add_edge(job_id, list_node_id) + for item in value: + detail_id = f"{list_node_id}-{item}" + graph.add_node(detail_id, label=item, title=item) + graph.add_edge(list_node_id, detail_id) + else: + detail_id = f"{job_id}-{key}" + graph.add_node(detail_id, label=key, title=f"{key}: {value}") + graph.add_edge(job_id, detail_id) + + return graph + +# Add customizations to the network +def add_customizations(net, graph): + node_colors = {} + node_sizes = {} + + # Custom colors and sizes for nodes + node_colors["Job Postings"] = '#8470FF' + node_sizes["Job Postings"] = 50 + + for node in graph.nodes: + if node in node_colors: + continue + if '-' not in node: # Company nodes + node_colors[node] = '#3CB371' + node_sizes[node] = 30 + elif '-' in node and node.count('-') == 1: # Job nodes + node_colors[node] = '#FFA07A' + node_sizes[node] = 20 + else: # Job detail nodes + node_colors[node] = '#B0C4DE' + node_sizes[node] = 10 + + # Add nodes and edges to the network with customized styles + for node in graph.nodes: + net.add_node(node, + label=graph.nodes[node].get('label', node.split('-')[-1]), + color=node_colors.get(node, 'lightgray'), + size=node_sizes.get(node, 15), + title=graph.nodes[node].get('title', '')) + for edge in graph.edges: + net.add_edge(edge[0], edge[1]) + return net + +# Create interactive graph +def create_interactive_graph(graph, output_file='interactive_graph.html'): + net = Network(notebook=False, height='1000px', width='100%', bgcolor='white', font_color='black') + net = add_customizations(net, graph) + net.save_graph(output_file) + + # Automatically open the generated HTML file in the default web browser + webbrowser.open(f"file://{os.path.realpath(output_file)}") +