mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
feat: Fix bug for gemini case when embeddings config not passed
This commit is contained in:
parent
5e1d5db6da
commit
726de28898
@ -46,7 +46,7 @@ class AbstractGraph(ABC):
|
|||||||
self.source = source
|
self.source = source
|
||||||
self.config = config
|
self.config = config
|
||||||
self.llm_model = self._create_llm(config["llm"], chat=True)
|
self.llm_model = self._create_llm(config["llm"], chat=True)
|
||||||
self.embedder_model = self._create_default_embedder(
|
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]
|
||||||
) if "embeddings" not in config else self._create_embedder(
|
) if "embeddings" not in config else self._create_embedder(
|
||||||
config["embeddings"])
|
config["embeddings"])
|
||||||
|
|
||||||
@ -91,6 +91,13 @@ class AbstractGraph(ABC):
|
|||||||
self.model_token = models_tokens['mistral'][llm.repo_id]
|
self.model_token = models_tokens['mistral'][llm.repo_id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise KeyError("Model not supported")
|
raise KeyError("Model not supported")
|
||||||
|
|
||||||
|
elif 'Google' in str(type(llm)):
|
||||||
|
try:
|
||||||
|
if 'gemini' in llm.model:
|
||||||
|
self.model_token = models_tokens['gemini'][llm.model]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError("Model not supported")
|
||||||
|
|
||||||
def _create_llm(self, llm_config: dict, chat=False) -> object:
|
def _create_llm(self, llm_config: dict, chat=False) -> object:
|
||||||
"""
|
"""
|
||||||
@ -197,7 +204,7 @@ class AbstractGraph(ABC):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Model provided by the configuration not supported")
|
"Model provided by the configuration not supported")
|
||||||
|
|
||||||
def _create_default_embedder(self) -> object:
|
def _create_default_embedder(self, llm_config=None) -> object:
|
||||||
"""
|
"""
|
||||||
Create an embedding model instance based on the chosen llm model.
|
Create an embedding model instance based on the chosen llm model.
|
||||||
|
|
||||||
@ -207,6 +214,8 @@ class AbstractGraph(ABC):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the model is not supported.
|
ValueError: If the model is not supported.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(self.llm_model, Gemini):
|
||||||
|
return GoogleGenerativeAIEmbeddings(google_api_key=llm_config['api_key'], model="models/embedding-001")
|
||||||
if isinstance(self.llm_model, OpenAI):
|
if isinstance(self.llm_model, OpenAI):
|
||||||
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
|
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
|
||||||
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
|
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
|
||||||
@ -241,7 +250,6 @@ class AbstractGraph(ABC):
|
|||||||
Raises:
|
Raises:
|
||||||
KeyError: If the model is not supported.
|
KeyError: If the model is not supported.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if 'model_instance' in embedder_config:
|
if 'model_instance' in embedder_config:
|
||||||
return embedder_config['model_instance']
|
return embedder_config['model_instance']
|
||||||
# Instantiate the embedding model based on the model name
|
# Instantiate the embedding model based on the model name
|
||||||
|
|||||||
@ -26,6 +26,7 @@ models_tokens = {
|
|||||||
},
|
},
|
||||||
"gemini": {
|
"gemini": {
|
||||||
"gemini-pro": 128000,
|
"gemini-pro": 128000,
|
||||||
|
"models/embedding-001": 2048
|
||||||
},
|
},
|
||||||
|
|
||||||
"ollama": {
|
"ollama": {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user