Merge pull request #154 from epage480/pass-common-params-graph

Pass common params to nodes in graph
This commit is contained in:
Marco Vinciguerra 2024-05-05 16:26:07 +02:00 committed by GitHub
commit 3bef9bb7c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 94 additions and 114 deletions

View File

@ -34,7 +34,7 @@ llm_model = OpenAI(graph_config["llm"])
robot_node = RobotsNode(
input="url",
output=["is_scrapable"],
node_config={"llm": llm_model}
node_config={"llm_model": llm_model}
)
fetch_node = FetchNode(
@ -50,12 +50,12 @@ parse_node = ParseNode(
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={"llm": llm_model},
node_config={"llm_model": llm_model},
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={"llm": llm_model},
node_config={"llm_model": llm_model},
)
# ************************************************

View File

@ -26,7 +26,7 @@ llm_model = Ollama(graph_config["llm"])
robots_node = RobotsNode(
input="url",
output=["is_scrapable"],
node_config={"llm": llm_model,
node_config={"llm_model": llm_model,
"headless": False
}
)

View File

@ -52,16 +52,33 @@ class AbstractGraph(ABC):
) if "embeddings" not in config else self._create_embedder(
config["embeddings"])
# Set common configuration parameters
self.verbose = True if config is None else config.get("verbose", False)
self.headless = True if config is None else config.get(
"headless", True)
# Create the graph
self.graph = self._create_graph()
self.final_state = None
self.execution_info = None
# Set common configuration parameters
self.verbose = True if config is None else config.get("verbose", False)
self.headless = True if config is None else config.get(
"headless", True)
common_params = {"headless": self.headless,
"verbose": self.verbose,
"llm_model": self.llm_model,
"embedder_model": self.embedder_model}
self.set_common_params(common_params, overwrite=False)
def set_common_params(self, params: dict, overwrite=False):
"""
Pass parameters to every node in the graph unless otherwise defined in the graph.
Args:
params (dict): Common parameters and their values.
"""
for node in self.graph.nodes:
node.update_config(params, overwrite)
def _set_model_token(self, llm):
if 'Azure' in str(type(llm)):

View File

@ -32,34 +32,27 @@ class CSVScraperGraph(AbstractGraph):
fetch_node = FetchNode(
input="csv_dir",
output=["doc"],
node_config={
"headless": self.headless,
"verbose": self.verbose
}
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
"verbose": self.verbose
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm": self.llm_model,
"llm_model": self.llm_model,
"embedder_model": self.embedder_model,
"verbose": self.verbose
}
)
generate_answer_node = GenerateAnswerCSVNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm": self.llm_model,
"verbose": self.verbose
"llm_model": self.llm_model,
}
)
@ -85,4 +78,4 @@ class CSVScraperGraph(AbstractGraph):
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)
return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")

View File

@ -56,34 +56,27 @@ class JSONScraperGraph(AbstractGraph):
fetch_node = FetchNode(
input="json_dir",
output=["doc"],
node_config={
"headless": self.headless,
"verbose": self.verbose
}
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
"verbose": self.verbose
"chunk_size": self.model_token
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm": self.llm_model,
"embedder_model": self.embedder_model,
"verbose": self.verbose
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm": self.llm_model,
"verbose": self.verbose
"llm_model": self.llm_model
}
)
@ -113,4 +106,4 @@ class JSONScraperGraph(AbstractGraph):
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)
return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")

View File

@ -61,32 +61,25 @@ class ScriptCreatorGraph(AbstractGraph):
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
node_config={
"headless": self.headless,
"verbose": self.verbose
}
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={"chunk_size": self.model_token,
"verbose": self.verbose
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm": self.llm_model,
"embedder_model": self.embedder_model,
"verbose": self.verbose
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_scraper_node = GenerateScraperNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={"llm": self.llm_model,
"verbose": self.verbose},
node_config={"llm_model": self.llm_model},
library=self.library,
website=self.source
)
@ -117,4 +110,4 @@ class ScriptCreatorGraph(AbstractGraph):
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)
return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")

View File

@ -50,41 +50,33 @@ class SearchGraph(AbstractGraph):
input="user_prompt",
output=["url"],
node_config={
"llm": self.llm_model,
"verbose": self.verbose
"llm_model": self.llm_model
}
)
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
node_config={
"headless": self.headless,
"verbose": self.verbose
}
output=["doc"]
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
"verbose": self.verbose
"chunk_size": self.model_token
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm": self.llm_model,
"embedder_model": self.embedder_model,
"verbose": self.verbose
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm": self.llm_model,
"verbose": self.verbose
"llm_model": self.llm_model
}
)
@ -116,4 +108,4 @@ class SearchGraph(AbstractGraph):
inputs = {"user_prompt": self.prompt}
self.final_state, self.execution_info = self.graph.execute(inputs)
return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")

View File

@ -57,35 +57,28 @@ class SmartScraperGraph(AbstractGraph):
"""
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
node_config={
"headless": self.headless,
"verbose": self.verbose
}
output=["doc"]
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
"verbose": self.verbose
"chunk_size": self.model_token
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm": self.llm_model,
"embedder_model": self.embedder_model,
"verbose": self.verbose
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm": self.llm_model,
"verbose": self.verbose
"llm_model": self.llm_model
}
)
@ -115,4 +108,4 @@ class SmartScraperGraph(AbstractGraph):
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)
return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")

View File

@ -56,43 +56,34 @@ class SpeechGraph(AbstractGraph):
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
node_config={
"headless": self.headless,
"verbose": self.verbose
}
output=["doc"]
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
"verbose": self.verbose
"chunk_size": self.model_token
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm": self.llm_model,
"embedder_model": self.embedder_model,
"verbose": self.verbose
}
"llm_model": self.llm_model,
"embedder_model": self.embedder_model }
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm": self.llm_model,
"verbose": self.verbose
"llm_model": self.llm_model
}
)
text_to_speech_node = TextToSpeechNode(
input="answer",
output=["audio"],
node_config={
"tts_model": OpenAITextToSpeech(self.config["tts_model"]),
"verbose": self.verbose
"tts_model": OpenAITextToSpeech(self.config["tts_model"])
}
)
@ -131,4 +122,4 @@ class SpeechGraph(AbstractGraph):
"output_path", "output.mp3"))
print(f"Audio saved to {self.config.get('output_path', 'output.mp3')}")
return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")

View File

@ -57,35 +57,28 @@ class XMLScraperGraph(AbstractGraph):
fetch_node = FetchNode(
input="xml_dir",
output=["doc"],
node_config={
"headless": self.headless,
"verbose": self.verbose
}
output=["doc"]
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
"verbose": self.verbose
"chunk_size": self.model_token
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm": self.llm_model,
"embedder_model": self.embedder_model,
"verbose": self.verbose
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm": self.llm_model,
"verbose": self.verbose
"llm_model": self.llm_model
}
)
@ -115,4 +108,4 @@ class XMLScraperGraph(AbstractGraph):
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)
return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")

View File

@ -68,6 +68,21 @@ class BaseNode(ABC):
pass
def update_config(self, params: dict, overwrite: bool = False):
"""
Updates the node_config dictionary as well as attributes with same key.
Args:
param (dict): The dictionary to update node_config with.
overwrite (bool): Flag indicating if the values of node_config should be overwritten if their value is not None.
"""
if self.node_config is None:
self.node_config = {}
for key, val in params.items():
if hasattr(self, key) and (key not in self.node_config or overwrite):
self.node_config[key] = val
setattr(self, key, val)
def get_input_keys(self, state: dict) -> List[str]:
"""
Determines the necessary state keys based on the input specification.

View File

@ -29,7 +29,7 @@ class FetchNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "Fetch".
"""
def __init__(self, input: str, output: List[str], node_config: Optional[dict], node_name: str = "Fetch"):
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "Fetch"):
super().__init__(node_name, "node", input, output, 1)
self.headless = True if node_config is None else node_config.get("headless", True)

View File

@ -22,14 +22,14 @@ class GenerateAnswerCSVNode(BaseNode):
an answer.
Attributes:
llm: An instance of a language model client, configured for generating answers.
llm_model: An instance of a language model client, configured for generating answers.
node_name (str): The unique identifier name for the node, defaulting
to "GenerateAnswerNodeCsv".
node_type (str): The type of the node, set to "node" indicating a
standard operational node.
Args:
llm: An instance of the language model client (e.g., ChatOpenAI) used
llm_model: An instance of the language model client (e.g., ChatOpenAI) used
for generating answers.
node_name (str, optional): The unique identifier name for the node.
Defaults to "GenerateAnswerNodeCsv".
@ -44,11 +44,11 @@ class GenerateAnswerCSVNode(BaseNode):
"""
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
Args:
llm: An instance of the OpenAIImageToText class.
llm_model: An instance of the OpenAIImageToText class.
node_name (str): name of the node
"""
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm"]
self.llm_model = node_config["llm_model"]
self.verbose = True if node_config is None else node_config.get(
"verbose", False)

View File

@ -37,7 +37,7 @@ class GenerateAnswerNode(BaseNode):
node_name: str = "GenerateAnswer"):
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm"]
self.llm_model = node_config["llm_model"]
self.verbose = True if node_config is None else node_config.get("verbose", False)
def execute(self, state: dict) -> dict:

View File

@ -22,14 +22,14 @@ class GenerateAnswerCSVNode(BaseNode):
an answer.
Attributes:
llm: An instance of a language model client, configured for generating answers.
llm_model: An instance of a language model client, configured for generating answers.
node_name (str): The unique identifier name for the node, defaulting
to "GenerateAnswerNodeCsv".
node_type (str): The type of the node, set to "node" indicating a
standard operational node.
Args:
llm: An instance of the language model client (e.g., ChatOpenAI) used
llm_model: An instance of the language model client (e.g., ChatOpenAI) used
for generating answers.
node_name (str, optional): The unique identifier name for the node.
Defaults to "GenerateAnswerNodeCsv".
@ -44,11 +44,11 @@ class GenerateAnswerCSVNode(BaseNode):
"""
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
Args:
llm: An instance of the OpenAIImageToText class.
llm_model: An instance of the OpenAIImageToText class.
node_name (str): name of the node
"""
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm"]
self.llm_model = node_config["llm_model"]
self.verbose = True if node_config is None else node_config.get(
"verbose", False)

View File

@ -40,7 +40,7 @@ class GenerateScraperNode(BaseNode):
library: str, website: str, node_name: str = "GenerateAnswer"):
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm"]
self.llm_model = node_config["llm_model"]
self.library = library
self.source = website

View File

@ -34,7 +34,7 @@ class RAGNode(BaseNode):
def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "RAG"):
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm"]
self.llm_model = node_config["llm_model"]
self.embedder_model = node_config.get("embedder_model", None)
self.verbose = True if node_config is None else node_config.get(
"verbose", False)

View File

@ -38,7 +38,7 @@ class RobotsNode(BaseNode):
node_name: str = "Robots"):
super().__init__(node_name, "node", input, output, 1)
self.llm_model = node_config["llm"]
self.llm_model = node_config["llm_model"]
self.force_scraping = force_scraping
self.verbose = True if node_config is None else node_config.get("verbose", False)

View File

@ -31,7 +31,7 @@ class SearchInternetNode(BaseNode):
node_name: str = "SearchInternet"):
super().__init__(node_name, "node", input, output, 1, node_config)
self.llm_model = node_config["llm"]
self.llm_model = node_config["llm_model"]
self.verbose = True if node_config is None else node_config.get("verbose", False)
def execute(self, state: dict) -> dict:

View File

@ -37,7 +37,7 @@ class SearchLinkNode(BaseNode):
node_name: str = "GenerateLinks"):
super().__init__(node_name, "node", input, output, 1, node_config)
self.llm_model = node_config["llm"]
self.llm_model = node_config["llm_model"]
self.verbose = True if node_config is None else node_config.get("verbose", False)
def execute(self, state: dict) -> dict: