feat: add nvidia connection

This commit is contained in:
Marco Vinciguerra 2024-07-22 11:56:33 +02:00
parent f078fe8a3f
commit fc0dadb8f8
7 changed files with 75 additions and 1 deletions

View File

@ -20,6 +20,7 @@ dependencies = [
"langchain-groq==0.1.3",
"langchain-aws==0.1.3",
"langchain-anthropic==0.1.11",
"langchain-nvidia-ai-endpoints==0.1.6",
"html2text==2024.2.26",
"faiss-cpu==1.8.0",
"beautifulsoup4==4.12.3",

View File

@ -14,6 +14,7 @@ aiohttp==3.9.5
# via langchain
# via langchain-community
# via langchain-fireworks
# via langchain-nvidia-ai-endpoints
aiosignal==1.3.1
# via aiohttp
alabaster==0.7.16
@ -268,6 +269,7 @@ langchain-core==0.1.52
# via langchain-google-genai
# via langchain-google-vertexai
# via langchain-groq
# via langchain-nvidia-ai-endpoints
# via langchain-openai
# via langchain-text-splitters
langchain-fireworks==0.1.3
@ -278,6 +280,8 @@ langchain-google-vertexai==1.0.4
# via scrapegraphai
langchain-groq==0.1.3
# via scrapegraphai
langchain-nvidia-ai-endpoints==0.1.6
# via scrapegraphai
langchain-openai==0.1.6
# via scrapegraphai
langchain-text-splitters==0.0.2
@ -348,6 +352,7 @@ pandas==2.2.2
# via streamlit
pillow==10.3.0
# via fireworks-ai
# via langchain-nvidia-ai-endpoints
# via matplotlib
# via streamlit
platformdirs==4.2.2

View File

@ -12,6 +12,7 @@ aiohttp==3.9.5
# via langchain
# via langchain-community
# via langchain-fireworks
# via langchain-nvidia-ai-endpoints
aiosignal==1.3.1
# via aiohttp
annotated-types==0.7.0
@ -187,6 +188,7 @@ langchain-core==0.1.52
# via langchain-google-genai
# via langchain-google-vertexai
# via langchain-groq
# via langchain-nvidia-ai-endpoints
# via langchain-openai
# via langchain-text-splitters
langchain-fireworks==0.1.3
@ -197,6 +199,8 @@ langchain-google-vertexai==1.0.4
# via scrapegraphai
langchain-groq==0.1.3
# via scrapegraphai
langchain-nvidia-ai-endpoints==0.1.6
# via scrapegraphai
langchain-openai==0.1.6
# via scrapegraphai
langchain-text-splitters==0.0.2
@ -238,6 +242,7 @@ pandas==2.2.2
# via scrapegraphai
pillow==10.3.0
# via fireworks-ai
# via langchain-nvidia-ai-endpoints
playwright==1.43.0
# via scrapegraphai
# via undetected-playwright

View File

@ -14,6 +14,7 @@ from langchain_google_vertexai import VertexAIEmbeddings
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_fireworks import FireworksEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from ..helpers import models_tokens
from ..models import (
Anthropic,
@ -26,7 +27,8 @@ from ..models import (
OpenAI,
OneApi,
Fireworks,
VertexAI
VertexAI,
Nvidia
)
from ..models.ernie import Ernie
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
@ -180,6 +182,13 @@ class AbstractGraph(ABC):
except KeyError as exc:
raise KeyError("Model not supported") from exc
return AzureOpenAI(llm_params)
elif "nvidia" in llm_params["model"]:
try:
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Nvidia(llm_params)
elif "gemini" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
@ -305,6 +314,8 @@ class AbstractGraph(ABC):
return AzureOpenAIEmbeddings()
elif isinstance(self.llm_model, Fireworks):
return FireworksEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, Nvidia):
return NVIDIAEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, Ollama):
# unwrap the kwargs from the model whihc is a dict
params = self.llm_model._lc_kwargs
@ -341,6 +352,14 @@ class AbstractGraph(ABC):
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
elif "azure" in embedder_params["model"]:
return AzureOpenAIEmbeddings()
if "nvidia" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["nvidia"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return NVIDIAEmbeddings(model=embedder_params["model"],
nvidia_api_key=embedder_params["api_key"])
elif "ollama" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:

View File

@ -79,6 +79,24 @@ models_tokens = {
"oneapi": {
"qwen-turbo": 6000
},
"nvidia": {
"meta/llama3-70b-instruct": 419,
"meta/llama3-8b-instruct": 419,
"nemotron-4-340b-instruct": 1024,
"databricks/dbrx-instruct": 4096,
"google/codegemma-7b": 8192,
"google/gemma-2b": 2048,
"google/gemma-7b": 8192,
"google/recurrentgemma-2b": 2048,
"meta/codellama-70b": 16384,
"meta/llama2-70b": 4096,
"microsoft/phi-3-mini-128k-instruct": 122880,
"mistralai/mistral-7b-instruct-v0.2": 4096,
"mistralai/mistral-large": 8192,
"mistralai/mixtral-8x22b-instruct-v0.1": 32768,
"mistralai/mixtral-8x7b-instruct-v0.1": 8192,
"snowflake/arctic": 16384,
},
"groq": {
"llama3-8b-8192": 8192,
"llama3-70b-8192": 8192,

View File

@ -16,3 +16,4 @@ from .deepseek import DeepSeek
from .oneapi import OneApi
from .fireworks import Fireworks
from .vertex import VertexAI
from .nvidia import Nvidia

View File

@ -0,0 +1,25 @@
"""
This is a Python wrapper class for ChatNVIDIA.
It provides default configuration and could be extended with additional methods if needed.
The purpose of this wrapper is to simplify the creation of instances of ChatNVIDIA by providing
default configurations for certain parameters,
allowing users to focus on specifying other important parameters without having
to understand all the details of the underlying class's constructor.
It inherits from the base class ChatNVIDIA and overrides
its init method to provide a more user-friendly interface.
The constructor takes one argument: llm_config, which is used to initialize the superclass
with default configuration.
"""
from langchain_nvidia_ai_endpoints import ChatNVIDIA
class Nvidia(ChatNVIDIA):
""" A wrapper for the Nvidia class that provides default configuration
and could be extended with additional methods if needed.
Args:
llm_config (dict): Configuration parameters for the language model.
"""
def __init__(self, llm_config: dict):
super().__init__(**llm_config)