mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
refactor: move embeddings code from AbstractGraph to RAGNode
Some checks are pending
/ build (push) Waiting to run
Some checks are pending
/ build (push) Waiting to run
This commit is contained in:
parent
bb73d916a1
commit
a94ebcde00
@ -7,15 +7,8 @@ from typing import Optional
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_community.chat_models import ChatOllama, ErnieBotChat
|
||||
from langchain_aws import BedrockEmbeddings, ChatBedrock
|
||||
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
|
||||
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
|
||||
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
|
||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
|
||||
from langchain_community.chat_models import ErnieBotChat
|
||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
from ..helpers import models_tokens
|
||||
@ -66,8 +59,6 @@ class AbstractGraph(ABC):
|
||||
self.config = config
|
||||
self.schema = schema
|
||||
self.llm_model = self._create_llm(config["llm"])
|
||||
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]) if "embeddings" not in config else self._create_embedder(
|
||||
config["embeddings"])
|
||||
self.verbose = False if config is None else config.get(
|
||||
"verbose", False)
|
||||
self.headless = True if config is None else config.get(
|
||||
@ -237,116 +228,6 @@ class AbstractGraph(ABC):
|
||||
# Raise an error if the model did not match any of the previous cases
|
||||
raise ValueError("Model provided by the configuration not supported")
|
||||
|
||||
def _create_default_embedder(self, llm_config=None) -> object:
|
||||
"""
|
||||
Create an embedding model instance based on the chosen llm model.
|
||||
|
||||
Returns:
|
||||
object: An instance of the embedding model client.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not supported.
|
||||
"""
|
||||
if isinstance(self.llm_model, ChatGoogleGenerativeAI):
|
||||
return GoogleGenerativeAIEmbeddings(
|
||||
google_api_key=llm_config["api_key"], model="models/embedding-001"
|
||||
)
|
||||
if isinstance(self.llm_model, ChatOpenAI):
|
||||
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key,
|
||||
base_url=self.llm_model.openai_api_base)
|
||||
elif isinstance(self.llm_model, DeepSeek):
|
||||
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
|
||||
elif isinstance(self.llm_model, ChatVertexAI):
|
||||
return VertexAIEmbeddings()
|
||||
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
|
||||
return self.llm_model
|
||||
elif isinstance(self.llm_model, AzureChatOpenAI):
|
||||
return AzureOpenAIEmbeddings()
|
||||
elif isinstance(self.llm_model, ChatFireworks):
|
||||
return FireworksEmbeddings(model=self.llm_model.model_name)
|
||||
elif isinstance(self.llm_model, ChatNVIDIA):
|
||||
return NVIDIAEmbeddings(model=self.llm_model.model_name)
|
||||
elif isinstance(self.llm_model, ChatOllama):
|
||||
# unwrap the kwargs from the model whihc is a dict
|
||||
params = self.llm_model._lc_kwargs
|
||||
# remove streaming and temperature
|
||||
params.pop("streaming", None)
|
||||
params.pop("temperature", None)
|
||||
|
||||
return OllamaEmbeddings(**params)
|
||||
elif isinstance(self.llm_model, ChatHuggingFace):
|
||||
return HuggingFaceEmbeddings(model=self.llm_model.model)
|
||||
elif isinstance(self.llm_model, ChatBedrock):
|
||||
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
|
||||
else:
|
||||
raise ValueError("Embedding Model missing or not supported")
|
||||
|
||||
def _create_embedder(self, embedder_config: dict) -> object:
|
||||
"""
|
||||
Create an embedding model instance based on the configuration provided.
|
||||
|
||||
Args:
|
||||
embedder_config (dict): Configuration parameters for the embedding model.
|
||||
|
||||
Returns:
|
||||
object: An instance of the embedding model client.
|
||||
|
||||
Raises:
|
||||
KeyError: If the model is not supported.
|
||||
"""
|
||||
embedder_params = {**embedder_config}
|
||||
if "model_instance" in embedder_config:
|
||||
return embedder_params["model_instance"]
|
||||
# Instantiate the embedding model based on the model name
|
||||
if "openai" in embedder_params["model"]:
|
||||
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
|
||||
if "azure" in embedder_params["model"]:
|
||||
return AzureOpenAIEmbeddings()
|
||||
if "nvidia" in embedder_params["model"]:
|
||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||
try:
|
||||
models_tokens["nvidia"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return NVIDIAEmbeddings(model=embedder_params["model"],
|
||||
nvidia_api_key=embedder_params["api_key"])
|
||||
if "ollama" in embedder_params["model"]:
|
||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||
try:
|
||||
models_tokens["ollama"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return OllamaEmbeddings(**embedder_params)
|
||||
if "hugging_face" in embedder_params["model"]:
|
||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||
try:
|
||||
models_tokens["hugging_face"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return HuggingFaceEmbeddings(model=embedder_params["model"])
|
||||
if "fireworks" in embedder_params["model"]:
|
||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||
try:
|
||||
models_tokens["fireworks"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return FireworksEmbeddings(model=embedder_params["model"])
|
||||
if "gemini" in embedder_params["model"]:
|
||||
try:
|
||||
models_tokens["gemini"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return GoogleGenerativeAIEmbeddings(model=embedder_params["model"])
|
||||
if "bedrock" in embedder_params["model"]:
|
||||
embedder_params["model"] = embedder_params["model"].split("/")[-1]
|
||||
client = embedder_params.get("client", None)
|
||||
try:
|
||||
models_tokens["bedrock"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return BedrockEmbeddings(client=client, model_id=embedder_params["model"])
|
||||
|
||||
raise ValueError("Model provided by the configuration not supported")
|
||||
|
||||
def get_state(self, key=None) -> dict:
|
||||
""" ""
|
||||
|
||||
@ -14,8 +14,20 @@ from langchain.retrievers.document_compressors import (
|
||||
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from langchain_aws import BedrockEmbeddings, ChatBedrock
|
||||
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
|
||||
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
|
||||
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
|
||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
from .base_node import BaseNode
|
||||
from ..helpers import models_tokens
|
||||
from ..models import DeepSeek
|
||||
|
||||
|
||||
class RAGNode(BaseNode):
|
||||
@ -95,10 +107,21 @@ class RAGNode(BaseNode):
|
||||
self.logger.info("--- (updated chunks metadata) ---")
|
||||
|
||||
# check if embedder_model is provided, if not use llm_model
|
||||
self.embedder_model = (
|
||||
self.embedder_model if self.embedder_model else self.llm_model
|
||||
)
|
||||
embeddings = self.embedder_model
|
||||
if self.embedder_model is not None:
|
||||
embeddings = self.embedder_model
|
||||
elif 'embeddings' in self.node_config:
|
||||
try:
|
||||
embeddings = self._create_embedder(self.node_config['embedder_config'])
|
||||
except Exception:
|
||||
try:
|
||||
embeddings = self._create_default_embedder()
|
||||
self.embedder_model = embeddings
|
||||
except ValueError:
|
||||
embeddings = self.llm_model
|
||||
self.embedder_model = self.llm_model
|
||||
else:
|
||||
embeddings = self.llm_model
|
||||
self.embedder_model = self.llm_model
|
||||
|
||||
folder_name = self.node_config.get("cache_path", "cache")
|
||||
|
||||
@ -141,3 +164,116 @@ class RAGNode(BaseNode):
|
||||
|
||||
state.update({self.output[0]: compressed_docs})
|
||||
return state
|
||||
|
||||
|
||||
def _create_default_embedder(self, llm_config=None) -> object:
|
||||
"""
|
||||
Create an embedding model instance based on the chosen llm model.
|
||||
|
||||
Returns:
|
||||
object: An instance of the embedding model client.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not supported.
|
||||
"""
|
||||
if isinstance(self.llm_model, ChatGoogleGenerativeAI):
|
||||
return GoogleGenerativeAIEmbeddings(
|
||||
google_api_key=llm_config["api_key"], model="models/embedding-001"
|
||||
)
|
||||
if isinstance(self.llm_model, ChatOpenAI):
|
||||
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key,
|
||||
base_url=self.llm_model.openai_api_base)
|
||||
elif isinstance(self.llm_model, DeepSeek):
|
||||
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
|
||||
elif isinstance(self.llm_model, ChatVertexAI):
|
||||
return VertexAIEmbeddings()
|
||||
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
|
||||
return self.llm_model
|
||||
elif isinstance(self.llm_model, AzureChatOpenAI):
|
||||
return AzureOpenAIEmbeddings()
|
||||
elif isinstance(self.llm_model, ChatFireworks):
|
||||
return FireworksEmbeddings(model=self.llm_model.model_name)
|
||||
elif isinstance(self.llm_model, ChatNVIDIA):
|
||||
return NVIDIAEmbeddings(model=self.llm_model.model_name)
|
||||
elif isinstance(self.llm_model, ChatOllama):
|
||||
# unwrap the kwargs from the model whihc is a dict
|
||||
params = self.llm_model._lc_kwargs
|
||||
# remove streaming and temperature
|
||||
params.pop("streaming", None)
|
||||
params.pop("temperature", None)
|
||||
|
||||
return OllamaEmbeddings(**params)
|
||||
elif isinstance(self.llm_model, ChatHuggingFace):
|
||||
return HuggingFaceEmbeddings(model=self.llm_model.model)
|
||||
elif isinstance(self.llm_model, ChatBedrock):
|
||||
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
|
||||
else:
|
||||
raise ValueError("Embedding Model missing or not supported")
|
||||
|
||||
|
||||
def _create_embedder(self, embedder_config: dict) -> object:
|
||||
"""
|
||||
Create an embedding model instance based on the configuration provided.
|
||||
|
||||
Args:
|
||||
embedder_config (dict): Configuration parameters for the embedding model.
|
||||
|
||||
Returns:
|
||||
object: An instance of the embedding model client.
|
||||
|
||||
Raises:
|
||||
KeyError: If the model is not supported.
|
||||
"""
|
||||
embedder_params = {**embedder_config}
|
||||
if "model_instance" in embedder_config:
|
||||
return embedder_params["model_instance"]
|
||||
# Instantiate the embedding model based on the model name
|
||||
if "openai" in embedder_params["model"]:
|
||||
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
|
||||
if "azure" in embedder_params["model"]:
|
||||
return AzureOpenAIEmbeddings()
|
||||
if "nvidia" in embedder_params["model"]:
|
||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||
try:
|
||||
models_tokens["nvidia"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return NVIDIAEmbeddings(model=embedder_params["model"],
|
||||
nvidia_api_key=embedder_params["api_key"])
|
||||
if "ollama" in embedder_params["model"]:
|
||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||
try:
|
||||
models_tokens["ollama"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return OllamaEmbeddings(**embedder_params)
|
||||
if "hugging_face" in embedder_params["model"]:
|
||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||
try:
|
||||
models_tokens["hugging_face"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return HuggingFaceEmbeddings(model=embedder_params["model"])
|
||||
if "fireworks" in embedder_params["model"]:
|
||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||
try:
|
||||
models_tokens["fireworks"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return FireworksEmbeddings(model=embedder_params["model"])
|
||||
if "gemini" in embedder_params["model"]:
|
||||
try:
|
||||
models_tokens["gemini"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return GoogleGenerativeAIEmbeddings(model=embedder_params["model"])
|
||||
if "bedrock" in embedder_params["model"]:
|
||||
embedder_params["model"] = embedder_params["model"].split("/")[-1]
|
||||
client = embedder_params.get("client", None)
|
||||
try:
|
||||
models_tokens["bedrock"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return BedrockEmbeddings(client=client, model_id=embedder_params["model"])
|
||||
|
||||
raise ValueError("Model provided by the configuration not supported")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user