refactor: move embeddings code from AbstractGraph to RAGNode
Some checks are pending
/ build (push) Waiting to run

This commit is contained in:
Federico Aguzzi 2024-08-01 11:53:17 +02:00
parent bb73d916a1
commit a94ebcde00
2 changed files with 142 additions and 125 deletions

View File

@ -7,15 +7,8 @@ from typing import Optional
import uuid
from pydantic import BaseModel
from langchain_community.chat_models import ChatOllama, ErnieBotChat
from langchain_aws import BedrockEmbeddings, ChatBedrock
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
from langchain_community.chat_models import ErnieBotChat
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain.chat_models import init_chat_model
from ..helpers import models_tokens
@ -66,8 +59,6 @@ class AbstractGraph(ABC):
self.config = config
self.schema = schema
self.llm_model = self._create_llm(config["llm"])
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]) if "embeddings" not in config else self._create_embedder(
config["embeddings"])
self.verbose = False if config is None else config.get(
"verbose", False)
self.headless = True if config is None else config.get(
@ -237,116 +228,6 @@ class AbstractGraph(ABC):
# Raise an error if the model did not match any of the previous cases
raise ValueError("Model provided by the configuration not supported")
def _create_default_embedder(self, llm_config=None) -> object:
"""
Create an embedding model instance based on the chosen llm model.
Returns:
object: An instance of the embedding model client.
Raises:
ValueError: If the model is not supported.
"""
if isinstance(self.llm_model, ChatGoogleGenerativeAI):
return GoogleGenerativeAIEmbeddings(
google_api_key=llm_config["api_key"], model="models/embedding-001"
)
if isinstance(self.llm_model, ChatOpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key,
base_url=self.llm_model.openai_api_base)
elif isinstance(self.llm_model, DeepSeek):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, ChatVertexAI):
return VertexAIEmbeddings()
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
return self.llm_model
elif isinstance(self.llm_model, AzureChatOpenAI):
return AzureOpenAIEmbeddings()
elif isinstance(self.llm_model, ChatFireworks):
return FireworksEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, ChatNVIDIA):
return NVIDIAEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, ChatOllama):
# unwrap the kwargs from the model whihc is a dict
params = self.llm_model._lc_kwargs
# remove streaming and temperature
params.pop("streaming", None)
params.pop("temperature", None)
return OllamaEmbeddings(**params)
elif isinstance(self.llm_model, ChatHuggingFace):
return HuggingFaceEmbeddings(model=self.llm_model.model)
elif isinstance(self.llm_model, ChatBedrock):
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
else:
raise ValueError("Embedding Model missing or not supported")
def _create_embedder(self, embedder_config: dict) -> object:
"""
Create an embedding model instance based on the configuration provided.
Args:
embedder_config (dict): Configuration parameters for the embedding model.
Returns:
object: An instance of the embedding model client.
Raises:
KeyError: If the model is not supported.
"""
embedder_params = {**embedder_config}
if "model_instance" in embedder_config:
return embedder_params["model_instance"]
# Instantiate the embedding model based on the model name
if "openai" in embedder_params["model"]:
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
if "azure" in embedder_params["model"]:
return AzureOpenAIEmbeddings()
if "nvidia" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["nvidia"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return NVIDIAEmbeddings(model=embedder_params["model"],
nvidia_api_key=embedder_params["api_key"])
if "ollama" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["ollama"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OllamaEmbeddings(**embedder_params)
if "hugging_face" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["hugging_face"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return HuggingFaceEmbeddings(model=embedder_params["model"])
if "fireworks" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["fireworks"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return FireworksEmbeddings(model=embedder_params["model"])
if "gemini" in embedder_params["model"]:
try:
models_tokens["gemini"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return GoogleGenerativeAIEmbeddings(model=embedder_params["model"])
if "bedrock" in embedder_params["model"]:
embedder_params["model"] = embedder_params["model"].split("/")[-1]
client = embedder_params.get("client", None)
try:
models_tokens["bedrock"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return BedrockEmbeddings(client=client, model_id=embedder_params["model"])
raise ValueError("Model provided by the configuration not supported")
def get_state(self, key=None) -> dict:
""" ""

View File

@ -14,8 +14,20 @@ from langchain.retrievers.document_compressors import (
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_community.vectorstores import FAISS
from langchain_community.chat_models import ChatOllama
from langchain_aws import BedrockEmbeddings, ChatBedrock
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
from ..utils.logging import get_logger
from .base_node import BaseNode
from ..helpers import models_tokens
from ..models import DeepSeek
class RAGNode(BaseNode):
@ -95,10 +107,21 @@ class RAGNode(BaseNode):
self.logger.info("--- (updated chunks metadata) ---")
# check if embedder_model is provided, if not use llm_model
self.embedder_model = (
self.embedder_model if self.embedder_model else self.llm_model
)
embeddings = self.embedder_model
if self.embedder_model is not None:
embeddings = self.embedder_model
elif 'embeddings' in self.node_config:
try:
embeddings = self._create_embedder(self.node_config['embedder_config'])
except Exception:
try:
embeddings = self._create_default_embedder()
self.embedder_model = embeddings
except ValueError:
embeddings = self.llm_model
self.embedder_model = self.llm_model
else:
embeddings = self.llm_model
self.embedder_model = self.llm_model
folder_name = self.node_config.get("cache_path", "cache")
@ -141,3 +164,116 @@ class RAGNode(BaseNode):
state.update({self.output[0]: compressed_docs})
return state
def _create_default_embedder(self, llm_config=None) -> object:
"""
Create an embedding model instance based on the chosen llm model.
Returns:
object: An instance of the embedding model client.
Raises:
ValueError: If the model is not supported.
"""
if isinstance(self.llm_model, ChatGoogleGenerativeAI):
return GoogleGenerativeAIEmbeddings(
google_api_key=llm_config["api_key"], model="models/embedding-001"
)
if isinstance(self.llm_model, ChatOpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key,
base_url=self.llm_model.openai_api_base)
elif isinstance(self.llm_model, DeepSeek):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, ChatVertexAI):
return VertexAIEmbeddings()
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
return self.llm_model
elif isinstance(self.llm_model, AzureChatOpenAI):
return AzureOpenAIEmbeddings()
elif isinstance(self.llm_model, ChatFireworks):
return FireworksEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, ChatNVIDIA):
return NVIDIAEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, ChatOllama):
# unwrap the kwargs from the model whihc is a dict
params = self.llm_model._lc_kwargs
# remove streaming and temperature
params.pop("streaming", None)
params.pop("temperature", None)
return OllamaEmbeddings(**params)
elif isinstance(self.llm_model, ChatHuggingFace):
return HuggingFaceEmbeddings(model=self.llm_model.model)
elif isinstance(self.llm_model, ChatBedrock):
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
else:
raise ValueError("Embedding Model missing or not supported")
def _create_embedder(self, embedder_config: dict) -> object:
"""
Create an embedding model instance based on the configuration provided.
Args:
embedder_config (dict): Configuration parameters for the embedding model.
Returns:
object: An instance of the embedding model client.
Raises:
KeyError: If the model is not supported.
"""
embedder_params = {**embedder_config}
if "model_instance" in embedder_config:
return embedder_params["model_instance"]
# Instantiate the embedding model based on the model name
if "openai" in embedder_params["model"]:
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
if "azure" in embedder_params["model"]:
return AzureOpenAIEmbeddings()
if "nvidia" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["nvidia"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return NVIDIAEmbeddings(model=embedder_params["model"],
nvidia_api_key=embedder_params["api_key"])
if "ollama" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["ollama"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OllamaEmbeddings(**embedder_params)
if "hugging_face" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["hugging_face"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return HuggingFaceEmbeddings(model=embedder_params["model"])
if "fireworks" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["fireworks"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return FireworksEmbeddings(model=embedder_params["model"])
if "gemini" in embedder_params["model"]:
try:
models_tokens["gemini"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return GoogleGenerativeAIEmbeddings(model=embedder_params["model"])
if "bedrock" in embedder_params["model"]:
embedder_params["model"] = embedder_params["model"].split("/")[-1]
client = embedder_params.get("client", None)
try:
models_tokens["bedrock"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return BedrockEmbeddings(client=client, model_id=embedder_params["model"])
raise ValueError("Model provided by the configuration not supported")