fix(AbstractGraph): pass kwargs to Ernie and Nvidia models
Some checks failed
/ build (push) Has been cancelled

Co-Authored-By: Marco Vinciguerra <88108002+VinciGit00@users.noreply.github.com>
This commit is contained in:
Federico Aguzzi 2024-08-12 10:11:45 +02:00
parent c3f1520240
commit e6bedb6701
3 changed files with 2 additions and 70 deletions

View File

@ -6,8 +6,6 @@
# features: []
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false
-e file:.
aiofiles==24.1.0
@ -112,7 +110,6 @@ filelock==3.15.4
# via huggingface-hub
# via torch
# via transformers
# via triton
fireworks-ai==0.14.0
# via langchain-fireworks
fonttools==4.53.1
@ -362,34 +359,6 @@ numpy==1.26.4
# via shapely
# via streamlit
# via transformers
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.19.3
# via torch
nvidia-nvjitlink-cu12==12.6.20
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
openai==1.37.0
# via burr
# via langchain-fireworks
@ -631,8 +600,6 @@ tqdm==4.66.4
transformers==4.43.3
# via langchain-huggingface
# via sentence-transformers
triton==2.2.0
# via torch
typer==0.12.3
# via fastapi-cli
typing-extensions==4.12.2
@ -676,8 +643,6 @@ uvicorn==0.30.3
# via fastapi
uvloop==0.19.0
# via uvicorn
watchdog==4.0.1
# via streamlit
watchfiles==0.22.0
# via uvicorn
websockets==12.0

View File

@ -6,8 +6,6 @@
# features: []
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false
-e file:.
aiohttp==3.9.5
@ -69,7 +67,6 @@ filelock==3.15.4
# via huggingface-hub
# via torch
# via transformers
# via triton
fireworks-ai==0.14.0
# via langchain-fireworks
free-proxy==1.1.1
@ -267,34 +264,6 @@ numpy==1.26.4
# via sentence-transformers
# via shapely
# via transformers
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.19.3
# via torch
nvidia-nvjitlink-cu12==12.6.20
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
openai==1.37.0
# via langchain-fireworks
# via langchain-openai
@ -446,8 +415,6 @@ tqdm==4.66.4
transformers==4.43.3
# via langchain-huggingface
# via sentence-transformers
triton==2.2.0
# via torch
typing-extensions==4.12.2
# via anthropic
# via anyio

View File

@ -211,7 +211,7 @@ class AbstractGraph(ABC):
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return ErnieBotChat(llm_params)
return ErnieBotChat(**llm_params)
if "oneapi" in llm_params["model"]:
# take the model after the last dash
@ -228,7 +228,7 @@ class AbstractGraph(ABC):
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError as exc:
raise KeyError("Model not supported") from exc
return ChatNVIDIA(**llm_config)
return ChatNVIDIA(**llm_params)
# Raise an error if the model did not match any of the previous cases
raise ValueError("Model provided by the configuration not supported")