mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-25 21:11:11 +08:00
feat: multiple graph instances
This commit is contained in:
parent
1c4ba91620
commit
dbb614a8dd
@ -6,8 +6,8 @@ import os
|
||||
from dotenv import load_dotenv
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from scrapegraphai.models import OpenAI
|
||||
from scrapegraphai.graphs import BaseGraph
|
||||
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode, SearchInternetNode
|
||||
from scrapegraphai.graphs import BaseGraph, SmartScraperGraph
|
||||
from scrapegraphai.nodes import SearchInternetNode, GraphIteratorNode, MergeAnswersNode
|
||||
load_dotenv()
|
||||
|
||||
# ************************************************
|
||||
@ -23,6 +23,16 @@ graph_config = {
|
||||
},
|
||||
}
|
||||
|
||||
# ************************************************
|
||||
# Create a SmartScraperGraph instance
|
||||
# ************************************************
|
||||
|
||||
smart_scraper_graph = SmartScraperGraph(
|
||||
prompt="",
|
||||
source="",
|
||||
config=graph_config
|
||||
)
|
||||
|
||||
# ************************************************
|
||||
# Define the graph nodes
|
||||
# ************************************************
|
||||
@ -32,38 +42,24 @@ embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key)
|
||||
|
||||
search_internet_node = SearchInternetNode(
|
||||
input="user_prompt",
|
||||
output=["url"],
|
||||
node_config={
|
||||
"llm_model": llm_model
|
||||
}
|
||||
)
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"],
|
||||
node_config={
|
||||
"verbose": True,
|
||||
"headless": True,
|
||||
}
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={
|
||||
"chunk_size": 4096,
|
||||
"verbose": True,
|
||||
}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
output=["urls"],
|
||||
node_config={
|
||||
"llm_model": llm_model,
|
||||
"embedder_model": embedder,
|
||||
"verbose": True,
|
||||
}
|
||||
)
|
||||
generate_answer_node = GenerateAnswerNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
|
||||
graph_iterator_node = GraphIteratorNode(
|
||||
input="user_prompt & urls",
|
||||
output=["results"],
|
||||
node_config={
|
||||
"graph_instance": smart_scraper_graph,
|
||||
"verbose": True,
|
||||
}
|
||||
)
|
||||
|
||||
merge_answers_node = MergeAnswersNode(
|
||||
input="user_prompt & results",
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": llm_model,
|
||||
@ -78,16 +74,12 @@ generate_answer_node = GenerateAnswerNode(
|
||||
graph = BaseGraph(
|
||||
nodes=[
|
||||
search_internet_node,
|
||||
fetch_node,
|
||||
parse_node,
|
||||
rag_node,
|
||||
generate_answer_node,
|
||||
graph_iterator_node,
|
||||
merge_answers_node
|
||||
],
|
||||
edges=[
|
||||
(search_internet_node, fetch_node),
|
||||
(fetch_node, parse_node),
|
||||
(parse_node, rag_node),
|
||||
(rag_node, generate_answer_node)
|
||||
(search_internet_node, graph_iterator_node),
|
||||
(graph_iterator_node, merge_answers_node)
|
||||
],
|
||||
entry_point=search_internet_node
|
||||
)
|
||||
@ -17,3 +17,5 @@ from .search_link_node import SearchLinkNode
|
||||
from .robots_node import RobotsNode
|
||||
from .generate_answer_csv_node import GenerateAnswerCSVNode
|
||||
from .generate_answer_pdf_node import GenerateAnswerPDFNode
|
||||
from .graph_iterator_node import GraphIteratorNode
|
||||
from .merge_answers_node import MergeAnswersNode
|
||||
@ -2,7 +2,7 @@
|
||||
Module for generating the answer node
|
||||
"""
|
||||
# Imports from standard library
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from tqdm import tqdm
|
||||
|
||||
# Imports from Langchain
|
||||
@ -39,7 +39,7 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
updating the state with the generated answer under the 'answer' key.
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
|
||||
node_name: str = "GenerateAnswer"):
|
||||
"""
|
||||
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
|
||||
|
||||
@ -3,7 +3,7 @@ GenerateAnswerNode Module
|
||||
"""
|
||||
|
||||
# Imports from standard library
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from tqdm import tqdm
|
||||
|
||||
# Imports from Langchain
|
||||
@ -33,7 +33,7 @@ class GenerateAnswerNode(BaseNode):
|
||||
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
|
||||
node_name: str = "GenerateAnswer"):
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
Module for generating the answer node
|
||||
"""
|
||||
# Imports from standard library
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from tqdm import tqdm
|
||||
|
||||
# Imports from Langchain
|
||||
@ -39,7 +39,7 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
updating the state with the generated answer under the 'answer' key.
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
|
||||
node_name: str = "GenerateAnswer"):
|
||||
"""
|
||||
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
Module for generating the answer node
|
||||
"""
|
||||
# Imports from standard library
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from tqdm import tqdm
|
||||
|
||||
# Imports from Langchain
|
||||
@ -39,7 +39,7 @@ class GenerateAnswerPDFNode(BaseNode):
|
||||
updating the state with the generated answer under the 'answer' key.
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
|
||||
node_name: str = "GenerateAnswer"):
|
||||
"""
|
||||
Initializes the GenerateAnswerNodePDF with a language model client and a node name.
|
||||
|
||||
@ -3,7 +3,7 @@ GenerateScraperNode Module
|
||||
"""
|
||||
|
||||
# Imports from standard library
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from tqdm import tqdm
|
||||
|
||||
# Imports from Langchain
|
||||
@ -36,8 +36,8 @@ class GenerateScraperNode(BaseNode):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
||||
library: str, website: str, node_name: str = "GenerateAnswer"):
|
||||
def __init__(self, input: str, output: List[str], library: str, website: str,
|
||||
node_config: Optional[dict]=None, node_name: str = "GenerateAnswer"):
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
|
||||
self.llm_model = node_config["llm_model"]
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
GetProbableTagsNode Module
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from langchain.output_parsers import CommaSeparatedListOutputParser
|
||||
from langchain.prompts import PromptTemplate
|
||||
from .base_node import BaseNode
|
||||
|
||||
83
scrapegraphai/nodes/graph_iterator_node.py
Normal file
83
scrapegraphai/nodes/graph_iterator_node.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""
|
||||
GraphIterator Module
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
import copy
|
||||
from tqdm import tqdm
|
||||
from .base_node import BaseNode
|
||||
|
||||
|
||||
class GraphIteratorNode(BaseNode):
|
||||
"""
|
||||
A node responsible for parsing HTML content from a document.
|
||||
The parsed content is split into chunks for further processing.
|
||||
|
||||
This node enhances the scraping workflow by allowing for targeted extraction of
|
||||
content, thereby optimizing the processing of large HTML documents.
|
||||
|
||||
Attributes:
|
||||
verbose (bool): A flag indicating whether to show print statements during execution.
|
||||
|
||||
Args:
|
||||
input (str): Boolean expression defining the input keys needed from the state.
|
||||
output (List[str]): List of output keys to be updated in the state.
|
||||
node_config (dict): Additional configuration for the node.
|
||||
node_name (str): The unique identifier name for the node, defaulting to "Parse".
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "GraphIterator"):
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
|
||||
self.verbose = False if node_config is None else node_config.get("verbose", False)
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
"""
|
||||
Executes the node's logic to parse the HTML document content and split it into chunks.
|
||||
|
||||
Args:
|
||||
state (dict): The current state of the graph. The input keys will be used to fetch the
|
||||
correct data from the state.
|
||||
|
||||
Returns:
|
||||
dict: The updated state with the output key containing the parsed content chunks.
|
||||
|
||||
Raises:
|
||||
KeyError: If the input keys are not found in the state, indicating that the
|
||||
necessary information for parsing the content is missing.
|
||||
"""
|
||||
|
||||
if self.verbose:
|
||||
print(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
# Interpret input keys based on the provided input expression
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
user_prompt = input_data[0]
|
||||
urls = input_data[1]
|
||||
|
||||
graph_instance = self.node_config.get("graph_instance", None)
|
||||
if graph_instance is None:
|
||||
raise ValueError("Graph instance is required for graph iteration.")
|
||||
|
||||
# set the prompt and source for each url
|
||||
graph_instance.prompt = user_prompt
|
||||
graphs_instances = []
|
||||
for url in urls:
|
||||
# make a copy of the graph instance
|
||||
copy_graph_instance = copy.copy(graph_instance)
|
||||
copy_graph_instance.source = url
|
||||
graphs_instances.append(copy_graph_instance)
|
||||
|
||||
# run the graph for each url and use tqdm for progress bar
|
||||
graphs_answers = []
|
||||
for graph in tqdm(graphs_instances, desc="Processing Graph Instances", disable=not self.verbose):
|
||||
result = graph.run()
|
||||
graphs_answers.append(result)
|
||||
|
||||
state.update({self.output[0]: graphs_answers})
|
||||
|
||||
return state
|
||||
@ -2,7 +2,7 @@
|
||||
ImageToTextNode Module
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from .base_node import BaseNode
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ class ImageToTextNode(BaseNode):
|
||||
node_name (str): The unique identifier name for the node, defaulting to "ImageToText".
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
|
||||
node_name: str = "ImageToText"):
|
||||
super().__init__(node_name, "node", input, output, 1, node_config)
|
||||
|
||||
|
||||
104
scrapegraphai/nodes/merge_answers_node.py
Normal file
104
scrapegraphai/nodes/merge_answers_node.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""
|
||||
MergeAnswersNode Module
|
||||
"""
|
||||
|
||||
# Imports from standard library
|
||||
from typing import List, Optional
|
||||
from tqdm import tqdm
|
||||
|
||||
# Imports from Langchain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.runnables import RunnableParallel
|
||||
|
||||
# Imports from the library
|
||||
from .base_node import BaseNode
|
||||
|
||||
|
||||
class MergeAnswersNode(BaseNode):
|
||||
"""
|
||||
A node that generates an answer using a large language model (LLM) based on the user's input
|
||||
and the content extracted from a webpage. It constructs a prompt from the user's input
|
||||
and the scraped content, feeds it to the LLM, and parses the LLM's response to produce
|
||||
an answer.
|
||||
|
||||
Attributes:
|
||||
llm_model: An instance of a language model client, configured for generating answers.
|
||||
verbose (bool): A flag indicating whether to show print statements during execution.
|
||||
|
||||
Args:
|
||||
input (str): Boolean expression defining the input keys needed from the state.
|
||||
output (List[str]): List of output keys to be updated in the state.
|
||||
node_config (dict): Additional configuration for the node.
|
||||
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
|
||||
node_name: str = "MergeAnswers"):
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
|
||||
self.llm_model = node_config["llm_model"]
|
||||
self.verbose = True if node_config is None else node_config.get("verbose", False)
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
"""
|
||||
Generates an answer by constructing a prompt from the user's input and the scraped
|
||||
content, querying the language model, and parsing its response.
|
||||
|
||||
Args:
|
||||
state (dict): The current state of the graph. The input keys will be used
|
||||
to fetch the correct data from the state.
|
||||
|
||||
Returns:
|
||||
dict: The updated state with the output key containing the generated answer.
|
||||
|
||||
Raises:
|
||||
KeyError: If the input keys are not found in the state, indicating
|
||||
that the necessary information for generating an answer is missing.
|
||||
"""
|
||||
|
||||
if self.verbose:
|
||||
print(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
# Interpret input keys based on the provided input expression
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
user_prompt = input_data[0]
|
||||
answers = input_data[1]
|
||||
|
||||
# merge the answers in one string
|
||||
answers_str = ""
|
||||
for i, answer in enumerate(answers):
|
||||
answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n"
|
||||
|
||||
output_parser = JsonOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
template_merge = """
|
||||
You are a website scraper and you have just scraped some content from multiple websites.\n
|
||||
You are now asked to provide an answer to a USER PROMPT based on the content you have scraped.\n
|
||||
You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n
|
||||
The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n
|
||||
OUTPUT INSTRUCTIONS: {format_instructions}\n
|
||||
USER PROMPT: {user_prompt}\n
|
||||
{website_content}
|
||||
"""
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
template=template_merge,
|
||||
input_variables=["user_prompt"],
|
||||
partial_variables={
|
||||
"format_instructions": format_instructions,
|
||||
"website_content": answers_str,
|
||||
},
|
||||
)
|
||||
|
||||
merge_chain = prompt_template | self.llm_model | output_parser
|
||||
answer = merge_chain.invoke({"user_prompt": user_prompt})
|
||||
|
||||
# Update the state with the generated answer
|
||||
state.update({self.output[0]: answer})
|
||||
return state
|
||||
@ -2,7 +2,7 @@
|
||||
ParseNode Module
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.document_transformers import Html2TextTransformer
|
||||
from .base_node import BaseNode
|
||||
@ -26,7 +26,7 @@ class ParseNode(BaseNode):
|
||||
node_name (str): The unique identifier name for the node, defaulting to "Parse".
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "Parse"):
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "Parse"):
|
||||
super().__init__(node_name, "node", input, output, 1, node_config)
|
||||
|
||||
self.verbose = True if node_config is None else node_config.get("verbose", False)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
RAGNode Module
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
|
||||
@ -31,7 +31,7 @@ class RAGNode(BaseNode):
|
||||
node_name (str): The unique identifier name for the node, defaulting to "Parse".
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "RAG"):
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "RAG"):
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
|
||||
self.llm_model = node_config["llm_model"]
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
SearchInternetNode Module
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from langchain.output_parsers import CommaSeparatedListOutputParser
|
||||
from langchain.prompts import PromptTemplate
|
||||
from ..utils.research_web import search_on_web
|
||||
@ -27,12 +27,13 @@ class SearchInternetNode(BaseNode):
|
||||
node_name (str): The unique identifier name for the node, defaulting to "SearchInternet".
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
|
||||
node_name: str = "SearchInternet"):
|
||||
super().__init__(node_name, "node", input, output, 1, node_config)
|
||||
|
||||
self.llm_model = node_config["llm_model"]
|
||||
self.verbose = True if node_config is None else node_config.get("verbose", False)
|
||||
self.max_results = node_config.get("max_results", 3)
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
"""
|
||||
@ -85,8 +86,11 @@ class SearchInternetNode(BaseNode):
|
||||
if self.verbose:
|
||||
print(f"Search Query: {search_query}")
|
||||
|
||||
# TODO: handle multiple URLs
|
||||
answer = search_on_web(query=search_query, max_results=1)[0]
|
||||
answer = search_on_web(query=search_query, max_results=self.max_results)
|
||||
|
||||
if len(answer) == 0:
|
||||
# raise an exception if no answer is found
|
||||
raise ValueError("Zero results found for the search query.")
|
||||
|
||||
# Update the state with the generated answer
|
||||
state.update({self.output[0]: answer})
|
||||
|
||||
@ -3,7 +3,7 @@ SearchLinkNode Module
|
||||
"""
|
||||
|
||||
# Imports from standard library
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from tqdm import tqdm
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
@ -33,7 +33,7 @@ class SearchLinkNode(BaseNode):
|
||||
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
|
||||
node_name: str = "GenerateLinks"):
|
||||
super().__init__(node_name, "node", input, output, 1, node_config)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
TextToSpeechNode Module
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from .base_node import BaseNode
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ class TextToSpeechNode(BaseNode):
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str],
|
||||
node_config: dict, node_name: str = "TextToSpeech"):
|
||||
node_config: Optional[dict]=None, node_name: str = "TextToSpeech"):
|
||||
super().__init__(node_name, "node", input, output, 1, node_config)
|
||||
|
||||
self.tts_model = node_config["tts_model"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user