refactor gen answ node

This commit is contained in:
Marco Vinciguerra 2024-05-23 13:45:23 +02:00
parent 1774b18059
commit 909af8d912
3 changed files with 22 additions and 59 deletions

View File

@ -8,13 +8,9 @@ from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from ..helpers import models_tokens
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic, DeepSeek
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from ..helpers import models_tokens
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic, DeepSeek
class AbstractGraph(ABC):
"""
Scaffolding class for creating a graph representation and executing it.
@ -22,7 +18,6 @@ class AbstractGraph(ABC):
prompt (str): The prompt for the graph.
source (str): The source of the graph.
config (dict): Configuration parameters for the graph.
schema (str): The schema for the graph output.
llm_model: An instance of a language model client, configured for generating answers.
embedder_model: An instance of an embedding model client,
configured for generating embeddings.
@ -33,7 +28,6 @@ class AbstractGraph(ABC):
prompt (str): The prompt for the graph.
config (dict): Configuration parameters for the graph.
source (str, optional): The source of the graph.
schema (str, optional): The schema for the graph output.
Example:
>>> class MyGraph(AbstractGraph):
@ -45,21 +39,15 @@ class AbstractGraph(ABC):
>>> result = my_graph.run()
"""
def __init__(self, prompt: str, config: dict, source: Optional[str] = None, schema: Optional[str] = None):
def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.prompt = prompt
self.source = source
self.config = config
self.schema = schema
self.llm_model = self._create_llm(config["llm"], chat=True)
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]
) if "embeddings" not in config else self._create_embedder(
config["embeddings"])
self.verbose = False if config is None else config.get(
"verbose", False)
self.headless = True if config is None else config.get(
"headless", True)
self.loader_kwargs = config.get("loader_kwargs", {})
# Create the graph
self.graph = self._create_graph()
@ -67,20 +55,18 @@ class AbstractGraph(ABC):
self.execution_info = None
# Set common configuration parameters
self.verbose = False if config is None else config.get(
"verbose", False)
self.headless = True if config is None else config.get(
"headless", True)
self.loader_kwargs = config.get("loader_kwargs", {})
common_params = {
"headless": self.headless,
"verbose": self.verbose,
"loader_kwargs": self.loader_kwargs,
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
}
common_params = {"headless": self.headless,
"loader_kwargs": self.loader_kwargs,
"llm_model": self.llm_model,
"embedder_model": self.embedder_model}
self.set_common_params(common_params, overwrite=False)
def set_common_params(self, params: dict, overwrite=False):
@ -93,7 +79,7 @@ class AbstractGraph(ABC):
for node in self.graph.nodes:
node.update_config(params, overwrite)
def _set_model_token(self, llm):
if 'Azure' in str(type(llm)):
@ -171,7 +157,7 @@ class AbstractGraph(ABC):
raise KeyError("Model not supported") from exc
return Anthropic(llm_params)
elif "ollama" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
llm_params["model"] = llm_params["model"].split("/")[-1]
# allow user to set model_tokens in config
try:
@ -245,8 +231,6 @@ class AbstractGraph(ABC):
model="models/embedding-001")
if isinstance(self.llm_model, OpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, DeepSeek):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
return self.llm_model
elif isinstance(self.llm_model, AzureOpenAI):
@ -282,31 +266,30 @@ class AbstractGraph(ABC):
if 'model_instance' in embedder_config:
return embedder_config['model_instance']
# Instantiate the embedding model based on the model name
if "openai" in embedder_config["model"].split("/")[0]:
if "openai" in embedder_config["model"]:
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
elif "azure" in embedder_config["model"]:
return AzureOpenAIEmbeddings()
elif "ollama" in embedder_config["model"].split("/")[0]:
print("ciao")
embedder_config["model"] = embedder_config["model"].split("ollama/")[-1]
elif "ollama" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
try:
models_tokens["ollama"][embedder_config["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OllamaEmbeddings(**embedder_config)
elif "hugging_face" in embedder_config["model"].split("/")[0]:
elif "hugging_face" in embedder_config["model"]:
try:
models_tokens["hugging_face"][embedder_config["model"]]
except KeyError as exc:
raise KeyError("Model not supported")from exc
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
elif "gemini" in embedder_config["model"].split("/")[0]:
elif "gemini" in embedder_config["model"]:
try:
models_tokens["gemini"][embedder_config["model"]]
except KeyError as exc:
raise KeyError("Model not supported")from exc
return GoogleGenerativeAIEmbeddings(model=embedder_config["model"])
elif "bedrock" in embedder_config["model"].split("/")[0]:
elif "bedrock" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
client = embedder_config.get('client', None)
try:

View File

@ -11,7 +11,7 @@ from ..nodes import (
FetchNode,
ParseNode,
RAGNode,
GenerateAnswerNode
GenerateAnswerPDFNode
)
@ -48,7 +48,7 @@ class PDFScraperGraph(AbstractGraph):
"""
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
super().__init__(prompt, config, source, schema)
super().__init__(prompt, config, source)
self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir"
@ -64,41 +64,21 @@ class PDFScraperGraph(AbstractGraph):
input='pdf | pdf_dir',
output=["doc", "link_urls", "img_urls"],
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.embedder_model,
}
)
generate_answer_node = GenerateAnswerNode(
generate_answer_node_pdf = GenerateAnswerPDFNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm_model": self.llm_model,
"schema": self.schema,
}
)
return BaseGraph(
nodes=[
fetch_node,
parse_node,
rag_node,
generate_answer_node,
generate_answer_node_pdf,
],
edges=[
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
(fetch_node, generate_answer_node_pdf)
],
entry_point=fetch_node
)

View File

@ -49,7 +49,7 @@ class GenerateAnswerPDFNode(BaseNode):
node_name (str): name of the node
"""
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm"]
self.llm_model = node_config["llm_model"]
self.verbose = False if node_config is None else node_config.get(
"verbose", False)