feat: add gemini embeddings

This commit is contained in:
VinciGit00 2024-05-05 11:40:44 +02:00
parent da8c72ce13
commit 79daa4c112

View File

@ -1,16 +1,14 @@
""" """
AbstractGraph Module AbstractGraph Module
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from ..helpers import models_tokens from ..helpers import models_tokens
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings
class AbstractGraph(ABC): class AbstractGraph(ABC):
@ -229,14 +227,11 @@ class AbstractGraph(ABC):
if 'model_instance' in embedder_config: if 'model_instance' in embedder_config:
return embedder_config['model_instance'] return embedder_config['model_instance']
# Instantiate the embedding model based on the model name # Instantiate the embedding model based on the model name
if "openai" in embedder_config["model"]: if "openai" in embedder_config["model"]:
return OpenAIEmbeddings(api_key=embedder_config["api_key"]) return OpenAIEmbeddings(api_key=embedder_config["api_key"])
elif "azure" in embedder_config["model"]: elif "azure" in embedder_config["model"]:
return AzureOpenAIEmbeddings() return AzureOpenAIEmbeddings()
elif "ollama" in embedder_config["model"]: elif "ollama" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1] embedder_config["model"] = embedder_config["model"].split("/")[-1]
try: try:
@ -244,14 +239,18 @@ class AbstractGraph(ABC):
except KeyError as exc: except KeyError as exc:
raise KeyError("Model not supported") from exc raise KeyError("Model not supported") from exc
return OllamaEmbeddings(**embedder_config) return OllamaEmbeddings(**embedder_config)
elif "hugging_face" in embedder_config["model"]: elif "hugging_face" in embedder_config["model"]:
try: try:
models_tokens["hugging_face"][embedder_config["model"]] models_tokens["hugging_face"][embedder_config["model"]]
except KeyError as exc: except KeyError as exc:
raise KeyError("Model not supported")from exc raise KeyError("Model not supported")from exc
return HuggingFaceHubEmbeddings(model=embedder_config["model"]) return HuggingFaceHubEmbeddings(model=embedder_config["model"])
elif "gemini" in embedder_config["model"]:
try:
models_tokens["gemini"][embedder_config["model"]]
except KeyError as exc:
raise KeyError("Model not supported")from exc
return GoogleGenerativeAIEmbeddings(model=embedder_config["model"])
elif "bedrock" in embedder_config["model"]: elif "bedrock" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1] embedder_config["model"] = embedder_config["model"].split("/")[-1]
try: try: