refactor: remove redundant wrappers for Ernie and Nvidia

This commit is contained in:
Federico Aguzzi 2024-07-30 10:59:12 +02:00
parent 9275486240
commit bc2c9967d2
3 changed files with 6 additions and 48 deletions

View File

@ -19,14 +19,14 @@ from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
from langchain_community.chat_models import ErnieBotChat
from ..helpers import models_tokens
from ..models import (
OneApi,
Nvidia,
DeepSeek
)
from ..models.ernie import Ernie
from langchain.chat_models import init_chat_model
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
@ -192,7 +192,7 @@ class AbstractGraph(ABC):
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Nvidia(llm_params)
return ChatNVIDIA(llm_params)
elif "gemini" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
@ -289,7 +289,7 @@ class AbstractGraph(ABC):
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return Ernie(llm_params)
return ErnieBotChat(llm_params)
else:
raise ValueError("Model provided by the configuration not supported")
@ -320,7 +320,7 @@ class AbstractGraph(ABC):
return AzureOpenAIEmbeddings()
elif isinstance(self.llm_model, ChatFireworks):
return FireworksEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, Nvidia):
elif isinstance(self.llm_model, ChatNVIDIA):
return NVIDIAEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, ChatOllama):
# unwrap the kwargs from the model whihc is a dict

View File

@ -1,17 +0,0 @@
"""
Ernie Module
"""
from langchain_community.chat_models import ErnieBotChat
class Ernie(ErnieBotChat):
"""
A wrapper for the ErnieBotChat class that provides default configuration
and could be extended with additional methods if needed.
Args:
llm_config (dict): Configuration parameters for the language model.
"""
def __init__(self, llm_config: dict):
super().__init__(**llm_config)

View File

@ -1,25 +0,0 @@
"""
This is a Python wrapper class for ChatNVIDIA.
It provides default configuration and could be extended with additional methods if needed.
The purpose of this wrapper is to simplify the creation of instances of ChatNVIDIA by providing
default configurations for certain parameters,
allowing users to focus on specifying other important parameters without having
to understand all the details of the underlying class's constructor.
It inherits from the base class ChatNVIDIA and overrides
its init method to provide a more user-friendly interface.
The constructor takes one argument: llm_config, which is used to initialize the superclass
with default configuration.
"""
from langchain_nvidia_ai_endpoints import ChatNVIDIA
class Nvidia(ChatNVIDIA):
""" A wrapper for the Nvidia class that provides default configuration
and could be extended with additional methods if needed.
Args:
llm_config (dict): Configuration parameters for the language model.
"""
def __init__(self, llm_config: dict):
super().__init__(**llm_config)