mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
refactor: remove redundant wrappers for Ernie and Nvidia
This commit is contained in:
parent
9275486240
commit
bc2c9967d2
@ -19,14 +19,14 @@ from langchain_google_genai import ChatGoogleGenerativeAI
|
|||||||
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
||||||
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
|
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
|
||||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
|
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
|
||||||
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
|
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
|
||||||
|
from langchain_community.chat_models import ErnieBotChat
|
||||||
from ..helpers import models_tokens
|
from ..helpers import models_tokens
|
||||||
from ..models import (
|
from ..models import (
|
||||||
OneApi,
|
OneApi,
|
||||||
Nvidia,
|
|
||||||
DeepSeek
|
DeepSeek
|
||||||
)
|
)
|
||||||
from ..models.ernie import Ernie
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
|
||||||
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
|
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
|
||||||
@ -192,7 +192,7 @@ class AbstractGraph(ABC):
|
|||||||
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise KeyError("Model not supported") from exc
|
raise KeyError("Model not supported") from exc
|
||||||
return Nvidia(llm_params)
|
return ChatNVIDIA(llm_params)
|
||||||
elif "gemini" in llm_params["model"]:
|
elif "gemini" in llm_params["model"]:
|
||||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||||
try:
|
try:
|
||||||
@ -289,7 +289,7 @@ class AbstractGraph(ABC):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
print("model not found, using default token size (8192)")
|
print("model not found, using default token size (8192)")
|
||||||
self.model_token = 8192
|
self.model_token = 8192
|
||||||
return Ernie(llm_params)
|
return ErnieBotChat(llm_params)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Model provided by the configuration not supported")
|
raise ValueError("Model provided by the configuration not supported")
|
||||||
|
|
||||||
@ -320,7 +320,7 @@ class AbstractGraph(ABC):
|
|||||||
return AzureOpenAIEmbeddings()
|
return AzureOpenAIEmbeddings()
|
||||||
elif isinstance(self.llm_model, ChatFireworks):
|
elif isinstance(self.llm_model, ChatFireworks):
|
||||||
return FireworksEmbeddings(model=self.llm_model.model_name)
|
return FireworksEmbeddings(model=self.llm_model.model_name)
|
||||||
elif isinstance(self.llm_model, Nvidia):
|
elif isinstance(self.llm_model, ChatNVIDIA):
|
||||||
return NVIDIAEmbeddings(model=self.llm_model.model_name)
|
return NVIDIAEmbeddings(model=self.llm_model.model_name)
|
||||||
elif isinstance(self.llm_model, ChatOllama):
|
elif isinstance(self.llm_model, ChatOllama):
|
||||||
# unwrap the kwargs from the model whihc is a dict
|
# unwrap the kwargs from the model whihc is a dict
|
||||||
|
|||||||
@ -1,17 +0,0 @@
|
|||||||
"""
|
|
||||||
Ernie Module
|
|
||||||
"""
|
|
||||||
from langchain_community.chat_models import ErnieBotChat
|
|
||||||
|
|
||||||
|
|
||||||
class Ernie(ErnieBotChat):
|
|
||||||
"""
|
|
||||||
A wrapper for the ErnieBotChat class that provides default configuration
|
|
||||||
and could be extended with additional methods if needed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llm_config (dict): Configuration parameters for the language model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, llm_config: dict):
|
|
||||||
super().__init__(**llm_config)
|
|
||||||
@ -1,25 +0,0 @@
|
|||||||
"""
|
|
||||||
This is a Python wrapper class for ChatNVIDIA.
|
|
||||||
It provides default configuration and could be extended with additional methods if needed.
|
|
||||||
The purpose of this wrapper is to simplify the creation of instances of ChatNVIDIA by providing
|
|
||||||
default configurations for certain parameters,
|
|
||||||
allowing users to focus on specifying other important parameters without having
|
|
||||||
to understand all the details of the underlying class's constructor.
|
|
||||||
It inherits from the base class ChatNVIDIA and overrides
|
|
||||||
its init method to provide a more user-friendly interface.
|
|
||||||
The constructor takes one argument: llm_config, which is used to initialize the superclass
|
|
||||||
with default configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
|
||||||
|
|
||||||
class Nvidia(ChatNVIDIA):
|
|
||||||
""" A wrapper for the Nvidia class that provides default configuration
|
|
||||||
and could be extended with additional methods if needed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llm_config (dict): Configuration parameters for the language model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, llm_config: dict):
|
|
||||||
super().__init__(**llm_config)
|
|
||||||
Loading…
Reference in New Issue
Block a user