From bc2c9967d2f13ade6eeb7b23e9b423f6e79aa890 Mon Sep 17 00:00:00 2001 From: Federico Aguzzi <62149513+f-aguzzi@users.noreply.github.com> Date: Tue, 30 Jul 2024 10:59:12 +0200 Subject: [PATCH] refactor: remove redundant wrappers for Ernie and Nvidia --- scrapegraphai/graphs/abstract_graph.py | 12 ++++++------ scrapegraphai/models/ernie.py | 17 ----------------- scrapegraphai/models/nvidia.py | 25 ------------------------- 3 files changed, 6 insertions(+), 48 deletions(-) delete mode 100644 scrapegraphai/models/ernie.py delete mode 100644 scrapegraphai/models/nvidia.py diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index f27d1aee..50de0a94 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -19,14 +19,14 @@ from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings from langchain_fireworks import FireworksEmbeddings, ChatFireworks from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI -from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings +from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA +from langchain_community.chat_models import ErnieBotChat from ..helpers import models_tokens from ..models import ( OneApi, - Nvidia, DeepSeek ) -from ..models.ernie import Ernie + from langchain.chat_models import init_chat_model from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info @@ -192,7 +192,7 @@ class AbstractGraph(ABC): llm_params["model"] = "/".join(llm_params["model"].split("/")[1:]) except KeyError as exc: raise KeyError("Model not supported") from exc - return Nvidia(llm_params) + return ChatNVIDIA(llm_params) elif "gemini" in llm_params["model"]: llm_params["model"] = llm_params["model"].split("/")[-1] try: @@ -289,7 +289,7 @@ class AbstractGraph(ABC): except KeyError: print("model not found, using default token size (8192)") self.model_token = 8192 - return Ernie(llm_params) + return ErnieBotChat(llm_params) else: raise ValueError("Model provided by the configuration not supported") @@ -320,7 +320,7 @@ class AbstractGraph(ABC): return AzureOpenAIEmbeddings() elif isinstance(self.llm_model, ChatFireworks): return FireworksEmbeddings(model=self.llm_model.model_name) - elif isinstance(self.llm_model, Nvidia): + elif isinstance(self.llm_model, ChatNVIDIA): return NVIDIAEmbeddings(model=self.llm_model.model_name) elif isinstance(self.llm_model, ChatOllama): # unwrap the kwargs from the model whihc is a dict diff --git a/scrapegraphai/models/ernie.py b/scrapegraphai/models/ernie.py deleted file mode 100644 index 75e2a261..00000000 --- a/scrapegraphai/models/ernie.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Ernie Module -""" -from langchain_community.chat_models import ErnieBotChat - - -class Ernie(ErnieBotChat): - """ - A wrapper for the ErnieBotChat class that provides default configuration - and could be extended with additional methods if needed. - - Args: - llm_config (dict): Configuration parameters for the language model. - """ - - def __init__(self, llm_config: dict): - super().__init__(**llm_config) diff --git a/scrapegraphai/models/nvidia.py b/scrapegraphai/models/nvidia.py deleted file mode 100644 index 48ce3c0f..00000000 --- a/scrapegraphai/models/nvidia.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -This is a Python wrapper class for ChatNVIDIA. -It provides default configuration and could be extended with additional methods if needed. -The purpose of this wrapper is to simplify the creation of instances of ChatNVIDIA by providing -default configurations for certain parameters, -allowing users to focus on specifying other important parameters without having -to understand all the details of the underlying class's constructor. -It inherits from the base class ChatNVIDIA and overrides -its init method to provide a more user-friendly interface. -The constructor takes one argument: llm_config, which is used to initialize the superclass -with default configuration. -""" - -from langchain_nvidia_ai_endpoints import ChatNVIDIA - -class Nvidia(ChatNVIDIA): - """ A wrapper for the Nvidia class that provides default configuration - and could be extended with additional methods if needed. - - Args: - llm_config (dict): Configuration parameters for the language model. - """ - - def __init__(self, llm_config: dict): - super().__init__(**llm_config)