mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-25 21:11:11 +08:00
feat: multiple graph instances
This commit is contained in:
parent
1c4ba91620
commit
dbb614a8dd
@ -6,8 +6,8 @@ import os
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
from scrapegraphai.models import OpenAI
|
from scrapegraphai.models import OpenAI
|
||||||
from scrapegraphai.graphs import BaseGraph
|
from scrapegraphai.graphs import BaseGraph, SmartScraperGraph
|
||||||
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode, SearchInternetNode
|
from scrapegraphai.nodes import SearchInternetNode, GraphIteratorNode, MergeAnswersNode
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# ************************************************
|
# ************************************************
|
||||||
@ -23,6 +23,16 @@ graph_config = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# ************************************************
|
||||||
|
# Create a SmartScraperGraph instance
|
||||||
|
# ************************************************
|
||||||
|
|
||||||
|
smart_scraper_graph = SmartScraperGraph(
|
||||||
|
prompt="",
|
||||||
|
source="",
|
||||||
|
config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
# ************************************************
|
# ************************************************
|
||||||
# Define the graph nodes
|
# Define the graph nodes
|
||||||
# ************************************************
|
# ************************************************
|
||||||
@ -32,38 +42,24 @@ embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key)
|
|||||||
|
|
||||||
search_internet_node = SearchInternetNode(
|
search_internet_node = SearchInternetNode(
|
||||||
input="user_prompt",
|
input="user_prompt",
|
||||||
output=["url"],
|
output=["urls"],
|
||||||
node_config={
|
|
||||||
"llm_model": llm_model
|
|
||||||
}
|
|
||||||
)
|
|
||||||
fetch_node = FetchNode(
|
|
||||||
input="url | local_dir",
|
|
||||||
output=["doc"],
|
|
||||||
node_config={
|
|
||||||
"verbose": True,
|
|
||||||
"headless": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
parse_node = ParseNode(
|
|
||||||
input="doc",
|
|
||||||
output=["parsed_doc"],
|
|
||||||
node_config={
|
|
||||||
"chunk_size": 4096,
|
|
||||||
"verbose": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
rag_node = RAGNode(
|
|
||||||
input="user_prompt & (parsed_doc | doc)",
|
|
||||||
output=["relevant_chunks"],
|
|
||||||
node_config={
|
node_config={
|
||||||
"llm_model": llm_model,
|
"llm_model": llm_model,
|
||||||
"embedder_model": embedder,
|
|
||||||
"verbose": True,
|
"verbose": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
generate_answer_node = GenerateAnswerNode(
|
|
||||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
graph_iterator_node = GraphIteratorNode(
|
||||||
|
input="user_prompt & urls",
|
||||||
|
output=["results"],
|
||||||
|
node_config={
|
||||||
|
"graph_instance": smart_scraper_graph,
|
||||||
|
"verbose": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
merge_answers_node = MergeAnswersNode(
|
||||||
|
input="user_prompt & results",
|
||||||
output=["answer"],
|
output=["answer"],
|
||||||
node_config={
|
node_config={
|
||||||
"llm_model": llm_model,
|
"llm_model": llm_model,
|
||||||
@ -78,16 +74,12 @@ generate_answer_node = GenerateAnswerNode(
|
|||||||
graph = BaseGraph(
|
graph = BaseGraph(
|
||||||
nodes=[
|
nodes=[
|
||||||
search_internet_node,
|
search_internet_node,
|
||||||
fetch_node,
|
graph_iterator_node,
|
||||||
parse_node,
|
merge_answers_node
|
||||||
rag_node,
|
|
||||||
generate_answer_node,
|
|
||||||
],
|
],
|
||||||
edges=[
|
edges=[
|
||||||
(search_internet_node, fetch_node),
|
(search_internet_node, graph_iterator_node),
|
||||||
(fetch_node, parse_node),
|
(graph_iterator_node, merge_answers_node)
|
||||||
(parse_node, rag_node),
|
|
||||||
(rag_node, generate_answer_node)
|
|
||||||
],
|
],
|
||||||
entry_point=search_internet_node
|
entry_point=search_internet_node
|
||||||
)
|
)
|
||||||
@ -17,3 +17,5 @@ from .search_link_node import SearchLinkNode
|
|||||||
from .robots_node import RobotsNode
|
from .robots_node import RobotsNode
|
||||||
from .generate_answer_csv_node import GenerateAnswerCSVNode
|
from .generate_answer_csv_node import GenerateAnswerCSVNode
|
||||||
from .generate_answer_pdf_node import GenerateAnswerPDFNode
|
from .generate_answer_pdf_node import GenerateAnswerPDFNode
|
||||||
|
from .graph_iterator_node import GraphIteratorNode
|
||||||
|
from .merge_answers_node import MergeAnswersNode
|
||||||
@ -2,7 +2,7 @@
|
|||||||
Module for generating the answer node
|
Module for generating the answer node
|
||||||
"""
|
"""
|
||||||
# Imports from standard library
|
# Imports from standard library
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# Imports from Langchain
|
# Imports from Langchain
|
||||||
@ -39,7 +39,7 @@ class GenerateAnswerCSVNode(BaseNode):
|
|||||||
updating the state with the generated answer under the 'answer' key.
|
updating the state with the generated answer under the 'answer' key.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
|
||||||
node_name: str = "GenerateAnswer"):
|
node_name: str = "GenerateAnswer"):
|
||||||
"""
|
"""
|
||||||
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
|
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
|
||||||
|
|||||||
@ -3,7 +3,7 @@ GenerateAnswerNode Module
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Imports from standard library
|
# Imports from standard library
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# Imports from Langchain
|
# Imports from Langchain
|
||||||
@ -33,7 +33,7 @@ class GenerateAnswerNode(BaseNode):
|
|||||||
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
|
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
|
||||||
node_name: str = "GenerateAnswer"):
|
node_name: str = "GenerateAnswer"):
|
||||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
Module for generating the answer node
|
Module for generating the answer node
|
||||||
"""
|
"""
|
||||||
# Imports from standard library
|
# Imports from standard library
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# Imports from Langchain
|
# Imports from Langchain
|
||||||
@ -39,7 +39,7 @@ class GenerateAnswerCSVNode(BaseNode):
|
|||||||
updating the state with the generated answer under the 'answer' key.
|
updating the state with the generated answer under the 'answer' key.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
|
||||||
node_name: str = "GenerateAnswer"):
|
node_name: str = "GenerateAnswer"):
|
||||||
"""
|
"""
|
||||||
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
|
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
Module for generating the answer node
|
Module for generating the answer node
|
||||||
"""
|
"""
|
||||||
# Imports from standard library
|
# Imports from standard library
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# Imports from Langchain
|
# Imports from Langchain
|
||||||
@ -39,7 +39,7 @@ class GenerateAnswerPDFNode(BaseNode):
|
|||||||
updating the state with the generated answer under the 'answer' key.
|
updating the state with the generated answer under the 'answer' key.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
|
||||||
node_name: str = "GenerateAnswer"):
|
node_name: str = "GenerateAnswer"):
|
||||||
"""
|
"""
|
||||||
Initializes the GenerateAnswerNodePDF with a language model client and a node name.
|
Initializes the GenerateAnswerNodePDF with a language model client and a node name.
|
||||||
|
|||||||
@ -3,7 +3,7 @@ GenerateScraperNode Module
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Imports from standard library
|
# Imports from standard library
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# Imports from Langchain
|
# Imports from Langchain
|
||||||
@ -36,8 +36,8 @@ class GenerateScraperNode(BaseNode):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
def __init__(self, input: str, output: List[str], library: str, website: str,
|
||||||
library: str, website: str, node_name: str = "GenerateAnswer"):
|
node_config: Optional[dict]=None, node_name: str = "GenerateAnswer"):
|
||||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||||
|
|
||||||
self.llm_model = node_config["llm_model"]
|
self.llm_model = node_config["llm_model"]
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
GetProbableTagsNode Module
|
GetProbableTagsNode Module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from langchain.output_parsers import CommaSeparatedListOutputParser
|
from langchain.output_parsers import CommaSeparatedListOutputParser
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from .base_node import BaseNode
|
from .base_node import BaseNode
|
||||||
|
|||||||
83
scrapegraphai/nodes/graph_iterator_node.py
Normal file
83
scrapegraphai/nodes/graph_iterator_node.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
"""
|
||||||
|
GraphIterator Module
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
import copy
|
||||||
|
from tqdm import tqdm
|
||||||
|
from .base_node import BaseNode
|
||||||
|
|
||||||
|
|
||||||
|
class GraphIteratorNode(BaseNode):
|
||||||
|
"""
|
||||||
|
A node responsible for parsing HTML content from a document.
|
||||||
|
The parsed content is split into chunks for further processing.
|
||||||
|
|
||||||
|
This node enhances the scraping workflow by allowing for targeted extraction of
|
||||||
|
content, thereby optimizing the processing of large HTML documents.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
verbose (bool): A flag indicating whether to show print statements during execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (str): Boolean expression defining the input keys needed from the state.
|
||||||
|
output (List[str]): List of output keys to be updated in the state.
|
||||||
|
node_config (dict): Additional configuration for the node.
|
||||||
|
node_name (str): The unique identifier name for the node, defaulting to "Parse".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "GraphIterator"):
|
||||||
|
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||||
|
|
||||||
|
self.verbose = False if node_config is None else node_config.get("verbose", False)
|
||||||
|
|
||||||
|
def execute(self, state: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Executes the node's logic to parse the HTML document content and split it into chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (dict): The current state of the graph. The input keys will be used to fetch the
|
||||||
|
correct data from the state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The updated state with the output key containing the parsed content chunks.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the input keys are not found in the state, indicating that the
|
||||||
|
necessary information for parsing the content is missing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"--- Executing {self.node_name} Node ---")
|
||||||
|
|
||||||
|
# Interpret input keys based on the provided input expression
|
||||||
|
input_keys = self.get_input_keys(state)
|
||||||
|
|
||||||
|
# Fetching data from the state based on the input keys
|
||||||
|
input_data = [state[key] for key in input_keys]
|
||||||
|
|
||||||
|
user_prompt = input_data[0]
|
||||||
|
urls = input_data[1]
|
||||||
|
|
||||||
|
graph_instance = self.node_config.get("graph_instance", None)
|
||||||
|
if graph_instance is None:
|
||||||
|
raise ValueError("Graph instance is required for graph iteration.")
|
||||||
|
|
||||||
|
# set the prompt and source for each url
|
||||||
|
graph_instance.prompt = user_prompt
|
||||||
|
graphs_instances = []
|
||||||
|
for url in urls:
|
||||||
|
# make a copy of the graph instance
|
||||||
|
copy_graph_instance = copy.copy(graph_instance)
|
||||||
|
copy_graph_instance.source = url
|
||||||
|
graphs_instances.append(copy_graph_instance)
|
||||||
|
|
||||||
|
# run the graph for each url and use tqdm for progress bar
|
||||||
|
graphs_answers = []
|
||||||
|
for graph in tqdm(graphs_instances, desc="Processing Graph Instances", disable=not self.verbose):
|
||||||
|
result = graph.run()
|
||||||
|
graphs_answers.append(result)
|
||||||
|
|
||||||
|
state.update({self.output[0]: graphs_answers})
|
||||||
|
|
||||||
|
return state
|
||||||
@ -2,7 +2,7 @@
|
|||||||
ImageToTextNode Module
|
ImageToTextNode Module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from .base_node import BaseNode
|
from .base_node import BaseNode
|
||||||
|
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ class ImageToTextNode(BaseNode):
|
|||||||
node_name (str): The unique identifier name for the node, defaulting to "ImageToText".
|
node_name (str): The unique identifier name for the node, defaulting to "ImageToText".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
|
||||||
node_name: str = "ImageToText"):
|
node_name: str = "ImageToText"):
|
||||||
super().__init__(node_name, "node", input, output, 1, node_config)
|
super().__init__(node_name, "node", input, output, 1, node_config)
|
||||||
|
|
||||||
|
|||||||
104
scrapegraphai/nodes/merge_answers_node.py
Normal file
104
scrapegraphai/nodes/merge_answers_node.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
"""
|
||||||
|
MergeAnswersNode Module
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Imports from standard library
|
||||||
|
from typing import List, Optional
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Imports from Langchain
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
|
from langchain_core.runnables import RunnableParallel
|
||||||
|
|
||||||
|
# Imports from the library
|
||||||
|
from .base_node import BaseNode
|
||||||
|
|
||||||
|
|
||||||
|
class MergeAnswersNode(BaseNode):
|
||||||
|
"""
|
||||||
|
A node that generates an answer using a large language model (LLM) based on the user's input
|
||||||
|
and the content extracted from a webpage. It constructs a prompt from the user's input
|
||||||
|
and the scraped content, feeds it to the LLM, and parses the LLM's response to produce
|
||||||
|
an answer.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
llm_model: An instance of a language model client, configured for generating answers.
|
||||||
|
verbose (bool): A flag indicating whether to show print statements during execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (str): Boolean expression defining the input keys needed from the state.
|
||||||
|
output (List[str]): List of output keys to be updated in the state.
|
||||||
|
node_config (dict): Additional configuration for the node.
|
||||||
|
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
|
||||||
|
node_name: str = "MergeAnswers"):
|
||||||
|
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||||
|
|
||||||
|
self.llm_model = node_config["llm_model"]
|
||||||
|
self.verbose = True if node_config is None else node_config.get("verbose", False)
|
||||||
|
|
||||||
|
def execute(self, state: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Generates an answer by constructing a prompt from the user's input and the scraped
|
||||||
|
content, querying the language model, and parsing its response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (dict): The current state of the graph. The input keys will be used
|
||||||
|
to fetch the correct data from the state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The updated state with the output key containing the generated answer.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the input keys are not found in the state, indicating
|
||||||
|
that the necessary information for generating an answer is missing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"--- Executing {self.node_name} Node ---")
|
||||||
|
|
||||||
|
# Interpret input keys based on the provided input expression
|
||||||
|
input_keys = self.get_input_keys(state)
|
||||||
|
|
||||||
|
# Fetching data from the state based on the input keys
|
||||||
|
input_data = [state[key] for key in input_keys]
|
||||||
|
|
||||||
|
user_prompt = input_data[0]
|
||||||
|
answers = input_data[1]
|
||||||
|
|
||||||
|
# merge the answers in one string
|
||||||
|
answers_str = ""
|
||||||
|
for i, answer in enumerate(answers):
|
||||||
|
answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n"
|
||||||
|
|
||||||
|
output_parser = JsonOutputParser()
|
||||||
|
format_instructions = output_parser.get_format_instructions()
|
||||||
|
|
||||||
|
template_merge = """
|
||||||
|
You are a website scraper and you have just scraped some content from multiple websites.\n
|
||||||
|
You are now asked to provide an answer to a USER PROMPT based on the content you have scraped.\n
|
||||||
|
You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n
|
||||||
|
The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n
|
||||||
|
OUTPUT INSTRUCTIONS: {format_instructions}\n
|
||||||
|
USER PROMPT: {user_prompt}\n
|
||||||
|
{website_content}
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_template = PromptTemplate(
|
||||||
|
template=template_merge,
|
||||||
|
input_variables=["user_prompt"],
|
||||||
|
partial_variables={
|
||||||
|
"format_instructions": format_instructions,
|
||||||
|
"website_content": answers_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
merge_chain = prompt_template | self.llm_model | output_parser
|
||||||
|
answer = merge_chain.invoke({"user_prompt": user_prompt})
|
||||||
|
|
||||||
|
# Update the state with the generated answer
|
||||||
|
state.update({self.output[0]: answer})
|
||||||
|
return state
|
||||||
@ -2,7 +2,7 @@
|
|||||||
ParseNode Module
|
ParseNode Module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain_community.document_transformers import Html2TextTransformer
|
from langchain_community.document_transformers import Html2TextTransformer
|
||||||
from .base_node import BaseNode
|
from .base_node import BaseNode
|
||||||
@ -26,7 +26,7 @@ class ParseNode(BaseNode):
|
|||||||
node_name (str): The unique identifier name for the node, defaulting to "Parse".
|
node_name (str): The unique identifier name for the node, defaulting to "Parse".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "Parse"):
|
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "Parse"):
|
||||||
super().__init__(node_name, "node", input, output, 1, node_config)
|
super().__init__(node_name, "node", input, output, 1, node_config)
|
||||||
|
|
||||||
self.verbose = True if node_config is None else node_config.get("verbose", False)
|
self.verbose = True if node_config is None else node_config.get("verbose", False)
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
RAGNode Module
|
RAGNode Module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.retrievers import ContextualCompressionRetriever
|
from langchain.retrievers import ContextualCompressionRetriever
|
||||||
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
|
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
|
||||||
@ -31,7 +31,7 @@ class RAGNode(BaseNode):
|
|||||||
node_name (str): The unique identifier name for the node, defaulting to "Parse".
|
node_name (str): The unique identifier name for the node, defaulting to "Parse".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "RAG"):
|
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "RAG"):
|
||||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||||
|
|
||||||
self.llm_model = node_config["llm_model"]
|
self.llm_model = node_config["llm_model"]
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
SearchInternetNode Module
|
SearchInternetNode Module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from langchain.output_parsers import CommaSeparatedListOutputParser
|
from langchain.output_parsers import CommaSeparatedListOutputParser
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from ..utils.research_web import search_on_web
|
from ..utils.research_web import search_on_web
|
||||||
@ -27,12 +27,13 @@ class SearchInternetNode(BaseNode):
|
|||||||
node_name (str): The unique identifier name for the node, defaulting to "SearchInternet".
|
node_name (str): The unique identifier name for the node, defaulting to "SearchInternet".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
|
||||||
node_name: str = "SearchInternet"):
|
node_name: str = "SearchInternet"):
|
||||||
super().__init__(node_name, "node", input, output, 1, node_config)
|
super().__init__(node_name, "node", input, output, 1, node_config)
|
||||||
|
|
||||||
self.llm_model = node_config["llm_model"]
|
self.llm_model = node_config["llm_model"]
|
||||||
self.verbose = True if node_config is None else node_config.get("verbose", False)
|
self.verbose = True if node_config is None else node_config.get("verbose", False)
|
||||||
|
self.max_results = node_config.get("max_results", 3)
|
||||||
|
|
||||||
def execute(self, state: dict) -> dict:
|
def execute(self, state: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
@ -85,8 +86,11 @@ class SearchInternetNode(BaseNode):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"Search Query: {search_query}")
|
print(f"Search Query: {search_query}")
|
||||||
|
|
||||||
# TODO: handle multiple URLs
|
answer = search_on_web(query=search_query, max_results=self.max_results)
|
||||||
answer = search_on_web(query=search_query, max_results=1)[0]
|
|
||||||
|
if len(answer) == 0:
|
||||||
|
# raise an exception if no answer is found
|
||||||
|
raise ValueError("Zero results found for the search query.")
|
||||||
|
|
||||||
# Update the state with the generated answer
|
# Update the state with the generated answer
|
||||||
state.update({self.output[0]: answer})
|
state.update({self.output[0]: answer})
|
||||||
|
|||||||
@ -3,7 +3,7 @@ SearchLinkNode Module
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Imports from standard library
|
# Imports from standard library
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ class SearchLinkNode(BaseNode):
|
|||||||
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
|
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input: str, output: List[str], node_config: dict,
|
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
|
||||||
node_name: str = "GenerateLinks"):
|
node_name: str = "GenerateLinks"):
|
||||||
super().__init__(node_name, "node", input, output, 1, node_config)
|
super().__init__(node_name, "node", input, output, 1, node_config)
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
TextToSpeechNode Module
|
TextToSpeechNode Module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from .base_node import BaseNode
|
from .base_node import BaseNode
|
||||||
|
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ class TextToSpeechNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input: str, output: List[str],
|
def __init__(self, input: str, output: List[str],
|
||||||
node_config: dict, node_name: str = "TextToSpeech"):
|
node_config: Optional[dict]=None, node_name: str = "TextToSpeech"):
|
||||||
super().__init__(node_name, "node", input, output, 1, node_config)
|
super().__init__(node_name, "node", input, output, 1, node_config)
|
||||||
|
|
||||||
self.tts_model = node_config["tts_model"]
|
self.tts_model = node_config["tts_model"]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user