mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-25 21:11:11 +08:00
Merge pull request #154 from epage480/pass-common-params-graph
Pass common params to nodes in graph
This commit is contained in:
commit
3bef9bb7c8
@ -34,7 +34,7 @@ llm_model = OpenAI(graph_config["llm"])
|
||||
robot_node = RobotsNode(
|
||||
input="url",
|
||||
output=["is_scrapable"],
|
||||
node_config={"llm": llm_model}
|
||||
node_config={"llm_model": llm_model}
|
||||
)
|
||||
|
||||
fetch_node = FetchNode(
|
||||
@ -50,12 +50,12 @@ parse_node = ParseNode(
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
node_config={"llm": llm_model},
|
||||
node_config={"llm_model": llm_model},
|
||||
)
|
||||
generate_answer_node = GenerateAnswerNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
node_config={"llm": llm_model},
|
||||
node_config={"llm_model": llm_model},
|
||||
)
|
||||
|
||||
# ************************************************
|
||||
|
||||
@ -26,7 +26,7 @@ llm_model = Ollama(graph_config["llm"])
|
||||
robots_node = RobotsNode(
|
||||
input="url",
|
||||
output=["is_scrapable"],
|
||||
node_config={"llm": llm_model,
|
||||
node_config={"llm_model": llm_model,
|
||||
"headless": False
|
||||
}
|
||||
)
|
||||
|
||||
@ -52,16 +52,33 @@ class AbstractGraph(ABC):
|
||||
) if "embeddings" not in config else self._create_embedder(
|
||||
config["embeddings"])
|
||||
|
||||
# Set common configuration parameters
|
||||
self.verbose = True if config is None else config.get("verbose", False)
|
||||
self.headless = True if config is None else config.get(
|
||||
"headless", True)
|
||||
|
||||
# Create the graph
|
||||
self.graph = self._create_graph()
|
||||
self.final_state = None
|
||||
self.execution_info = None
|
||||
|
||||
# Set common configuration parameters
|
||||
self.verbose = True if config is None else config.get("verbose", False)
|
||||
self.headless = True if config is None else config.get(
|
||||
"headless", True)
|
||||
common_params = {"headless": self.headless,
|
||||
"verbose": self.verbose,
|
||||
"llm_model": self.llm_model,
|
||||
"embedder_model": self.embedder_model}
|
||||
self.set_common_params(common_params, overwrite=False)
|
||||
|
||||
|
||||
def set_common_params(self, params: dict, overwrite=False):
|
||||
"""
|
||||
Pass parameters to every node in the graph unless otherwise defined in the graph.
|
||||
|
||||
Args:
|
||||
params (dict): Common parameters and their values.
|
||||
"""
|
||||
|
||||
for node in self.graph.nodes:
|
||||
node.update_config(params, overwrite)
|
||||
|
||||
def _set_model_token(self, llm):
|
||||
|
||||
if 'Azure' in str(type(llm)):
|
||||
|
||||
@ -32,34 +32,27 @@ class CSVScraperGraph(AbstractGraph):
|
||||
fetch_node = FetchNode(
|
||||
input="csv_dir",
|
||||
output=["doc"],
|
||||
node_config={
|
||||
"headless": self.headless,
|
||||
"verbose": self.verbose
|
||||
}
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={
|
||||
"chunk_size": self.model_token,
|
||||
"verbose": self.verbose
|
||||
}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"llm_model": self.llm_model,
|
||||
"embedder_model": self.embedder_model,
|
||||
"verbose": self.verbose
|
||||
}
|
||||
)
|
||||
generate_answer_node = GenerateAnswerCSVNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"verbose": self.verbose
|
||||
"llm_model": self.llm_model,
|
||||
}
|
||||
)
|
||||
|
||||
@ -85,4 +78,4 @@ class CSVScraperGraph(AbstractGraph):
|
||||
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
|
||||
self.final_state, self.execution_info = self.graph.execute(inputs)
|
||||
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
@ -56,34 +56,27 @@ class JSONScraperGraph(AbstractGraph):
|
||||
fetch_node = FetchNode(
|
||||
input="json_dir",
|
||||
output=["doc"],
|
||||
node_config={
|
||||
"headless": self.headless,
|
||||
"verbose": self.verbose
|
||||
}
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={
|
||||
"chunk_size": self.model_token,
|
||||
"verbose": self.verbose
|
||||
"chunk_size": self.model_token
|
||||
}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"embedder_model": self.embedder_model,
|
||||
"verbose": self.verbose
|
||||
"llm_model": self.llm_model,
|
||||
"embedder_model": self.embedder_model
|
||||
}
|
||||
)
|
||||
generate_answer_node = GenerateAnswerNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"verbose": self.verbose
|
||||
"llm_model": self.llm_model
|
||||
}
|
||||
)
|
||||
|
||||
@ -113,4 +106,4 @@ class JSONScraperGraph(AbstractGraph):
|
||||
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
|
||||
self.final_state, self.execution_info = self.graph.execute(inputs)
|
||||
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
@ -61,32 +61,25 @@ class ScriptCreatorGraph(AbstractGraph):
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"],
|
||||
node_config={
|
||||
"headless": self.headless,
|
||||
"verbose": self.verbose
|
||||
}
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={"chunk_size": self.model_token,
|
||||
"verbose": self.verbose
|
||||
}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"embedder_model": self.embedder_model,
|
||||
"verbose": self.verbose
|
||||
"llm_model": self.llm_model,
|
||||
"embedder_model": self.embedder_model
|
||||
}
|
||||
)
|
||||
generate_scraper_node = GenerateScraperNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
node_config={"llm": self.llm_model,
|
||||
"verbose": self.verbose},
|
||||
node_config={"llm_model": self.llm_model},
|
||||
library=self.library,
|
||||
website=self.source
|
||||
)
|
||||
@ -117,4 +110,4 @@ class ScriptCreatorGraph(AbstractGraph):
|
||||
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
|
||||
self.final_state, self.execution_info = self.graph.execute(inputs)
|
||||
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
@ -50,41 +50,33 @@ class SearchGraph(AbstractGraph):
|
||||
input="user_prompt",
|
||||
output=["url"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"verbose": self.verbose
|
||||
"llm_model": self.llm_model
|
||||
}
|
||||
)
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"],
|
||||
node_config={
|
||||
"headless": self.headless,
|
||||
"verbose": self.verbose
|
||||
}
|
||||
output=["doc"]
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={
|
||||
"chunk_size": self.model_token,
|
||||
"verbose": self.verbose
|
||||
"chunk_size": self.model_token
|
||||
}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"embedder_model": self.embedder_model,
|
||||
"verbose": self.verbose
|
||||
"llm_model": self.llm_model,
|
||||
"embedder_model": self.embedder_model
|
||||
}
|
||||
)
|
||||
generate_answer_node = GenerateAnswerNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"verbose": self.verbose
|
||||
"llm_model": self.llm_model
|
||||
}
|
||||
)
|
||||
|
||||
@ -116,4 +108,4 @@ class SearchGraph(AbstractGraph):
|
||||
inputs = {"user_prompt": self.prompt}
|
||||
self.final_state, self.execution_info = self.graph.execute(inputs)
|
||||
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
@ -57,35 +57,28 @@ class SmartScraperGraph(AbstractGraph):
|
||||
"""
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"],
|
||||
node_config={
|
||||
"headless": self.headless,
|
||||
"verbose": self.verbose
|
||||
}
|
||||
output=["doc"]
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={
|
||||
"chunk_size": self.model_token,
|
||||
"verbose": self.verbose
|
||||
"chunk_size": self.model_token
|
||||
}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"embedder_model": self.embedder_model,
|
||||
"verbose": self.verbose
|
||||
"llm_model": self.llm_model,
|
||||
"embedder_model": self.embedder_model
|
||||
}
|
||||
)
|
||||
generate_answer_node = GenerateAnswerNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"verbose": self.verbose
|
||||
"llm_model": self.llm_model
|
||||
}
|
||||
)
|
||||
|
||||
@ -115,4 +108,4 @@ class SmartScraperGraph(AbstractGraph):
|
||||
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
|
||||
self.final_state, self.execution_info = self.graph.execute(inputs)
|
||||
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
@ -56,43 +56,34 @@ class SpeechGraph(AbstractGraph):
|
||||
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"],
|
||||
node_config={
|
||||
"headless": self.headless,
|
||||
"verbose": self.verbose
|
||||
}
|
||||
output=["doc"]
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={
|
||||
"chunk_size": self.model_token,
|
||||
"verbose": self.verbose
|
||||
"chunk_size": self.model_token
|
||||
}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"embedder_model": self.embedder_model,
|
||||
"verbose": self.verbose
|
||||
}
|
||||
"llm_model": self.llm_model,
|
||||
"embedder_model": self.embedder_model }
|
||||
)
|
||||
generate_answer_node = GenerateAnswerNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"verbose": self.verbose
|
||||
"llm_model": self.llm_model
|
||||
}
|
||||
)
|
||||
text_to_speech_node = TextToSpeechNode(
|
||||
input="answer",
|
||||
output=["audio"],
|
||||
node_config={
|
||||
"tts_model": OpenAITextToSpeech(self.config["tts_model"]),
|
||||
"verbose": self.verbose
|
||||
"tts_model": OpenAITextToSpeech(self.config["tts_model"])
|
||||
}
|
||||
)
|
||||
|
||||
@ -131,4 +122,4 @@ class SpeechGraph(AbstractGraph):
|
||||
"output_path", "output.mp3"))
|
||||
print(f"Audio saved to {self.config.get('output_path', 'output.mp3')}")
|
||||
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
@ -57,35 +57,28 @@ class XMLScraperGraph(AbstractGraph):
|
||||
|
||||
fetch_node = FetchNode(
|
||||
input="xml_dir",
|
||||
output=["doc"],
|
||||
node_config={
|
||||
"headless": self.headless,
|
||||
"verbose": self.verbose
|
||||
}
|
||||
output=["doc"]
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={
|
||||
"chunk_size": self.model_token,
|
||||
"verbose": self.verbose
|
||||
"chunk_size": self.model_token
|
||||
}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"embedder_model": self.embedder_model,
|
||||
"verbose": self.verbose
|
||||
"llm_model": self.llm_model,
|
||||
"embedder_model": self.embedder_model
|
||||
}
|
||||
)
|
||||
generate_answer_node = GenerateAnswerNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm": self.llm_model,
|
||||
"verbose": self.verbose
|
||||
"llm_model": self.llm_model
|
||||
}
|
||||
)
|
||||
|
||||
@ -115,4 +108,4 @@ class XMLScraperGraph(AbstractGraph):
|
||||
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
|
||||
self.final_state, self.execution_info = self.graph.execute(inputs)
|
||||
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
@ -68,6 +68,21 @@ class BaseNode(ABC):
|
||||
|
||||
pass
|
||||
|
||||
def update_config(self, params: dict, overwrite: bool = False):
|
||||
"""
|
||||
Updates the node_config dictionary as well as attributes with same key.
|
||||
|
||||
Args:
|
||||
param (dict): The dictionary to update node_config with.
|
||||
overwrite (bool): Flag indicating if the values of node_config should be overwritten if their value is not None.
|
||||
"""
|
||||
if self.node_config is None:
|
||||
self.node_config = {}
|
||||
for key, val in params.items():
|
||||
if hasattr(self, key) and (key not in self.node_config or overwrite):
|
||||
self.node_config[key] = val
|
||||
setattr(self, key, val)
|
||||
|
||||
def get_input_keys(self, state: dict) -> List[str]:
|
||||
"""
|
||||
Determines the necessary state keys based on the input specification.
|
||||
|
||||
@ -29,7 +29,7 @@ class FetchNode(BaseNode):
|
||||
node_name (str): The unique identifier name for the node, defaulting to "Fetch".
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict], node_name: str = "Fetch"):
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "Fetch"):
|
||||
super().__init__(node_name, "node", input, output, 1)
|
||||
|
||||
self.headless = True if node_config is None else node_config.get("headless", True)
|
||||
|
||||
@ -22,14 +22,14 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
an answer.
|
||||
|
||||
Attributes:
|
||||
llm: An instance of a language model client, configured for generating answers.
|
||||
llm_model: An instance of a language model client, configured for generating answers.
|
||||
node_name (str): The unique identifier name for the node, defaulting
|
||||
to "GenerateAnswerNodeCsv".
|
||||
node_type (str): The type of the node, set to "node" indicating a
|
||||
standard operational node.
|
||||
|
||||
Args:
|
||||
llm: An instance of the language model client (e.g., ChatOpenAI) used
|
||||
llm_model: An instance of the language model client (e.g., ChatOpenAI) used
|
||||
for generating answers.
|
||||
node_name (str, optional): The unique identifier name for the node.
|
||||
Defaults to "GenerateAnswerNodeCsv".
|
||||
@ -44,11 +44,11 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
"""
|
||||
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
|
||||
Args:
|
||||
llm: An instance of the OpenAIImageToText class.
|
||||
llm_model: An instance of the OpenAIImageToText class.
|
||||
node_name (str): name of the node
|
||||
"""
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
self.llm_model = node_config["llm"]
|
||||
self.llm_model = node_config["llm_model"]
|
||||
self.verbose = True if node_config is None else node_config.get(
|
||||
"verbose", False)
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ class GenerateAnswerNode(BaseNode):
|
||||
node_name: str = "GenerateAnswer"):
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
|
||||
self.llm_model = node_config["llm"]
|
||||
self.llm_model = node_config["llm_model"]
|
||||
self.verbose = True if node_config is None else node_config.get("verbose", False)
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
|
||||
@ -22,14 +22,14 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
an answer.
|
||||
|
||||
Attributes:
|
||||
llm: An instance of a language model client, configured for generating answers.
|
||||
llm_model: An instance of a language model client, configured for generating answers.
|
||||
node_name (str): The unique identifier name for the node, defaulting
|
||||
to "GenerateAnswerNodeCsv".
|
||||
node_type (str): The type of the node, set to "node" indicating a
|
||||
standard operational node.
|
||||
|
||||
Args:
|
||||
llm: An instance of the language model client (e.g., ChatOpenAI) used
|
||||
llm_model: An instance of the language model client (e.g., ChatOpenAI) used
|
||||
for generating answers.
|
||||
node_name (str, optional): The unique identifier name for the node.
|
||||
Defaults to "GenerateAnswerNodeCsv".
|
||||
@ -44,11 +44,11 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
"""
|
||||
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
|
||||
Args:
|
||||
llm: An instance of the OpenAIImageToText class.
|
||||
llm_model: An instance of the OpenAIImageToText class.
|
||||
node_name (str): name of the node
|
||||
"""
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
self.llm_model = node_config["llm"]
|
||||
self.llm_model = node_config["llm_model"]
|
||||
self.verbose = True if node_config is None else node_config.get(
|
||||
"verbose", False)
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ class GenerateScraperNode(BaseNode):
|
||||
library: str, website: str, node_name: str = "GenerateAnswer"):
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
|
||||
self.llm_model = node_config["llm"]
|
||||
self.llm_model = node_config["llm_model"]
|
||||
self.library = library
|
||||
self.source = website
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ class RAGNode(BaseNode):
|
||||
def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "RAG"):
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
|
||||
self.llm_model = node_config["llm"]
|
||||
self.llm_model = node_config["llm_model"]
|
||||
self.embedder_model = node_config.get("embedder_model", None)
|
||||
self.verbose = True if node_config is None else node_config.get(
|
||||
"verbose", False)
|
||||
|
||||
@ -38,7 +38,7 @@ class RobotsNode(BaseNode):
|
||||
node_name: str = "Robots"):
|
||||
super().__init__(node_name, "node", input, output, 1)
|
||||
|
||||
self.llm_model = node_config["llm"]
|
||||
self.llm_model = node_config["llm_model"]
|
||||
self.force_scraping = force_scraping
|
||||
self.verbose = True if node_config is None else node_config.get("verbose", False)
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ class SearchInternetNode(BaseNode):
|
||||
node_name: str = "SearchInternet"):
|
||||
super().__init__(node_name, "node", input, output, 1, node_config)
|
||||
|
||||
self.llm_model = node_config["llm"]
|
||||
self.llm_model = node_config["llm_model"]
|
||||
self.verbose = True if node_config is None else node_config.get("verbose", False)
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
|
||||
@ -37,7 +37,7 @@ class SearchLinkNode(BaseNode):
|
||||
node_name: str = "GenerateLinks"):
|
||||
super().__init__(node_name, "node", input, output, 1, node_config)
|
||||
|
||||
self.llm_model = node_config["llm"]
|
||||
self.llm_model = node_config["llm_model"]
|
||||
self.verbose = True if node_config is None else node_config.get("verbose", False)
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user