refactor: move embeddings code from AbstractGraph to RAGNode
Some checks are pending
/ build (push) Waiting to run

This commit is contained in:
Federico Aguzzi 2024-08-01 11:53:17 +02:00
parent bb73d916a1
commit a94ebcde00
2 changed files with 142 additions and 125 deletions

View File

@ -7,15 +7,8 @@ from typing import Optional
import uuid import uuid
from pydantic import BaseModel from pydantic import BaseModel
from langchain_community.chat_models import ChatOllama, ErnieBotChat from langchain_community.chat_models import ErnieBotChat
from langchain_aws import BedrockEmbeddings, ChatBedrock from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
from ..helpers import models_tokens from ..helpers import models_tokens
@ -66,8 +59,6 @@ class AbstractGraph(ABC):
self.config = config self.config = config
self.schema = schema self.schema = schema
self.llm_model = self._create_llm(config["llm"]) self.llm_model = self._create_llm(config["llm"])
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]) if "embeddings" not in config else self._create_embedder(
config["embeddings"])
self.verbose = False if config is None else config.get( self.verbose = False if config is None else config.get(
"verbose", False) "verbose", False)
self.headless = True if config is None else config.get( self.headless = True if config is None else config.get(
@ -237,116 +228,6 @@ class AbstractGraph(ABC):
# Raise an error if the model did not match any of the previous cases # Raise an error if the model did not match any of the previous cases
raise ValueError("Model provided by the configuration not supported") raise ValueError("Model provided by the configuration not supported")
def _create_default_embedder(self, llm_config=None) -> object:
"""
Create an embedding model instance based on the chosen llm model.
Returns:
object: An instance of the embedding model client.
Raises:
ValueError: If the model is not supported.
"""
if isinstance(self.llm_model, ChatGoogleGenerativeAI):
return GoogleGenerativeAIEmbeddings(
google_api_key=llm_config["api_key"], model="models/embedding-001"
)
if isinstance(self.llm_model, ChatOpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key,
base_url=self.llm_model.openai_api_base)
elif isinstance(self.llm_model, DeepSeek):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, ChatVertexAI):
return VertexAIEmbeddings()
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
return self.llm_model
elif isinstance(self.llm_model, AzureChatOpenAI):
return AzureOpenAIEmbeddings()
elif isinstance(self.llm_model, ChatFireworks):
return FireworksEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, ChatNVIDIA):
return NVIDIAEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, ChatOllama):
# unwrap the kwargs from the model whihc is a dict
params = self.llm_model._lc_kwargs
# remove streaming and temperature
params.pop("streaming", None)
params.pop("temperature", None)
return OllamaEmbeddings(**params)
elif isinstance(self.llm_model, ChatHuggingFace):
return HuggingFaceEmbeddings(model=self.llm_model.model)
elif isinstance(self.llm_model, ChatBedrock):
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
else:
raise ValueError("Embedding Model missing or not supported")
def _create_embedder(self, embedder_config: dict) -> object:
"""
Create an embedding model instance based on the configuration provided.
Args:
embedder_config (dict): Configuration parameters for the embedding model.
Returns:
object: An instance of the embedding model client.
Raises:
KeyError: If the model is not supported.
"""
embedder_params = {**embedder_config}
if "model_instance" in embedder_config:
return embedder_params["model_instance"]
# Instantiate the embedding model based on the model name
if "openai" in embedder_params["model"]:
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
if "azure" in embedder_params["model"]:
return AzureOpenAIEmbeddings()
if "nvidia" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["nvidia"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return NVIDIAEmbeddings(model=embedder_params["model"],
nvidia_api_key=embedder_params["api_key"])
if "ollama" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["ollama"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OllamaEmbeddings(**embedder_params)
if "hugging_face" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["hugging_face"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return HuggingFaceEmbeddings(model=embedder_params["model"])
if "fireworks" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["fireworks"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return FireworksEmbeddings(model=embedder_params["model"])
if "gemini" in embedder_params["model"]:
try:
models_tokens["gemini"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return GoogleGenerativeAIEmbeddings(model=embedder_params["model"])
if "bedrock" in embedder_params["model"]:
embedder_params["model"] = embedder_params["model"].split("/")[-1]
client = embedder_params.get("client", None)
try:
models_tokens["bedrock"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return BedrockEmbeddings(client=client, model_id=embedder_params["model"])
raise ValueError("Model provided by the configuration not supported")
def get_state(self, key=None) -> dict: def get_state(self, key=None) -> dict:
""" "" """ ""

View File

@ -14,8 +14,20 @@ from langchain.retrievers.document_compressors import (
from langchain_community.document_transformers import EmbeddingsRedundantFilter from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_community.vectorstores import FAISS from langchain_community.vectorstores import FAISS
from langchain_community.chat_models import ChatOllama
from langchain_aws import BedrockEmbeddings, ChatBedrock
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
from ..utils.logging import get_logger from ..utils.logging import get_logger
from .base_node import BaseNode from .base_node import BaseNode
from ..helpers import models_tokens
from ..models import DeepSeek
class RAGNode(BaseNode): class RAGNode(BaseNode):
@ -95,10 +107,21 @@ class RAGNode(BaseNode):
self.logger.info("--- (updated chunks metadata) ---") self.logger.info("--- (updated chunks metadata) ---")
# check if embedder_model is provided, if not use llm_model # check if embedder_model is provided, if not use llm_model
self.embedder_model = ( if self.embedder_model is not None:
self.embedder_model if self.embedder_model else self.llm_model embeddings = self.embedder_model
) elif 'embeddings' in self.node_config:
embeddings = self.embedder_model try:
embeddings = self._create_embedder(self.node_config['embedder_config'])
except Exception:
try:
embeddings = self._create_default_embedder()
self.embedder_model = embeddings
except ValueError:
embeddings = self.llm_model
self.embedder_model = self.llm_model
else:
embeddings = self.llm_model
self.embedder_model = self.llm_model
folder_name = self.node_config.get("cache_path", "cache") folder_name = self.node_config.get("cache_path", "cache")
@ -141,3 +164,116 @@ class RAGNode(BaseNode):
state.update({self.output[0]: compressed_docs}) state.update({self.output[0]: compressed_docs})
return state return state
def _create_default_embedder(self, llm_config=None) -> object:
"""
Create an embedding model instance based on the chosen llm model.
Returns:
object: An instance of the embedding model client.
Raises:
ValueError: If the model is not supported.
"""
if isinstance(self.llm_model, ChatGoogleGenerativeAI):
return GoogleGenerativeAIEmbeddings(
google_api_key=llm_config["api_key"], model="models/embedding-001"
)
if isinstance(self.llm_model, ChatOpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key,
base_url=self.llm_model.openai_api_base)
elif isinstance(self.llm_model, DeepSeek):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, ChatVertexAI):
return VertexAIEmbeddings()
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
return self.llm_model
elif isinstance(self.llm_model, AzureChatOpenAI):
return AzureOpenAIEmbeddings()
elif isinstance(self.llm_model, ChatFireworks):
return FireworksEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, ChatNVIDIA):
return NVIDIAEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, ChatOllama):
# unwrap the kwargs from the model whihc is a dict
params = self.llm_model._lc_kwargs
# remove streaming and temperature
params.pop("streaming", None)
params.pop("temperature", None)
return OllamaEmbeddings(**params)
elif isinstance(self.llm_model, ChatHuggingFace):
return HuggingFaceEmbeddings(model=self.llm_model.model)
elif isinstance(self.llm_model, ChatBedrock):
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
else:
raise ValueError("Embedding Model missing or not supported")
def _create_embedder(self, embedder_config: dict) -> object:
"""
Create an embedding model instance based on the configuration provided.
Args:
embedder_config (dict): Configuration parameters for the embedding model.
Returns:
object: An instance of the embedding model client.
Raises:
KeyError: If the model is not supported.
"""
embedder_params = {**embedder_config}
if "model_instance" in embedder_config:
return embedder_params["model_instance"]
# Instantiate the embedding model based on the model name
if "openai" in embedder_params["model"]:
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
if "azure" in embedder_params["model"]:
return AzureOpenAIEmbeddings()
if "nvidia" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["nvidia"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return NVIDIAEmbeddings(model=embedder_params["model"],
nvidia_api_key=embedder_params["api_key"])
if "ollama" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["ollama"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OllamaEmbeddings(**embedder_params)
if "hugging_face" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["hugging_face"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return HuggingFaceEmbeddings(model=embedder_params["model"])
if "fireworks" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["fireworks"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return FireworksEmbeddings(model=embedder_params["model"])
if "gemini" in embedder_params["model"]:
try:
models_tokens["gemini"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return GoogleGenerativeAIEmbeddings(model=embedder_params["model"])
if "bedrock" in embedder_params["model"]:
embedder_params["model"] = embedder_params["model"].split("/")[-1]
client = embedder_params.get("client", None)
try:
models_tokens["bedrock"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return BedrockEmbeddings(client=client, model_id=embedder_params["model"])
raise ValueError("Model provided by the configuration not supported")