mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-25 21:11:11 +08:00
feat: add gemini embeddings
This commit is contained in:
parent
da8c72ce13
commit
79daa4c112
@ -1,16 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
AbstractGraph Module
|
AbstractGraph Module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
|
|
||||||
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
|
|
||||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||||
|
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
|
||||||
from ..helpers import models_tokens
|
from ..helpers import models_tokens
|
||||||
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude
|
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude
|
||||||
|
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
|
||||||
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||||
|
|
||||||
|
|
||||||
class AbstractGraph(ABC):
|
class AbstractGraph(ABC):
|
||||||
@ -229,14 +227,11 @@ class AbstractGraph(ABC):
|
|||||||
|
|
||||||
if 'model_instance' in embedder_config:
|
if 'model_instance' in embedder_config:
|
||||||
return embedder_config['model_instance']
|
return embedder_config['model_instance']
|
||||||
|
|
||||||
# Instantiate the embedding model based on the model name
|
# Instantiate the embedding model based on the model name
|
||||||
if "openai" in embedder_config["model"]:
|
if "openai" in embedder_config["model"]:
|
||||||
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
|
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
|
||||||
|
|
||||||
elif "azure" in embedder_config["model"]:
|
elif "azure" in embedder_config["model"]:
|
||||||
return AzureOpenAIEmbeddings()
|
return AzureOpenAIEmbeddings()
|
||||||
|
|
||||||
elif "ollama" in embedder_config["model"]:
|
elif "ollama" in embedder_config["model"]:
|
||||||
embedder_config["model"] = embedder_config["model"].split("/")[-1]
|
embedder_config["model"] = embedder_config["model"].split("/")[-1]
|
||||||
try:
|
try:
|
||||||
@ -244,14 +239,18 @@ class AbstractGraph(ABC):
|
|||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise KeyError("Model not supported") from exc
|
raise KeyError("Model not supported") from exc
|
||||||
return OllamaEmbeddings(**embedder_config)
|
return OllamaEmbeddings(**embedder_config)
|
||||||
|
|
||||||
elif "hugging_face" in embedder_config["model"]:
|
elif "hugging_face" in embedder_config["model"]:
|
||||||
try:
|
try:
|
||||||
models_tokens["hugging_face"][embedder_config["model"]]
|
models_tokens["hugging_face"][embedder_config["model"]]
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise KeyError("Model not supported")from exc
|
raise KeyError("Model not supported")from exc
|
||||||
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
|
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
|
||||||
|
elif "gemini" in embedder_config["model"]:
|
||||||
|
try:
|
||||||
|
models_tokens["gemini"][embedder_config["model"]]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError("Model not supported")from exc
|
||||||
|
return GoogleGenerativeAIEmbeddings(model=embedder_config["model"])
|
||||||
elif "bedrock" in embedder_config["model"]:
|
elif "bedrock" in embedder_config["model"]:
|
||||||
embedder_config["model"] = embedder_config["model"].split("/")[-1]
|
embedder_config["model"] = embedder_config["model"].split("/")[-1]
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user