mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-25 21:11:11 +08:00
refactor: Changed the way embedding model is created in AbstractGraph class and removed handling of embedding model creation from RAGNode. Now AbstractGraph will call a dedicated method for embedding models instead of _create_llm. This makes it easy to use any LLM with any supported embedding model.
This commit is contained in:
parent
1219caa4ff
commit
819cbcd3be
@ -25,7 +25,7 @@ graph_config = {
|
|||||||
},
|
},
|
||||||
"embeddings": {
|
"embeddings": {
|
||||||
"api_key": openai_key,
|
"api_key": openai_key,
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "openai",
|
||||||
},
|
},
|
||||||
"headless": False
|
"headless": False
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,8 +5,12 @@ AbstractGraph Module
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace, Groq, Bedrock
|
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
|
||||||
|
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
|
||||||
|
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||||
|
|
||||||
from ..helpers import models_tokens
|
from ..helpers import models_tokens
|
||||||
|
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI
|
||||||
|
|
||||||
|
|
||||||
class AbstractGraph(ABC):
|
class AbstractGraph(ABC):
|
||||||
@ -43,7 +47,8 @@ class AbstractGraph(ABC):
|
|||||||
self.source = source
|
self.source = source
|
||||||
self.config = config
|
self.config = config
|
||||||
self.llm_model = self._create_llm(config["llm"], chat=True)
|
self.llm_model = self._create_llm(config["llm"], chat=True)
|
||||||
self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm(
|
self.embedder_model = self._create_default_embedder(
|
||||||
|
) if "embeddings" not in config else self._create_embedder(
|
||||||
config["embeddings"])
|
config["embeddings"])
|
||||||
|
|
||||||
# Set common configuration parameters
|
# Set common configuration parameters
|
||||||
@ -166,6 +171,85 @@ class AbstractGraph(ABC):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Model provided by the configuration not supported")
|
"Model provided by the configuration not supported")
|
||||||
|
|
||||||
|
def _create_default_embedder(self) -> object:
|
||||||
|
"""
|
||||||
|
Create an embedding model instance based on the chosen llm model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: An instance of the embedding model client.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model is not supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(self.llm_model, OpenAI):
|
||||||
|
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
|
||||||
|
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
|
||||||
|
return self.llm_model
|
||||||
|
elif isinstance(self.llm_model, AzureOpenAI):
|
||||||
|
return AzureOpenAIEmbeddings()
|
||||||
|
elif isinstance(self.llm_model, Ollama):
|
||||||
|
# unwrap the kwargs from the model whihc is a dict
|
||||||
|
params = self.llm_model._lc_kwargs
|
||||||
|
# remove streaming and temperature
|
||||||
|
params.pop("streaming", None)
|
||||||
|
params.pop("temperature", None)
|
||||||
|
|
||||||
|
return OllamaEmbeddings(**params)
|
||||||
|
elif isinstance(self.llm_model, HuggingFace):
|
||||||
|
return HuggingFaceHubEmbeddings(model=self.llm_model.model)
|
||||||
|
elif isinstance(self.llm_model, Bedrock):
|
||||||
|
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
|
||||||
|
else:
|
||||||
|
raise ValueError("Embedding Model missing or not supported")
|
||||||
|
|
||||||
|
def _create_embedder(self, embedder_config: dict) -> object:
|
||||||
|
"""
|
||||||
|
Create an embedding model instance based on the configuration provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedder_config (dict): Configuration parameters for the embedding model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: An instance of the embedding model client.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the model is not supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Instantiate the embedding model based on the model name
|
||||||
|
if "openai" in embedder_config["model"]:
|
||||||
|
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
|
||||||
|
|
||||||
|
elif "azure" in embedder_config["model"]:
|
||||||
|
return AzureOpenAIEmbeddings()
|
||||||
|
|
||||||
|
elif "ollama" in embedder_config["model"]:
|
||||||
|
embedder_config["model"] = embedder_config["model"].split("/")[-1]
|
||||||
|
try:
|
||||||
|
models_tokens["ollama"][embedder_config["model"]]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError("Model not supported")
|
||||||
|
return OllamaEmbeddings(**embedder_config)
|
||||||
|
|
||||||
|
elif "hugging_face" in embedder_config["model"]:
|
||||||
|
try:
|
||||||
|
models_tokens["hugging_face"][embedder_config["model"]]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError("Model not supported")
|
||||||
|
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
|
||||||
|
|
||||||
|
elif "bedrock" in embedder_config["model"]:
|
||||||
|
embedder_config["model"] = embedder_config["model"].split("/")[-1]
|
||||||
|
try:
|
||||||
|
models_tokens["bedrock"][embedder_config["model"]]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError("Model not supported")
|
||||||
|
return BedrockEmbeddings(client=None, model_id=embedder_config["model"])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Model provided by the configuration not supported")
|
||||||
|
|
||||||
def get_state(self, key=None) -> dict:
|
def get_state(self, key=None) -> dict:
|
||||||
"""""
|
"""""
|
||||||
Get the final state of the graph.
|
Get the final state of the graph.
|
||||||
|
|||||||
@ -87,31 +87,7 @@ class RAGNode(BaseNode):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("--- (updated chunks metadata) ---")
|
print("--- (updated chunks metadata) ---")
|
||||||
|
|
||||||
# check if embedder_model is provided, if not use llm_model
|
embeddings = self.embedder_model
|
||||||
embedding_model = self.embedder_model if self.embedder_model else self.llm_model
|
|
||||||
|
|
||||||
if isinstance(embedding_model, OpenAI):
|
|
||||||
embeddings = OpenAIEmbeddings(
|
|
||||||
api_key=embedding_model.openai_api_key)
|
|
||||||
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
|
|
||||||
embeddings = embedding_model
|
|
||||||
elif isinstance(embedding_model, AzureOpenAI):
|
|
||||||
embeddings = AzureOpenAIEmbeddings()
|
|
||||||
elif isinstance(embedding_model, Ollama):
|
|
||||||
# unwrap the kwargs from the model whihc is a dict
|
|
||||||
params = embedding_model._lc_kwargs
|
|
||||||
# remove streaming and temperature
|
|
||||||
params.pop("streaming", None)
|
|
||||||
params.pop("temperature", None)
|
|
||||||
|
|
||||||
embeddings = OllamaEmbeddings(**params)
|
|
||||||
elif isinstance(embedding_model, HuggingFace):
|
|
||||||
embeddings = HuggingFaceHubEmbeddings(model=embedding_model.model)
|
|
||||||
elif isinstance(embedding_model, Bedrock):
|
|
||||||
embeddings = BedrockEmbeddings(
|
|
||||||
client=None, model_id=embedding_model.model_id)
|
|
||||||
else:
|
|
||||||
raise ValueError("Embedding Model missing or not supported")
|
|
||||||
|
|
||||||
retriever = FAISS.from_documents(
|
retriever = FAISS.from_documents(
|
||||||
chunked_docs, embeddings).as_retriever()
|
chunked_docs, embeddings).as_retriever()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user