feat: Fix bug for gemini case when embeddings config not passed

This commit is contained in:
Shubham Kamboj 2024-05-06 15:40:03 +05:30
parent 5e1d5db6da
commit 726de28898
2 changed files with 12 additions and 3 deletions

View File

@ -46,7 +46,7 @@ class AbstractGraph(ABC):
self.source = source
self.config = config
self.llm_model = self._create_llm(config["llm"], chat=True)
self.embedder_model = self._create_default_embedder(
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]
) if "embeddings" not in config else self._create_embedder(
config["embeddings"])
@ -91,6 +91,13 @@ class AbstractGraph(ABC):
self.model_token = models_tokens['mistral'][llm.repo_id]
except KeyError:
raise KeyError("Model not supported")
elif 'Google' in str(type(llm)):
try:
if 'gemini' in llm.model:
self.model_token = models_tokens['gemini'][llm.model]
except KeyError:
raise KeyError("Model not supported")
def _create_llm(self, llm_config: dict, chat=False) -> object:
"""
@ -197,7 +204,7 @@ class AbstractGraph(ABC):
raise ValueError(
"Model provided by the configuration not supported")
def _create_default_embedder(self) -> object:
def _create_default_embedder(self, llm_config=None) -> object:
"""
Create an embedding model instance based on the chosen llm model.
@ -207,6 +214,8 @@ class AbstractGraph(ABC):
Raises:
ValueError: If the model is not supported.
"""
if isinstance(self.llm_model, Gemini):
return GoogleGenerativeAIEmbeddings(google_api_key=llm_config['api_key'], model="models/embedding-001")
if isinstance(self.llm_model, OpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
@ -241,7 +250,6 @@ class AbstractGraph(ABC):
Raises:
KeyError: If the model is not supported.
"""
if 'model_instance' in embedder_config:
return embedder_config['model_instance']
# Instantiate the embedding model based on the model name

View File

@ -26,6 +26,7 @@ models_tokens = {
},
"gemini": {
"gemini-pro": 128000,
"models/embedding-001": 2048
},
"ollama": {