mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-25 21:11:11 +08:00
feat: add gemini embeddings
This commit is contained in:
parent
da8c72ce13
commit
79daa4c112
@ -1,16 +1,14 @@
|
||||
"""
|
||||
AbstractGraph Module
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
|
||||
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
|
||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
|
||||
from ..helpers import models_tokens
|
||||
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude
|
||||
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||
|
||||
|
||||
class AbstractGraph(ABC):
|
||||
@ -69,7 +67,7 @@ class AbstractGraph(ABC):
|
||||
self.model_token = models_tokens["azure"][llm.model_name]
|
||||
except KeyError:
|
||||
raise KeyError("Model not supported")
|
||||
|
||||
|
||||
elif 'HuggingFaceEndpoint' in str(type(llm)):
|
||||
if 'mistral' in llm.repo_id:
|
||||
try:
|
||||
@ -229,14 +227,11 @@ class AbstractGraph(ABC):
|
||||
|
||||
if 'model_instance' in embedder_config:
|
||||
return embedder_config['model_instance']
|
||||
|
||||
# Instantiate the embedding model based on the model name
|
||||
if "openai" in embedder_config["model"]:
|
||||
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
|
||||
|
||||
elif "azure" in embedder_config["model"]:
|
||||
return AzureOpenAIEmbeddings()
|
||||
|
||||
elif "ollama" in embedder_config["model"]:
|
||||
embedder_config["model"] = embedder_config["model"].split("/")[-1]
|
||||
try:
|
||||
@ -244,14 +239,18 @@ class AbstractGraph(ABC):
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return OllamaEmbeddings(**embedder_config)
|
||||
|
||||
elif "hugging_face" in embedder_config["model"]:
|
||||
try:
|
||||
models_tokens["hugging_face"][embedder_config["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported")from exc
|
||||
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
|
||||
|
||||
elif "gemini" in embedder_config["model"]:
|
||||
try:
|
||||
models_tokens["gemini"][embedder_config["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported")from exc
|
||||
return GoogleGenerativeAIEmbeddings(model=embedder_config["model"])
|
||||
elif "bedrock" in embedder_config["model"]:
|
||||
embedder_config["model"] = embedder_config["model"].split("/")[-1]
|
||||
try:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user