mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
feat: add nvidia connection
This commit is contained in:
parent
f078fe8a3f
commit
fc0dadb8f8
@ -20,6 +20,7 @@ dependencies = [
|
||||
"langchain-groq==0.1.3",
|
||||
"langchain-aws==0.1.3",
|
||||
"langchain-anthropic==0.1.11",
|
||||
"langchain-nvidia-ai-endpoints==0.1.6",
|
||||
"html2text==2024.2.26",
|
||||
"faiss-cpu==1.8.0",
|
||||
"beautifulsoup4==4.12.3",
|
||||
|
||||
@ -14,6 +14,7 @@ aiohttp==3.9.5
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
# via langchain-nvidia-ai-endpoints
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
alabaster==0.7.16
|
||||
@ -268,6 +269,7 @@ langchain-core==0.1.52
|
||||
# via langchain-google-genai
|
||||
# via langchain-google-vertexai
|
||||
# via langchain-groq
|
||||
# via langchain-nvidia-ai-endpoints
|
||||
# via langchain-openai
|
||||
# via langchain-text-splitters
|
||||
langchain-fireworks==0.1.3
|
||||
@ -278,6 +280,8 @@ langchain-google-vertexai==1.0.4
|
||||
# via scrapegraphai
|
||||
langchain-groq==0.1.3
|
||||
# via scrapegraphai
|
||||
langchain-nvidia-ai-endpoints==0.1.6
|
||||
# via scrapegraphai
|
||||
langchain-openai==0.1.6
|
||||
# via scrapegraphai
|
||||
langchain-text-splitters==0.0.2
|
||||
@ -348,6 +352,7 @@ pandas==2.2.2
|
||||
# via streamlit
|
||||
pillow==10.3.0
|
||||
# via fireworks-ai
|
||||
# via langchain-nvidia-ai-endpoints
|
||||
# via matplotlib
|
||||
# via streamlit
|
||||
platformdirs==4.2.2
|
||||
|
||||
@ -12,6 +12,7 @@ aiohttp==3.9.5
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
# via langchain-nvidia-ai-endpoints
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
@ -187,6 +188,7 @@ langchain-core==0.1.52
|
||||
# via langchain-google-genai
|
||||
# via langchain-google-vertexai
|
||||
# via langchain-groq
|
||||
# via langchain-nvidia-ai-endpoints
|
||||
# via langchain-openai
|
||||
# via langchain-text-splitters
|
||||
langchain-fireworks==0.1.3
|
||||
@ -197,6 +199,8 @@ langchain-google-vertexai==1.0.4
|
||||
# via scrapegraphai
|
||||
langchain-groq==0.1.3
|
||||
# via scrapegraphai
|
||||
langchain-nvidia-ai-endpoints==0.1.6
|
||||
# via scrapegraphai
|
||||
langchain-openai==0.1.6
|
||||
# via scrapegraphai
|
||||
langchain-text-splitters==0.0.2
|
||||
@ -238,6 +242,7 @@ pandas==2.2.2
|
||||
# via scrapegraphai
|
||||
pillow==10.3.0
|
||||
# via fireworks-ai
|
||||
# via langchain-nvidia-ai-endpoints
|
||||
playwright==1.43.0
|
||||
# via scrapegraphai
|
||||
# via undetected-playwright
|
||||
|
||||
@ -14,6 +14,7 @@ from langchain_google_vertexai import VertexAIEmbeddings
|
||||
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
||||
from langchain_fireworks import FireworksEmbeddings
|
||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
|
||||
from ..helpers import models_tokens
|
||||
from ..models import (
|
||||
Anthropic,
|
||||
@ -26,7 +27,8 @@ from ..models import (
|
||||
OpenAI,
|
||||
OneApi,
|
||||
Fireworks,
|
||||
VertexAI
|
||||
VertexAI,
|
||||
Nvidia
|
||||
)
|
||||
from ..models.ernie import Ernie
|
||||
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
|
||||
@ -180,6 +182,13 @@ class AbstractGraph(ABC):
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return AzureOpenAI(llm_params)
|
||||
elif "nvidia" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
|
||||
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return Nvidia(llm_params)
|
||||
elif "gemini" in llm_params["model"]:
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
@ -305,6 +314,8 @@ class AbstractGraph(ABC):
|
||||
return AzureOpenAIEmbeddings()
|
||||
elif isinstance(self.llm_model, Fireworks):
|
||||
return FireworksEmbeddings(model=self.llm_model.model_name)
|
||||
elif isinstance(self.llm_model, Nvidia):
|
||||
return NVIDIAEmbeddings(model=self.llm_model.model_name)
|
||||
elif isinstance(self.llm_model, Ollama):
|
||||
# unwrap the kwargs from the model whihc is a dict
|
||||
params = self.llm_model._lc_kwargs
|
||||
@ -341,6 +352,14 @@ class AbstractGraph(ABC):
|
||||
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
|
||||
elif "azure" in embedder_params["model"]:
|
||||
return AzureOpenAIEmbeddings()
|
||||
if "nvidia" in embedder_params["model"]:
|
||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||
try:
|
||||
models_tokens["nvidia"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return NVIDIAEmbeddings(model=embedder_params["model"],
|
||||
nvidia_api_key=embedder_params["api_key"])
|
||||
elif "ollama" in embedder_params["model"]:
|
||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||
try:
|
||||
|
||||
@ -79,6 +79,24 @@ models_tokens = {
|
||||
"oneapi": {
|
||||
"qwen-turbo": 6000
|
||||
},
|
||||
"nvidia": {
|
||||
"meta/llama3-70b-instruct": 419,
|
||||
"meta/llama3-8b-instruct": 419,
|
||||
"nemotron-4-340b-instruct": 1024,
|
||||
"databricks/dbrx-instruct": 4096,
|
||||
"google/codegemma-7b": 8192,
|
||||
"google/gemma-2b": 2048,
|
||||
"google/gemma-7b": 8192,
|
||||
"google/recurrentgemma-2b": 2048,
|
||||
"meta/codellama-70b": 16384,
|
||||
"meta/llama2-70b": 4096,
|
||||
"microsoft/phi-3-mini-128k-instruct": 122880,
|
||||
"mistralai/mistral-7b-instruct-v0.2": 4096,
|
||||
"mistralai/mistral-large": 8192,
|
||||
"mistralai/mixtral-8x22b-instruct-v0.1": 32768,
|
||||
"mistralai/mixtral-8x7b-instruct-v0.1": 8192,
|
||||
"snowflake/arctic": 16384,
|
||||
},
|
||||
"groq": {
|
||||
"llama3-8b-8192": 8192,
|
||||
"llama3-70b-8192": 8192,
|
||||
|
||||
@ -16,3 +16,4 @@ from .deepseek import DeepSeek
|
||||
from .oneapi import OneApi
|
||||
from .fireworks import Fireworks
|
||||
from .vertex import VertexAI
|
||||
from .nvidia import Nvidia
|
||||
|
||||
25
scrapegraphai/models/nvidia.py
Normal file
25
scrapegraphai/models/nvidia.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""
|
||||
This is a Python wrapper class for ChatNVIDIA.
|
||||
It provides default configuration and could be extended with additional methods if needed.
|
||||
The purpose of this wrapper is to simplify the creation of instances of ChatNVIDIA by providing
|
||||
default configurations for certain parameters,
|
||||
allowing users to focus on specifying other important parameters without having
|
||||
to understand all the details of the underlying class's constructor.
|
||||
It inherits from the base class ChatNVIDIA and overrides
|
||||
its init method to provide a more user-friendly interface.
|
||||
The constructor takes one argument: llm_config, which is used to initialize the superclass
|
||||
with default configuration.
|
||||
"""
|
||||
|
||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||
|
||||
class Nvidia(ChatNVIDIA):
|
||||
""" A wrapper for the Nvidia class that provides default configuration
|
||||
and could be extended with additional methods if needed.
|
||||
|
||||
Args:
|
||||
llm_config (dict): Configuration parameters for the language model.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_config: dict):
|
||||
super().__init__(**llm_config)
|
||||
Loading…
Reference in New Issue
Block a user