feat: multiple graph instances

This commit is contained in:
Marco Perini 2024-05-05 23:51:04 +02:00
parent 1c4ba91620
commit dbb614a8dd
16 changed files with 248 additions and 63 deletions

View File

@ -6,8 +6,8 @@ import os
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
from scrapegraphai.models import OpenAI
from scrapegraphai.graphs import BaseGraph
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode, SearchInternetNode
from scrapegraphai.graphs import BaseGraph, SmartScraperGraph
from scrapegraphai.nodes import SearchInternetNode, GraphIteratorNode, MergeAnswersNode
load_dotenv()
# ************************************************
@ -23,6 +23,16 @@ graph_config = {
},
}
# ************************************************
# Create a SmartScraperGraph instance
# ************************************************
smart_scraper_graph = SmartScraperGraph(
prompt="",
source="",
config=graph_config
)
# ************************************************
# Define the graph nodes
# ************************************************
@ -32,38 +42,24 @@ embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key)
search_internet_node = SearchInternetNode(
input="user_prompt",
output=["url"],
node_config={
"llm_model": llm_model
}
)
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
node_config={
"verbose": True,
"headless": True,
}
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": 4096,
"verbose": True,
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
output=["urls"],
node_config={
"llm_model": llm_model,
"embedder_model": embedder,
"verbose": True,
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
graph_iterator_node = GraphIteratorNode(
input="user_prompt & urls",
output=["results"],
node_config={
"graph_instance": smart_scraper_graph,
"verbose": True,
}
)
merge_answers_node = MergeAnswersNode(
input="user_prompt & results",
output=["answer"],
node_config={
"llm_model": llm_model,
@ -78,16 +74,12 @@ generate_answer_node = GenerateAnswerNode(
graph = BaseGraph(
nodes=[
search_internet_node,
fetch_node,
parse_node,
rag_node,
generate_answer_node,
graph_iterator_node,
merge_answers_node
],
edges=[
(search_internet_node, fetch_node),
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
(search_internet_node, graph_iterator_node),
(graph_iterator_node, merge_answers_node)
],
entry_point=search_internet_node
)

View File

@ -17,3 +17,5 @@ from .search_link_node import SearchLinkNode
from .robots_node import RobotsNode
from .generate_answer_csv_node import GenerateAnswerCSVNode
from .generate_answer_pdf_node import GenerateAnswerPDFNode
from .graph_iterator_node import GraphIteratorNode
from .merge_answers_node import MergeAnswersNode

View File

@ -2,7 +2,7 @@
Module for generating the answer node
"""
# Imports from standard library
from typing import List
from typing import List, Optional
from tqdm import tqdm
# Imports from Langchain
@ -39,7 +39,7 @@ class GenerateAnswerCSVNode(BaseNode):
updating the state with the generated answer under the 'answer' key.
"""
def __init__(self, input: str, output: List[str], node_config: dict,
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
node_name: str = "GenerateAnswer"):
"""
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.

View File

@ -3,7 +3,7 @@ GenerateAnswerNode Module
"""
# Imports from standard library
from typing import List
from typing import List, Optional
from tqdm import tqdm
# Imports from Langchain
@ -33,7 +33,7 @@ class GenerateAnswerNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
"""
def __init__(self, input: str, output: List[str], node_config: dict,
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
node_name: str = "GenerateAnswer"):
super().__init__(node_name, "node", input, output, 2, node_config)

View File

@ -2,7 +2,7 @@
Module for generating the answer node
"""
# Imports from standard library
from typing import List
from typing import List, Optional
from tqdm import tqdm
# Imports from Langchain
@ -39,7 +39,7 @@ class GenerateAnswerCSVNode(BaseNode):
updating the state with the generated answer under the 'answer' key.
"""
def __init__(self, input: str, output: List[str], node_config: dict,
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
node_name: str = "GenerateAnswer"):
"""
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.

View File

@ -2,7 +2,7 @@
Module for generating the answer node
"""
# Imports from standard library
from typing import List
from typing import List, Optional
from tqdm import tqdm
# Imports from Langchain
@ -39,7 +39,7 @@ class GenerateAnswerPDFNode(BaseNode):
updating the state with the generated answer under the 'answer' key.
"""
def __init__(self, input: str, output: List[str], node_config: dict,
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
node_name: str = "GenerateAnswer"):
"""
Initializes the GenerateAnswerNodePDF with a language model client and a node name.

View File

@ -3,7 +3,7 @@ GenerateScraperNode Module
"""
# Imports from standard library
from typing import List
from typing import List, Optional
from tqdm import tqdm
# Imports from Langchain
@ -36,8 +36,8 @@ class GenerateScraperNode(BaseNode):
"""
def __init__(self, input: str, output: List[str], node_config: dict,
library: str, website: str, node_name: str = "GenerateAnswer"):
def __init__(self, input: str, output: List[str], library: str, website: str,
node_config: Optional[dict]=None, node_name: str = "GenerateAnswer"):
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm_model"]

View File

@ -2,7 +2,7 @@
GetProbableTagsNode Module
"""
from typing import List
from typing import List, Optional
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate
from .base_node import BaseNode

View File

@ -0,0 +1,83 @@
"""
GraphIterator Module
"""
from typing import List, Optional
import copy
from tqdm import tqdm
from .base_node import BaseNode
class GraphIteratorNode(BaseNode):
"""
A node responsible for parsing HTML content from a document.
The parsed content is split into chunks for further processing.
This node enhances the scraping workflow by allowing for targeted extraction of
content, thereby optimizing the processing of large HTML documents.
Attributes:
verbose (bool): A flag indicating whether to show print statements during execution.
Args:
input (str): Boolean expression defining the input keys needed from the state.
output (List[str]): List of output keys to be updated in the state.
node_config (dict): Additional configuration for the node.
node_name (str): The unique identifier name for the node, defaulting to "Parse".
"""
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "GraphIterator"):
super().__init__(node_name, "node", input, output, 2, node_config)
self.verbose = False if node_config is None else node_config.get("verbose", False)
def execute(self, state: dict) -> dict:
"""
Executes the node's logic to parse the HTML document content and split it into chunks.
Args:
state (dict): The current state of the graph. The input keys will be used to fetch the
correct data from the state.
Returns:
dict: The updated state with the output key containing the parsed content chunks.
Raises:
KeyError: If the input keys are not found in the state, indicating that the
necessary information for parsing the content is missing.
"""
if self.verbose:
print(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0]
urls = input_data[1]
graph_instance = self.node_config.get("graph_instance", None)
if graph_instance is None:
raise ValueError("Graph instance is required for graph iteration.")
# set the prompt and source for each url
graph_instance.prompt = user_prompt
graphs_instances = []
for url in urls:
# make a copy of the graph instance
copy_graph_instance = copy.copy(graph_instance)
copy_graph_instance.source = url
graphs_instances.append(copy_graph_instance)
# run the graph for each url and use tqdm for progress bar
graphs_answers = []
for graph in tqdm(graphs_instances, desc="Processing Graph Instances", disable=not self.verbose):
result = graph.run()
graphs_answers.append(result)
state.update({self.output[0]: graphs_answers})
return state

View File

@ -2,7 +2,7 @@
ImageToTextNode Module
"""
from typing import List
from typing import List, Optional
from .base_node import BaseNode
@ -21,7 +21,7 @@ class ImageToTextNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "ImageToText".
"""
def __init__(self, input: str, output: List[str], node_config: dict,
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
node_name: str = "ImageToText"):
super().__init__(node_name, "node", input, output, 1, node_config)

View File

@ -0,0 +1,104 @@
"""
MergeAnswersNode Module
"""
# Imports from standard library
from typing import List, Optional
from tqdm import tqdm
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
# Imports from the library
from .base_node import BaseNode
class MergeAnswersNode(BaseNode):
"""
A node that generates an answer using a large language model (LLM) based on the user's input
and the content extracted from a webpage. It constructs a prompt from the user's input
and the scraped content, feeds it to the LLM, and parses the LLM's response to produce
an answer.
Attributes:
llm_model: An instance of a language model client, configured for generating answers.
verbose (bool): A flag indicating whether to show print statements during execution.
Args:
input (str): Boolean expression defining the input keys needed from the state.
output (List[str]): List of output keys to be updated in the state.
node_config (dict): Additional configuration for the node.
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
"""
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
node_name: str = "MergeAnswers"):
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm_model"]
self.verbose = True if node_config is None else node_config.get("verbose", False)
def execute(self, state: dict) -> dict:
"""
Generates an answer by constructing a prompt from the user's input and the scraped
content, querying the language model, and parsing its response.
Args:
state (dict): The current state of the graph. The input keys will be used
to fetch the correct data from the state.
Returns:
dict: The updated state with the output key containing the generated answer.
Raises:
KeyError: If the input keys are not found in the state, indicating
that the necessary information for generating an answer is missing.
"""
if self.verbose:
print(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0]
answers = input_data[1]
# merge the answers in one string
answers_str = ""
for i, answer in enumerate(answers):
answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n"
output_parser = JsonOutputParser()
format_instructions = output_parser.get_format_instructions()
template_merge = """
You are a website scraper and you have just scraped some content from multiple websites.\n
You are now asked to provide an answer to a USER PROMPT based on the content you have scraped.\n
You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n
The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n
OUTPUT INSTRUCTIONS: {format_instructions}\n
USER PROMPT: {user_prompt}\n
{website_content}
"""
prompt_template = PromptTemplate(
template=template_merge,
input_variables=["user_prompt"],
partial_variables={
"format_instructions": format_instructions,
"website_content": answers_str,
},
)
merge_chain = prompt_template | self.llm_model | output_parser
answer = merge_chain.invoke({"user_prompt": user_prompt})
# Update the state with the generated answer
state.update({self.output[0]: answer})
return state

View File

@ -2,7 +2,7 @@
ParseNode Module
"""
from typing import List
from typing import List, Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_transformers import Html2TextTransformer
from .base_node import BaseNode
@ -26,7 +26,7 @@ class ParseNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "Parse".
"""
def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "Parse"):
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "Parse"):
super().__init__(node_name, "node", input, output, 1, node_config)
self.verbose = True if node_config is None else node_config.get("verbose", False)

View File

@ -2,7 +2,7 @@
RAGNode Module
"""
from typing import List
from typing import List, Optional
from langchain.docstore.document import Document
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
@ -31,7 +31,7 @@ class RAGNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "Parse".
"""
def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "RAG"):
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "RAG"):
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm_model"]

View File

@ -2,7 +2,7 @@
SearchInternetNode Module
"""
from typing import List
from typing import List, Optional
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate
from ..utils.research_web import search_on_web
@ -27,12 +27,13 @@ class SearchInternetNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "SearchInternet".
"""
def __init__(self, input: str, output: List[str], node_config: dict,
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
node_name: str = "SearchInternet"):
super().__init__(node_name, "node", input, output, 1, node_config)
self.llm_model = node_config["llm_model"]
self.verbose = True if node_config is None else node_config.get("verbose", False)
self.max_results = node_config.get("max_results", 3)
def execute(self, state: dict) -> dict:
"""
@ -85,8 +86,11 @@ class SearchInternetNode(BaseNode):
if self.verbose:
print(f"Search Query: {search_query}")
# TODO: handle multiple URLs
answer = search_on_web(query=search_query, max_results=1)[0]
answer = search_on_web(query=search_query, max_results=self.max_results)
if len(answer) == 0:
# raise an exception if no answer is found
raise ValueError("Zero results found for the search query.")
# Update the state with the generated answer
state.update({self.output[0]: answer})

View File

@ -3,7 +3,7 @@ SearchLinkNode Module
"""
# Imports from standard library
from typing import List
from typing import List, Optional
from tqdm import tqdm
from bs4 import BeautifulSoup
@ -33,7 +33,7 @@ class SearchLinkNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
"""
def __init__(self, input: str, output: List[str], node_config: dict,
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
node_name: str = "GenerateLinks"):
super().__init__(node_name, "node", input, output, 1, node_config)

View File

@ -2,7 +2,7 @@
TextToSpeechNode Module
"""
from typing import List
from typing import List, Optional
from .base_node import BaseNode
@ -22,7 +22,7 @@ class TextToSpeechNode(BaseNode):
"""
def __init__(self, input: str, output: List[str],
node_config: dict, node_name: str = "TextToSpeech"):
node_config: Optional[dict]=None, node_name: str = "TextToSpeech"):
super().__init__(node_name, "node", input, output, 1, node_config)
self.tts_model = node_config["tts_model"]