diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index aa4a0cbe..83b5b712 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -188,7 +188,6 @@ class AbstractGraph(ABC): Raises: ValueError: If the model is not supported. """ - if isinstance(self.llm_model, OpenAI): return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) elif isinstance(self.llm_model, AzureOpenAIEmbeddings): @@ -223,6 +222,9 @@ class AbstractGraph(ABC): Raises: KeyError: If the model is not supported. """ + + if 'model_instance' in embedder_config: + return embedder_config['model_instance'] # Instantiate the embedding model based on the model name if "openai" in embedder_config["model"]: diff --git a/scrapegraphai/nodes/rag_node.py b/scrapegraphai/nodes/rag_node.py index 86de7d7b..b883845a 100644 --- a/scrapegraphai/nodes/rag_node.py +++ b/scrapegraphai/nodes/rag_node.py @@ -8,9 +8,6 @@ from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline from langchain_community.document_transformers import EmbeddingsRedundantFilter from langchain_community.vectorstores import FAISS -from langchain_community.embeddings import OllamaEmbeddings -from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings -from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings from .base_node import BaseNode @@ -86,33 +83,7 @@ class RAGNode(BaseNode): print("--- (updated chunks metadata) ---") # check if embedder_model is provided, if not use llm_model - embedding_model = self.embedder_model if self.embedder_model else self.llm_model - - if isinstance(embedding_model, OpenAI): - embeddings = OpenAIEmbeddings( - api_key=embedding_model.openai_api_key) - elif isinstance(embedding_model, AzureOpenAIEmbeddings): - embeddings = embedding_model - elif isinstance(embedding_model, HuggingFaceInferenceAPIEmbeddings): - embeddings = embedding_model - - elif isinstance(embedding_model, AzureOpenAI): - embeddings = AzureOpenAIEmbeddings() - elif isinstance(embedding_model, Ollama): - # unwrap the kwargs from the model whihc is a dict - params = embedding_model._lc_kwargs - # remove streaming and temperature - params.pop("streaming", None) - params.pop("temperature", None) - - embeddings = OllamaEmbeddings(**params) - elif isinstance(embedding_model, HuggingFace): - embeddings = HuggingFaceHubEmbeddings(model=embedding_model.model) - elif isinstance(embedding_model, Bedrock): - embeddings = BedrockEmbeddings( - client=None, model_id=embedding_model.model_id) - else: - raise ValueError("Embedding Model missing or not supported") + self.embedder_model = self.embedder_model if self.embedder_model else self.llm_model embeddings = self.embedder_model retriever = FAISS.from_documents(