mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
refactor: move embeddings code from AbstractGraph to RAGNode
Some checks are pending
/ build (push) Waiting to run
Some checks are pending
/ build (push) Waiting to run
This commit is contained in:
parent
bb73d916a1
commit
a94ebcde00
@ -7,15 +7,8 @@ from typing import Optional
|
|||||||
import uuid
|
import uuid
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain_community.chat_models import ChatOllama, ErnieBotChat
|
from langchain_community.chat_models import ErnieBotChat
|
||||||
from langchain_aws import BedrockEmbeddings, ChatBedrock
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||||
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings
|
|
||||||
from langchain_community.embeddings import OllamaEmbeddings
|
|
||||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
|
|
||||||
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
|
|
||||||
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
|
|
||||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
|
|
||||||
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
|
||||||
from ..helpers import models_tokens
|
from ..helpers import models_tokens
|
||||||
@ -66,8 +59,6 @@ class AbstractGraph(ABC):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.llm_model = self._create_llm(config["llm"])
|
self.llm_model = self._create_llm(config["llm"])
|
||||||
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]) if "embeddings" not in config else self._create_embedder(
|
|
||||||
config["embeddings"])
|
|
||||||
self.verbose = False if config is None else config.get(
|
self.verbose = False if config is None else config.get(
|
||||||
"verbose", False)
|
"verbose", False)
|
||||||
self.headless = True if config is None else config.get(
|
self.headless = True if config is None else config.get(
|
||||||
@ -237,116 +228,6 @@ class AbstractGraph(ABC):
|
|||||||
# Raise an error if the model did not match any of the previous cases
|
# Raise an error if the model did not match any of the previous cases
|
||||||
raise ValueError("Model provided by the configuration not supported")
|
raise ValueError("Model provided by the configuration not supported")
|
||||||
|
|
||||||
def _create_default_embedder(self, llm_config=None) -> object:
|
|
||||||
"""
|
|
||||||
Create an embedding model instance based on the chosen llm model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
object: An instance of the embedding model client.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the model is not supported.
|
|
||||||
"""
|
|
||||||
if isinstance(self.llm_model, ChatGoogleGenerativeAI):
|
|
||||||
return GoogleGenerativeAIEmbeddings(
|
|
||||||
google_api_key=llm_config["api_key"], model="models/embedding-001"
|
|
||||||
)
|
|
||||||
if isinstance(self.llm_model, ChatOpenAI):
|
|
||||||
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key,
|
|
||||||
base_url=self.llm_model.openai_api_base)
|
|
||||||
elif isinstance(self.llm_model, DeepSeek):
|
|
||||||
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
|
|
||||||
elif isinstance(self.llm_model, ChatVertexAI):
|
|
||||||
return VertexAIEmbeddings()
|
|
||||||
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
|
|
||||||
return self.llm_model
|
|
||||||
elif isinstance(self.llm_model, AzureChatOpenAI):
|
|
||||||
return AzureOpenAIEmbeddings()
|
|
||||||
elif isinstance(self.llm_model, ChatFireworks):
|
|
||||||
return FireworksEmbeddings(model=self.llm_model.model_name)
|
|
||||||
elif isinstance(self.llm_model, ChatNVIDIA):
|
|
||||||
return NVIDIAEmbeddings(model=self.llm_model.model_name)
|
|
||||||
elif isinstance(self.llm_model, ChatOllama):
|
|
||||||
# unwrap the kwargs from the model whihc is a dict
|
|
||||||
params = self.llm_model._lc_kwargs
|
|
||||||
# remove streaming and temperature
|
|
||||||
params.pop("streaming", None)
|
|
||||||
params.pop("temperature", None)
|
|
||||||
|
|
||||||
return OllamaEmbeddings(**params)
|
|
||||||
elif isinstance(self.llm_model, ChatHuggingFace):
|
|
||||||
return HuggingFaceEmbeddings(model=self.llm_model.model)
|
|
||||||
elif isinstance(self.llm_model, ChatBedrock):
|
|
||||||
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
|
|
||||||
else:
|
|
||||||
raise ValueError("Embedding Model missing or not supported")
|
|
||||||
|
|
||||||
def _create_embedder(self, embedder_config: dict) -> object:
|
|
||||||
"""
|
|
||||||
Create an embedding model instance based on the configuration provided.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedder_config (dict): Configuration parameters for the embedding model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
object: An instance of the embedding model client.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If the model is not supported.
|
|
||||||
"""
|
|
||||||
embedder_params = {**embedder_config}
|
|
||||||
if "model_instance" in embedder_config:
|
|
||||||
return embedder_params["model_instance"]
|
|
||||||
# Instantiate the embedding model based on the model name
|
|
||||||
if "openai" in embedder_params["model"]:
|
|
||||||
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
|
|
||||||
if "azure" in embedder_params["model"]:
|
|
||||||
return AzureOpenAIEmbeddings()
|
|
||||||
if "nvidia" in embedder_params["model"]:
|
|
||||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
|
||||||
try:
|
|
||||||
models_tokens["nvidia"][embedder_params["model"]]
|
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError("Model not supported") from exc
|
|
||||||
return NVIDIAEmbeddings(model=embedder_params["model"],
|
|
||||||
nvidia_api_key=embedder_params["api_key"])
|
|
||||||
if "ollama" in embedder_params["model"]:
|
|
||||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
|
||||||
try:
|
|
||||||
models_tokens["ollama"][embedder_params["model"]]
|
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError("Model not supported") from exc
|
|
||||||
return OllamaEmbeddings(**embedder_params)
|
|
||||||
if "hugging_face" in embedder_params["model"]:
|
|
||||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
|
||||||
try:
|
|
||||||
models_tokens["hugging_face"][embedder_params["model"]]
|
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError("Model not supported") from exc
|
|
||||||
return HuggingFaceEmbeddings(model=embedder_params["model"])
|
|
||||||
if "fireworks" in embedder_params["model"]:
|
|
||||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
|
||||||
try:
|
|
||||||
models_tokens["fireworks"][embedder_params["model"]]
|
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError("Model not supported") from exc
|
|
||||||
return FireworksEmbeddings(model=embedder_params["model"])
|
|
||||||
if "gemini" in embedder_params["model"]:
|
|
||||||
try:
|
|
||||||
models_tokens["gemini"][embedder_params["model"]]
|
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError("Model not supported") from exc
|
|
||||||
return GoogleGenerativeAIEmbeddings(model=embedder_params["model"])
|
|
||||||
if "bedrock" in embedder_params["model"]:
|
|
||||||
embedder_params["model"] = embedder_params["model"].split("/")[-1]
|
|
||||||
client = embedder_params.get("client", None)
|
|
||||||
try:
|
|
||||||
models_tokens["bedrock"][embedder_params["model"]]
|
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError("Model not supported") from exc
|
|
||||||
return BedrockEmbeddings(client=client, model_id=embedder_params["model"])
|
|
||||||
|
|
||||||
raise ValueError("Model provided by the configuration not supported")
|
|
||||||
|
|
||||||
def get_state(self, key=None) -> dict:
|
def get_state(self, key=None) -> dict:
|
||||||
""" ""
|
""" ""
|
||||||
|
|||||||
@ -14,8 +14,20 @@ from langchain.retrievers.document_compressors import (
|
|||||||
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
||||||
from langchain_community.vectorstores import FAISS
|
from langchain_community.vectorstores import FAISS
|
||||||
|
|
||||||
|
from langchain_community.chat_models import ChatOllama
|
||||||
|
from langchain_aws import BedrockEmbeddings, ChatBedrock
|
||||||
|
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings
|
||||||
|
from langchain_community.embeddings import OllamaEmbeddings
|
||||||
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
|
||||||
|
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
|
||||||
|
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
|
||||||
|
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
|
||||||
|
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
|
||||||
|
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
from .base_node import BaseNode
|
from .base_node import BaseNode
|
||||||
|
from ..helpers import models_tokens
|
||||||
|
from ..models import DeepSeek
|
||||||
|
|
||||||
|
|
||||||
class RAGNode(BaseNode):
|
class RAGNode(BaseNode):
|
||||||
@ -95,10 +107,21 @@ class RAGNode(BaseNode):
|
|||||||
self.logger.info("--- (updated chunks metadata) ---")
|
self.logger.info("--- (updated chunks metadata) ---")
|
||||||
|
|
||||||
# check if embedder_model is provided, if not use llm_model
|
# check if embedder_model is provided, if not use llm_model
|
||||||
self.embedder_model = (
|
if self.embedder_model is not None:
|
||||||
self.embedder_model if self.embedder_model else self.llm_model
|
embeddings = self.embedder_model
|
||||||
)
|
elif 'embeddings' in self.node_config:
|
||||||
embeddings = self.embedder_model
|
try:
|
||||||
|
embeddings = self._create_embedder(self.node_config['embedder_config'])
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
embeddings = self._create_default_embedder()
|
||||||
|
self.embedder_model = embeddings
|
||||||
|
except ValueError:
|
||||||
|
embeddings = self.llm_model
|
||||||
|
self.embedder_model = self.llm_model
|
||||||
|
else:
|
||||||
|
embeddings = self.llm_model
|
||||||
|
self.embedder_model = self.llm_model
|
||||||
|
|
||||||
folder_name = self.node_config.get("cache_path", "cache")
|
folder_name = self.node_config.get("cache_path", "cache")
|
||||||
|
|
||||||
@ -141,3 +164,116 @@ class RAGNode(BaseNode):
|
|||||||
|
|
||||||
state.update({self.output[0]: compressed_docs})
|
state.update({self.output[0]: compressed_docs})
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def _create_default_embedder(self, llm_config=None) -> object:
|
||||||
|
"""
|
||||||
|
Create an embedding model instance based on the chosen llm model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: An instance of the embedding model client.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model is not supported.
|
||||||
|
"""
|
||||||
|
if isinstance(self.llm_model, ChatGoogleGenerativeAI):
|
||||||
|
return GoogleGenerativeAIEmbeddings(
|
||||||
|
google_api_key=llm_config["api_key"], model="models/embedding-001"
|
||||||
|
)
|
||||||
|
if isinstance(self.llm_model, ChatOpenAI):
|
||||||
|
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key,
|
||||||
|
base_url=self.llm_model.openai_api_base)
|
||||||
|
elif isinstance(self.llm_model, DeepSeek):
|
||||||
|
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
|
||||||
|
elif isinstance(self.llm_model, ChatVertexAI):
|
||||||
|
return VertexAIEmbeddings()
|
||||||
|
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
|
||||||
|
return self.llm_model
|
||||||
|
elif isinstance(self.llm_model, AzureChatOpenAI):
|
||||||
|
return AzureOpenAIEmbeddings()
|
||||||
|
elif isinstance(self.llm_model, ChatFireworks):
|
||||||
|
return FireworksEmbeddings(model=self.llm_model.model_name)
|
||||||
|
elif isinstance(self.llm_model, ChatNVIDIA):
|
||||||
|
return NVIDIAEmbeddings(model=self.llm_model.model_name)
|
||||||
|
elif isinstance(self.llm_model, ChatOllama):
|
||||||
|
# unwrap the kwargs from the model whihc is a dict
|
||||||
|
params = self.llm_model._lc_kwargs
|
||||||
|
# remove streaming and temperature
|
||||||
|
params.pop("streaming", None)
|
||||||
|
params.pop("temperature", None)
|
||||||
|
|
||||||
|
return OllamaEmbeddings(**params)
|
||||||
|
elif isinstance(self.llm_model, ChatHuggingFace):
|
||||||
|
return HuggingFaceEmbeddings(model=self.llm_model.model)
|
||||||
|
elif isinstance(self.llm_model, ChatBedrock):
|
||||||
|
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
|
||||||
|
else:
|
||||||
|
raise ValueError("Embedding Model missing or not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def _create_embedder(self, embedder_config: dict) -> object:
|
||||||
|
"""
|
||||||
|
Create an embedding model instance based on the configuration provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedder_config (dict): Configuration parameters for the embedding model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: An instance of the embedding model client.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the model is not supported.
|
||||||
|
"""
|
||||||
|
embedder_params = {**embedder_config}
|
||||||
|
if "model_instance" in embedder_config:
|
||||||
|
return embedder_params["model_instance"]
|
||||||
|
# Instantiate the embedding model based on the model name
|
||||||
|
if "openai" in embedder_params["model"]:
|
||||||
|
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
|
||||||
|
if "azure" in embedder_params["model"]:
|
||||||
|
return AzureOpenAIEmbeddings()
|
||||||
|
if "nvidia" in embedder_params["model"]:
|
||||||
|
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||||
|
try:
|
||||||
|
models_tokens["nvidia"][embedder_params["model"]]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError("Model not supported") from exc
|
||||||
|
return NVIDIAEmbeddings(model=embedder_params["model"],
|
||||||
|
nvidia_api_key=embedder_params["api_key"])
|
||||||
|
if "ollama" in embedder_params["model"]:
|
||||||
|
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||||
|
try:
|
||||||
|
models_tokens["ollama"][embedder_params["model"]]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError("Model not supported") from exc
|
||||||
|
return OllamaEmbeddings(**embedder_params)
|
||||||
|
if "hugging_face" in embedder_params["model"]:
|
||||||
|
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||||
|
try:
|
||||||
|
models_tokens["hugging_face"][embedder_params["model"]]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError("Model not supported") from exc
|
||||||
|
return HuggingFaceEmbeddings(model=embedder_params["model"])
|
||||||
|
if "fireworks" in embedder_params["model"]:
|
||||||
|
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||||
|
try:
|
||||||
|
models_tokens["fireworks"][embedder_params["model"]]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError("Model not supported") from exc
|
||||||
|
return FireworksEmbeddings(model=embedder_params["model"])
|
||||||
|
if "gemini" in embedder_params["model"]:
|
||||||
|
try:
|
||||||
|
models_tokens["gemini"][embedder_params["model"]]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError("Model not supported") from exc
|
||||||
|
return GoogleGenerativeAIEmbeddings(model=embedder_params["model"])
|
||||||
|
if "bedrock" in embedder_params["model"]:
|
||||||
|
embedder_params["model"] = embedder_params["model"].split("/")[-1]
|
||||||
|
client = embedder_params.get("client", None)
|
||||||
|
try:
|
||||||
|
models_tokens["bedrock"][embedder_params["model"]]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError("Model not supported") from exc
|
||||||
|
return BedrockEmbeddings(client=client, model_id=embedder_params["model"])
|
||||||
|
|
||||||
|
raise ValueError("Model provided by the configuration not supported")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user