mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
fix(AbstractGraph): pass kwargs to Ernie and Nvidia models
Some checks failed
/ build (push) Has been cancelled
Some checks failed
/ build (push) Has been cancelled
Co-Authored-By: Marco Vinciguerra <88108002+VinciGit00@users.noreply.github.com>
This commit is contained in:
parent
c3f1520240
commit
e6bedb6701
@ -6,8 +6,6 @@
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
# generate-hashes: false
|
||||
# universal: false
|
||||
|
||||
-e file:.
|
||||
aiofiles==24.1.0
|
||||
@ -112,7 +110,6 @@ filelock==3.15.4
|
||||
# via huggingface-hub
|
||||
# via torch
|
||||
# via transformers
|
||||
# via triton
|
||||
fireworks-ai==0.14.0
|
||||
# via langchain-fireworks
|
||||
fonttools==4.53.1
|
||||
@ -362,34 +359,6 @@ numpy==1.26.4
|
||||
# via shapely
|
||||
# via streamlit
|
||||
# via transformers
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
# via nvidia-cudnn-cu12
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.19.3
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.6.20
|
||||
# via nvidia-cusolver-cu12
|
||||
# via nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
openai==1.37.0
|
||||
# via burr
|
||||
# via langchain-fireworks
|
||||
@ -631,8 +600,6 @@ tqdm==4.66.4
|
||||
transformers==4.43.3
|
||||
# via langchain-huggingface
|
||||
# via sentence-transformers
|
||||
triton==2.2.0
|
||||
# via torch
|
||||
typer==0.12.3
|
||||
# via fastapi-cli
|
||||
typing-extensions==4.12.2
|
||||
@ -676,8 +643,6 @@ uvicorn==0.30.3
|
||||
# via fastapi
|
||||
uvloop==0.19.0
|
||||
# via uvicorn
|
||||
watchdog==4.0.1
|
||||
# via streamlit
|
||||
watchfiles==0.22.0
|
||||
# via uvicorn
|
||||
websockets==12.0
|
||||
|
||||
@ -6,8 +6,6 @@
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
# generate-hashes: false
|
||||
# universal: false
|
||||
|
||||
-e file:.
|
||||
aiohttp==3.9.5
|
||||
@ -69,7 +67,6 @@ filelock==3.15.4
|
||||
# via huggingface-hub
|
||||
# via torch
|
||||
# via transformers
|
||||
# via triton
|
||||
fireworks-ai==0.14.0
|
||||
# via langchain-fireworks
|
||||
free-proxy==1.1.1
|
||||
@ -267,34 +264,6 @@ numpy==1.26.4
|
||||
# via sentence-transformers
|
||||
# via shapely
|
||||
# via transformers
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
# via nvidia-cudnn-cu12
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.19.3
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.6.20
|
||||
# via nvidia-cusolver-cu12
|
||||
# via nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
openai==1.37.0
|
||||
# via langchain-fireworks
|
||||
# via langchain-openai
|
||||
@ -446,8 +415,6 @@ tqdm==4.66.4
|
||||
transformers==4.43.3
|
||||
# via langchain-huggingface
|
||||
# via sentence-transformers
|
||||
triton==2.2.0
|
||||
# via torch
|
||||
typing-extensions==4.12.2
|
||||
# via anthropic
|
||||
# via anyio
|
||||
|
||||
@ -211,7 +211,7 @@ class AbstractGraph(ABC):
|
||||
except KeyError:
|
||||
print("model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
return ErnieBotChat(llm_params)
|
||||
return ErnieBotChat(**llm_params)
|
||||
|
||||
if "oneapi" in llm_params["model"]:
|
||||
# take the model after the last dash
|
||||
@ -228,7 +228,7 @@ class AbstractGraph(ABC):
|
||||
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return ChatNVIDIA(**llm_config)
|
||||
return ChatNVIDIA(**llm_params)
|
||||
|
||||
# Raise an error if the model did not match any of the previous cases
|
||||
raise ValueError("Model provided by the configuration not supported")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user