refactor: Changed the way embedding model is created in AbstractGraph class and removed handling of embedding model creation from RAGNode. Now AbstractGraph will call a dedicated method for embedding models instead of _create_llm. This makes it easy to use any LLM with any supported embedding model.

This commit is contained in:
S4mpl3r 2024-05-03 16:14:27 +03:30
parent 1219caa4ff
commit 819cbcd3be
3 changed files with 88 additions and 28 deletions

View File

@ -25,7 +25,7 @@ graph_config = {
}, },
"embeddings": { "embeddings": {
"api_key": openai_key, "api_key": openai_key,
"model": "gpt-3.5-turbo", "model": "openai",
}, },
"headless": False "headless": False
} }

View File

@ -5,8 +5,12 @@ AbstractGraph Module
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace, Groq, Bedrock from langchain_aws.embeddings.bedrock import BedrockEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from ..helpers import models_tokens from ..helpers import models_tokens
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI
class AbstractGraph(ABC): class AbstractGraph(ABC):
@ -43,7 +47,8 @@ class AbstractGraph(ABC):
self.source = source self.source = source
self.config = config self.config = config
self.llm_model = self._create_llm(config["llm"], chat=True) self.llm_model = self._create_llm(config["llm"], chat=True)
self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm( self.embedder_model = self._create_default_embedder(
) if "embeddings" not in config else self._create_embedder(
config["embeddings"]) config["embeddings"])
# Set common configuration parameters # Set common configuration parameters
@ -165,6 +170,85 @@ class AbstractGraph(ABC):
else: else:
raise ValueError( raise ValueError(
"Model provided by the configuration not supported") "Model provided by the configuration not supported")
def _create_default_embedder(self) -> object:
"""
Create an embedding model instance based on the chosen llm model.
Returns:
object: An instance of the embedding model client.
Raises:
ValueError: If the model is not supported.
"""
if isinstance(self.llm_model, OpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
return self.llm_model
elif isinstance(self.llm_model, AzureOpenAI):
return AzureOpenAIEmbeddings()
elif isinstance(self.llm_model, Ollama):
# unwrap the kwargs from the model whihc is a dict
params = self.llm_model._lc_kwargs
# remove streaming and temperature
params.pop("streaming", None)
params.pop("temperature", None)
return OllamaEmbeddings(**params)
elif isinstance(self.llm_model, HuggingFace):
return HuggingFaceHubEmbeddings(model=self.llm_model.model)
elif isinstance(self.llm_model, Bedrock):
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
else:
raise ValueError("Embedding Model missing or not supported")
def _create_embedder(self, embedder_config: dict) -> object:
"""
Create an embedding model instance based on the configuration provided.
Args:
embedder_config (dict): Configuration parameters for the embedding model.
Returns:
object: An instance of the embedding model client.
Raises:
KeyError: If the model is not supported.
"""
# Instantiate the embedding model based on the model name
if "openai" in embedder_config["model"]:
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
elif "azure" in embedder_config["model"]:
return AzureOpenAIEmbeddings()
elif "ollama" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
try:
models_tokens["ollama"][embedder_config["model"]]
except KeyError:
raise KeyError("Model not supported")
return OllamaEmbeddings(**embedder_config)
elif "hugging_face" in embedder_config["model"]:
try:
models_tokens["hugging_face"][embedder_config["model"]]
except KeyError:
raise KeyError("Model not supported")
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
elif "bedrock" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
try:
models_tokens["bedrock"][embedder_config["model"]]
except KeyError:
raise KeyError("Model not supported")
return BedrockEmbeddings(client=None, model_id=embedder_config["model"])
else:
raise ValueError(
"Model provided by the configuration not supported")
def get_state(self, key=None) -> dict: def get_state(self, key=None) -> dict:
""""" """""

View File

@ -87,31 +87,7 @@ class RAGNode(BaseNode):
if self.verbose: if self.verbose:
print("--- (updated chunks metadata) ---") print("--- (updated chunks metadata) ---")
# check if embedder_model is provided, if not use llm_model embeddings = self.embedder_model
embedding_model = self.embedder_model if self.embedder_model else self.llm_model
if isinstance(embedding_model, OpenAI):
embeddings = OpenAIEmbeddings(
api_key=embedding_model.openai_api_key)
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
embeddings = embedding_model
elif isinstance(embedding_model, AzureOpenAI):
embeddings = AzureOpenAIEmbeddings()
elif isinstance(embedding_model, Ollama):
# unwrap the kwargs from the model whihc is a dict
params = embedding_model._lc_kwargs
# remove streaming and temperature
params.pop("streaming", None)
params.pop("temperature", None)
embeddings = OllamaEmbeddings(**params)
elif isinstance(embedding_model, HuggingFace):
embeddings = HuggingFaceHubEmbeddings(model=embedding_model.model)
elif isinstance(embedding_model, Bedrock):
embeddings = BedrockEmbeddings(
client=None, model_id=embedding_model.model_id)
else:
raise ValueError("Embedding Model missing or not supported")
retriever = FAISS.from_documents( retriever = FAISS.from_documents(
chunked_docs, embeddings).as_retriever() chunked_docs, embeddings).as_retriever()