diff --git a/pyproject.toml b/pyproject.toml index c42bf33b..0fba1f33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "langchain-groq==0.1.3", "langchain-aws==0.1.3", "langchain-anthropic==0.1.11", + "langchain-nvidia-ai-endpoints==0.1.6", "html2text==2024.2.26", "faiss-cpu==1.8.0", "beautifulsoup4==4.12.3", diff --git a/requirements-dev.lock b/requirements-dev.lock index b0bcaaa0..0cd32e1d 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -14,6 +14,7 @@ aiohttp==3.9.5 # via langchain # via langchain-community # via langchain-fireworks + # via langchain-nvidia-ai-endpoints aiosignal==1.3.1 # via aiohttp alabaster==0.7.16 @@ -268,6 +269,7 @@ langchain-core==0.1.52 # via langchain-google-genai # via langchain-google-vertexai # via langchain-groq + # via langchain-nvidia-ai-endpoints # via langchain-openai # via langchain-text-splitters langchain-fireworks==0.1.3 @@ -278,6 +280,8 @@ langchain-google-vertexai==1.0.4 # via scrapegraphai langchain-groq==0.1.3 # via scrapegraphai +langchain-nvidia-ai-endpoints==0.1.6 + # via scrapegraphai langchain-openai==0.1.6 # via scrapegraphai langchain-text-splitters==0.0.2 @@ -348,6 +352,7 @@ pandas==2.2.2 # via streamlit pillow==10.3.0 # via fireworks-ai + # via langchain-nvidia-ai-endpoints # via matplotlib # via streamlit platformdirs==4.2.2 diff --git a/requirements.lock b/requirements.lock index 7a8bb455..f5624a09 100644 --- a/requirements.lock +++ b/requirements.lock @@ -12,6 +12,7 @@ aiohttp==3.9.5 # via langchain # via langchain-community # via langchain-fireworks + # via langchain-nvidia-ai-endpoints aiosignal==1.3.1 # via aiohttp annotated-types==0.7.0 @@ -187,6 +188,7 @@ langchain-core==0.1.52 # via langchain-google-genai # via langchain-google-vertexai # via langchain-groq + # via langchain-nvidia-ai-endpoints # via langchain-openai # via langchain-text-splitters langchain-fireworks==0.1.3 @@ -197,6 +199,8 @@ langchain-google-vertexai==1.0.4 # via scrapegraphai langchain-groq==0.1.3 # via scrapegraphai +langchain-nvidia-ai-endpoints==0.1.6 + # via scrapegraphai langchain-openai==0.1.6 # via scrapegraphai langchain-text-splitters==0.0.2 @@ -238,6 +242,7 @@ pandas==2.2.2 # via scrapegraphai pillow==10.3.0 # via fireworks-ai + # via langchain-nvidia-ai-endpoints playwright==1.43.0 # via scrapegraphai # via undetected-playwright diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 7f8ec4ea..3323a096 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -14,6 +14,7 @@ from langchain_google_vertexai import VertexAIEmbeddings from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings from langchain_fireworks import FireworksEmbeddings from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings +from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings from ..helpers import models_tokens from ..models import ( Anthropic, @@ -26,7 +27,8 @@ from ..models import ( OpenAI, OneApi, Fireworks, - VertexAI + VertexAI, + Nvidia ) from ..models.ernie import Ernie from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info @@ -180,6 +182,13 @@ class AbstractGraph(ABC): except KeyError as exc: raise KeyError("Model not supported") from exc return AzureOpenAI(llm_params) + elif "nvidia" in llm_params["model"]: + try: + self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]] + llm_params["model"] = "/".join(llm_params["model"].split("/")[1:]) + except KeyError as exc: + raise KeyError("Model not supported") from exc + return Nvidia(llm_params) elif "gemini" in llm_params["model"]: llm_params["model"] = llm_params["model"].split("/")[-1] try: @@ -305,6 +314,8 @@ class AbstractGraph(ABC): return AzureOpenAIEmbeddings() elif isinstance(self.llm_model, Fireworks): return FireworksEmbeddings(model=self.llm_model.model_name) + elif isinstance(self.llm_model, Nvidia): + return NVIDIAEmbeddings(model=self.llm_model.model_name) elif isinstance(self.llm_model, Ollama): # unwrap the kwargs from the model whihc is a dict params = self.llm_model._lc_kwargs @@ -341,6 +352,14 @@ class AbstractGraph(ABC): return OpenAIEmbeddings(api_key=embedder_params["api_key"]) elif "azure" in embedder_params["model"]: return AzureOpenAIEmbeddings() + if "nvidia" in embedder_params["model"]: + embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) + try: + models_tokens["nvidia"][embedder_params["model"]] + except KeyError as exc: + raise KeyError("Model not supported") from exc + return NVIDIAEmbeddings(model=embedder_params["model"], + nvidia_api_key=embedder_params["api_key"]) elif "ollama" in embedder_params["model"]: embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) try: diff --git a/scrapegraphai/helpers/models_tokens.py b/scrapegraphai/helpers/models_tokens.py index df990bf4..b3d61065 100644 --- a/scrapegraphai/helpers/models_tokens.py +++ b/scrapegraphai/helpers/models_tokens.py @@ -79,6 +79,24 @@ models_tokens = { "oneapi": { "qwen-turbo": 6000 }, + "nvidia": { + "meta/llama3-70b-instruct": 419, + "meta/llama3-8b-instruct": 419, + "nemotron-4-340b-instruct": 1024, + "databricks/dbrx-instruct": 4096, + "google/codegemma-7b": 8192, + "google/gemma-2b": 2048, + "google/gemma-7b": 8192, + "google/recurrentgemma-2b": 2048, + "meta/codellama-70b": 16384, + "meta/llama2-70b": 4096, + "microsoft/phi-3-mini-128k-instruct": 122880, + "mistralai/mistral-7b-instruct-v0.2": 4096, + "mistralai/mistral-large": 8192, + "mistralai/mixtral-8x22b-instruct-v0.1": 32768, + "mistralai/mixtral-8x7b-instruct-v0.1": 8192, + "snowflake/arctic": 16384, + }, "groq": { "llama3-8b-8192": 8192, "llama3-70b-8192": 8192, diff --git a/scrapegraphai/models/__init__.py b/scrapegraphai/models/__init__.py index a408d9ac..bfcb84d6 100644 --- a/scrapegraphai/models/__init__.py +++ b/scrapegraphai/models/__init__.py @@ -16,3 +16,4 @@ from .deepseek import DeepSeek from .oneapi import OneApi from .fireworks import Fireworks from .vertex import VertexAI +from .nvidia import Nvidia diff --git a/scrapegraphai/models/nvidia.py b/scrapegraphai/models/nvidia.py new file mode 100644 index 00000000..48ce3c0f --- /dev/null +++ b/scrapegraphai/models/nvidia.py @@ -0,0 +1,25 @@ +""" +This is a Python wrapper class for ChatNVIDIA. +It provides default configuration and could be extended with additional methods if needed. +The purpose of this wrapper is to simplify the creation of instances of ChatNVIDIA by providing +default configurations for certain parameters, +allowing users to focus on specifying other important parameters without having +to understand all the details of the underlying class's constructor. +It inherits from the base class ChatNVIDIA and overrides +its init method to provide a more user-friendly interface. +The constructor takes one argument: llm_config, which is used to initialize the superclass +with default configuration. +""" + +from langchain_nvidia_ai_endpoints import ChatNVIDIA + +class Nvidia(ChatNVIDIA): + """ A wrapper for the Nvidia class that provides default configuration + and could be extended with additional methods if needed. + + Args: + llm_config (dict): Configuration parameters for the language model. + """ + + def __init__(self, llm_config: dict): + super().__init__(**llm_config)