diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 306901e8..4ed08057 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -7,15 +7,8 @@ from typing import Optional import uuid from pydantic import BaseModel -from langchain_community.chat_models import ChatOllama, ErnieBotChat -from langchain_aws import BedrockEmbeddings, ChatBedrock -from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings -from langchain_community.embeddings import OllamaEmbeddings -from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI -from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings -from langchain_fireworks import FireworksEmbeddings, ChatFireworks -from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI -from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA +from langchain_community.chat_models import ErnieBotChat +from langchain_nvidia_ai_endpoints import ChatNVIDIA from langchain.chat_models import init_chat_model from ..helpers import models_tokens @@ -66,8 +59,6 @@ class AbstractGraph(ABC): self.config = config self.schema = schema self.llm_model = self._create_llm(config["llm"]) - self.embedder_model = self._create_default_embedder(llm_config=config["llm"]) if "embeddings" not in config else self._create_embedder( - config["embeddings"]) self.verbose = False if config is None else config.get( "verbose", False) self.headless = True if config is None else config.get( @@ -237,116 +228,6 @@ class AbstractGraph(ABC): # Raise an error if the model did not match any of the previous cases raise ValueError("Model provided by the configuration not supported") - def _create_default_embedder(self, llm_config=None) -> object: - """ - Create an embedding model instance based on the chosen llm model. - - Returns: - object: An instance of the embedding model client. - - Raises: - ValueError: If the model is not supported. - """ - if isinstance(self.llm_model, ChatGoogleGenerativeAI): - return GoogleGenerativeAIEmbeddings( - google_api_key=llm_config["api_key"], model="models/embedding-001" - ) - if isinstance(self.llm_model, ChatOpenAI): - return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key, - base_url=self.llm_model.openai_api_base) - elif isinstance(self.llm_model, DeepSeek): - return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) - elif isinstance(self.llm_model, ChatVertexAI): - return VertexAIEmbeddings() - elif isinstance(self.llm_model, AzureOpenAIEmbeddings): - return self.llm_model - elif isinstance(self.llm_model, AzureChatOpenAI): - return AzureOpenAIEmbeddings() - elif isinstance(self.llm_model, ChatFireworks): - return FireworksEmbeddings(model=self.llm_model.model_name) - elif isinstance(self.llm_model, ChatNVIDIA): - return NVIDIAEmbeddings(model=self.llm_model.model_name) - elif isinstance(self.llm_model, ChatOllama): - # unwrap the kwargs from the model whihc is a dict - params = self.llm_model._lc_kwargs - # remove streaming and temperature - params.pop("streaming", None) - params.pop("temperature", None) - - return OllamaEmbeddings(**params) - elif isinstance(self.llm_model, ChatHuggingFace): - return HuggingFaceEmbeddings(model=self.llm_model.model) - elif isinstance(self.llm_model, ChatBedrock): - return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id) - else: - raise ValueError("Embedding Model missing or not supported") - - def _create_embedder(self, embedder_config: dict) -> object: - """ - Create an embedding model instance based on the configuration provided. - - Args: - embedder_config (dict): Configuration parameters for the embedding model. - - Returns: - object: An instance of the embedding model client. - - Raises: - KeyError: If the model is not supported. - """ - embedder_params = {**embedder_config} - if "model_instance" in embedder_config: - return embedder_params["model_instance"] - # Instantiate the embedding model based on the model name - if "openai" in embedder_params["model"]: - return OpenAIEmbeddings(api_key=embedder_params["api_key"]) - if "azure" in embedder_params["model"]: - return AzureOpenAIEmbeddings() - if "nvidia" in embedder_params["model"]: - embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) - try: - models_tokens["nvidia"][embedder_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc - return NVIDIAEmbeddings(model=embedder_params["model"], - nvidia_api_key=embedder_params["api_key"]) - if "ollama" in embedder_params["model"]: - embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) - try: - models_tokens["ollama"][embedder_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc - return OllamaEmbeddings(**embedder_params) - if "hugging_face" in embedder_params["model"]: - embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) - try: - models_tokens["hugging_face"][embedder_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc - return HuggingFaceEmbeddings(model=embedder_params["model"]) - if "fireworks" in embedder_params["model"]: - embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) - try: - models_tokens["fireworks"][embedder_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc - return FireworksEmbeddings(model=embedder_params["model"]) - if "gemini" in embedder_params["model"]: - try: - models_tokens["gemini"][embedder_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc - return GoogleGenerativeAIEmbeddings(model=embedder_params["model"]) - if "bedrock" in embedder_params["model"]: - embedder_params["model"] = embedder_params["model"].split("/")[-1] - client = embedder_params.get("client", None) - try: - models_tokens["bedrock"][embedder_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc - return BedrockEmbeddings(client=client, model_id=embedder_params["model"]) - - raise ValueError("Model provided by the configuration not supported") def get_state(self, key=None) -> dict: """ "" diff --git a/scrapegraphai/nodes/rag_node.py b/scrapegraphai/nodes/rag_node.py index a4f58191..952daa6c 100644 --- a/scrapegraphai/nodes/rag_node.py +++ b/scrapegraphai/nodes/rag_node.py @@ -14,8 +14,20 @@ from langchain.retrievers.document_compressors import ( from langchain_community.document_transformers import EmbeddingsRedundantFilter from langchain_community.vectorstores import FAISS +from langchain_community.chat_models import ChatOllama +from langchain_aws import BedrockEmbeddings, ChatBedrock +from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings +from langchain_community.embeddings import OllamaEmbeddings +from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI +from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings +from langchain_fireworks import FireworksEmbeddings, ChatFireworks +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI +from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA + from ..utils.logging import get_logger from .base_node import BaseNode +from ..helpers import models_tokens +from ..models import DeepSeek class RAGNode(BaseNode): @@ -95,10 +107,21 @@ class RAGNode(BaseNode): self.logger.info("--- (updated chunks metadata) ---") # check if embedder_model is provided, if not use llm_model - self.embedder_model = ( - self.embedder_model if self.embedder_model else self.llm_model - ) - embeddings = self.embedder_model + if self.embedder_model is not None: + embeddings = self.embedder_model + elif 'embeddings' in self.node_config: + try: + embeddings = self._create_embedder(self.node_config['embedder_config']) + except Exception: + try: + embeddings = self._create_default_embedder() + self.embedder_model = embeddings + except ValueError: + embeddings = self.llm_model + self.embedder_model = self.llm_model + else: + embeddings = self.llm_model + self.embedder_model = self.llm_model folder_name = self.node_config.get("cache_path", "cache") @@ -141,3 +164,116 @@ class RAGNode(BaseNode): state.update({self.output[0]: compressed_docs}) return state + + + def _create_default_embedder(self, llm_config=None) -> object: + """ + Create an embedding model instance based on the chosen llm model. + + Returns: + object: An instance of the embedding model client. + + Raises: + ValueError: If the model is not supported. + """ + if isinstance(self.llm_model, ChatGoogleGenerativeAI): + return GoogleGenerativeAIEmbeddings( + google_api_key=llm_config["api_key"], model="models/embedding-001" + ) + if isinstance(self.llm_model, ChatOpenAI): + return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key, + base_url=self.llm_model.openai_api_base) + elif isinstance(self.llm_model, DeepSeek): + return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) + elif isinstance(self.llm_model, ChatVertexAI): + return VertexAIEmbeddings() + elif isinstance(self.llm_model, AzureOpenAIEmbeddings): + return self.llm_model + elif isinstance(self.llm_model, AzureChatOpenAI): + return AzureOpenAIEmbeddings() + elif isinstance(self.llm_model, ChatFireworks): + return FireworksEmbeddings(model=self.llm_model.model_name) + elif isinstance(self.llm_model, ChatNVIDIA): + return NVIDIAEmbeddings(model=self.llm_model.model_name) + elif isinstance(self.llm_model, ChatOllama): + # unwrap the kwargs from the model whihc is a dict + params = self.llm_model._lc_kwargs + # remove streaming and temperature + params.pop("streaming", None) + params.pop("temperature", None) + + return OllamaEmbeddings(**params) + elif isinstance(self.llm_model, ChatHuggingFace): + return HuggingFaceEmbeddings(model=self.llm_model.model) + elif isinstance(self.llm_model, ChatBedrock): + return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id) + else: + raise ValueError("Embedding Model missing or not supported") + + + def _create_embedder(self, embedder_config: dict) -> object: + """ + Create an embedding model instance based on the configuration provided. + + Args: + embedder_config (dict): Configuration parameters for the embedding model. + + Returns: + object: An instance of the embedding model client. + + Raises: + KeyError: If the model is not supported. + """ + embedder_params = {**embedder_config} + if "model_instance" in embedder_config: + return embedder_params["model_instance"] + # Instantiate the embedding model based on the model name + if "openai" in embedder_params["model"]: + return OpenAIEmbeddings(api_key=embedder_params["api_key"]) + if "azure" in embedder_params["model"]: + return AzureOpenAIEmbeddings() + if "nvidia" in embedder_params["model"]: + embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) + try: + models_tokens["nvidia"][embedder_params["model"]] + except KeyError as exc: + raise KeyError("Model not supported") from exc + return NVIDIAEmbeddings(model=embedder_params["model"], + nvidia_api_key=embedder_params["api_key"]) + if "ollama" in embedder_params["model"]: + embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) + try: + models_tokens["ollama"][embedder_params["model"]] + except KeyError as exc: + raise KeyError("Model not supported") from exc + return OllamaEmbeddings(**embedder_params) + if "hugging_face" in embedder_params["model"]: + embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) + try: + models_tokens["hugging_face"][embedder_params["model"]] + except KeyError as exc: + raise KeyError("Model not supported") from exc + return HuggingFaceEmbeddings(model=embedder_params["model"]) + if "fireworks" in embedder_params["model"]: + embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) + try: + models_tokens["fireworks"][embedder_params["model"]] + except KeyError as exc: + raise KeyError("Model not supported") from exc + return FireworksEmbeddings(model=embedder_params["model"]) + if "gemini" in embedder_params["model"]: + try: + models_tokens["gemini"][embedder_params["model"]] + except KeyError as exc: + raise KeyError("Model not supported") from exc + return GoogleGenerativeAIEmbeddings(model=embedder_params["model"]) + if "bedrock" in embedder_params["model"]: + embedder_params["model"] = embedder_params["model"].split("/")[-1] + client = embedder_params.get("client", None) + try: + models_tokens["bedrock"][embedder_params["model"]] + except KeyError as exc: + raise KeyError("Model not supported") from exc + return BedrockEmbeddings(client=client, model_id=embedder_params["model"]) + + raise ValueError("Model provided by the configuration not supported")