From 726de288982700dab8ab9f22af8e26f01c6198a7 Mon Sep 17 00:00:00 2001 From: Shubham Kamboj Date: Mon, 6 May 2024 15:40:03 +0530 Subject: [PATCH] feat: Fix bug for gemini case when embeddings config not passed --- scrapegraphai/graphs/abstract_graph.py | 14 +++++++++++--- scrapegraphai/helpers/models_tokens.py | 1 + 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 089b0f95..096ce84e 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -46,7 +46,7 @@ class AbstractGraph(ABC): self.source = source self.config = config self.llm_model = self._create_llm(config["llm"], chat=True) - self.embedder_model = self._create_default_embedder( + self.embedder_model = self._create_default_embedder(llm_config=config["llm"] ) if "embeddings" not in config else self._create_embedder( config["embeddings"]) @@ -91,6 +91,13 @@ class AbstractGraph(ABC): self.model_token = models_tokens['mistral'][llm.repo_id] except KeyError: raise KeyError("Model not supported") + + elif 'Google' in str(type(llm)): + try: + if 'gemini' in llm.model: + self.model_token = models_tokens['gemini'][llm.model] + except KeyError: + raise KeyError("Model not supported") def _create_llm(self, llm_config: dict, chat=False) -> object: """ @@ -197,7 +204,7 @@ class AbstractGraph(ABC): raise ValueError( "Model provided by the configuration not supported") - def _create_default_embedder(self) -> object: + def _create_default_embedder(self, llm_config=None) -> object: """ Create an embedding model instance based on the chosen llm model. @@ -207,6 +214,8 @@ class AbstractGraph(ABC): Raises: ValueError: If the model is not supported. """ + if isinstance(self.llm_model, Gemini): + return GoogleGenerativeAIEmbeddings(google_api_key=llm_config['api_key'], model="models/embedding-001") if isinstance(self.llm_model, OpenAI): return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) elif isinstance(self.llm_model, AzureOpenAIEmbeddings): @@ -241,7 +250,6 @@ class AbstractGraph(ABC): Raises: KeyError: If the model is not supported. """ - if 'model_instance' in embedder_config: return embedder_config['model_instance'] # Instantiate the embedding model based on the model name diff --git a/scrapegraphai/helpers/models_tokens.py b/scrapegraphai/helpers/models_tokens.py index 121ae63c..ed73285e 100644 --- a/scrapegraphai/helpers/models_tokens.py +++ b/scrapegraphai/helpers/models_tokens.py @@ -26,6 +26,7 @@ models_tokens = { }, "gemini": { "gemini-pro": 128000, + "models/embedding-001": 2048 }, "ollama": {