mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
fix(AbstractGraph): pass kwargs to Ernie and Nvidia models
Some checks failed
/ build (push) Has been cancelled
Some checks failed
/ build (push) Has been cancelled
Co-Authored-By: Marco Vinciguerra <88108002+VinciGit00@users.noreply.github.com>
This commit is contained in:
parent
c3f1520240
commit
e6bedb6701
@ -6,8 +6,6 @@
|
|||||||
# features: []
|
# features: []
|
||||||
# all-features: false
|
# all-features: false
|
||||||
# with-sources: false
|
# with-sources: false
|
||||||
# generate-hashes: false
|
|
||||||
# universal: false
|
|
||||||
|
|
||||||
-e file:.
|
-e file:.
|
||||||
aiofiles==24.1.0
|
aiofiles==24.1.0
|
||||||
@ -112,7 +110,6 @@ filelock==3.15.4
|
|||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
# via torch
|
# via torch
|
||||||
# via transformers
|
# via transformers
|
||||||
# via triton
|
|
||||||
fireworks-ai==0.14.0
|
fireworks-ai==0.14.0
|
||||||
# via langchain-fireworks
|
# via langchain-fireworks
|
||||||
fonttools==4.53.1
|
fonttools==4.53.1
|
||||||
@ -362,34 +359,6 @@ numpy==1.26.4
|
|||||||
# via shapely
|
# via shapely
|
||||||
# via streamlit
|
# via streamlit
|
||||||
# via transformers
|
# via transformers
|
||||||
nvidia-cublas-cu12==12.1.3.1
|
|
||||||
# via nvidia-cudnn-cu12
|
|
||||||
# via nvidia-cusolver-cu12
|
|
||||||
# via torch
|
|
||||||
nvidia-cuda-cupti-cu12==12.1.105
|
|
||||||
# via torch
|
|
||||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
|
||||||
# via torch
|
|
||||||
nvidia-cuda-runtime-cu12==12.1.105
|
|
||||||
# via torch
|
|
||||||
nvidia-cudnn-cu12==8.9.2.26
|
|
||||||
# via torch
|
|
||||||
nvidia-cufft-cu12==11.0.2.54
|
|
||||||
# via torch
|
|
||||||
nvidia-curand-cu12==10.3.2.106
|
|
||||||
# via torch
|
|
||||||
nvidia-cusolver-cu12==11.4.5.107
|
|
||||||
# via torch
|
|
||||||
nvidia-cusparse-cu12==12.1.0.106
|
|
||||||
# via nvidia-cusolver-cu12
|
|
||||||
# via torch
|
|
||||||
nvidia-nccl-cu12==2.19.3
|
|
||||||
# via torch
|
|
||||||
nvidia-nvjitlink-cu12==12.6.20
|
|
||||||
# via nvidia-cusolver-cu12
|
|
||||||
# via nvidia-cusparse-cu12
|
|
||||||
nvidia-nvtx-cu12==12.1.105
|
|
||||||
# via torch
|
|
||||||
openai==1.37.0
|
openai==1.37.0
|
||||||
# via burr
|
# via burr
|
||||||
# via langchain-fireworks
|
# via langchain-fireworks
|
||||||
@ -631,8 +600,6 @@ tqdm==4.66.4
|
|||||||
transformers==4.43.3
|
transformers==4.43.3
|
||||||
# via langchain-huggingface
|
# via langchain-huggingface
|
||||||
# via sentence-transformers
|
# via sentence-transformers
|
||||||
triton==2.2.0
|
|
||||||
# via torch
|
|
||||||
typer==0.12.3
|
typer==0.12.3
|
||||||
# via fastapi-cli
|
# via fastapi-cli
|
||||||
typing-extensions==4.12.2
|
typing-extensions==4.12.2
|
||||||
@ -676,8 +643,6 @@ uvicorn==0.30.3
|
|||||||
# via fastapi
|
# via fastapi
|
||||||
uvloop==0.19.0
|
uvloop==0.19.0
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
watchdog==4.0.1
|
|
||||||
# via streamlit
|
|
||||||
watchfiles==0.22.0
|
watchfiles==0.22.0
|
||||||
# via uvicorn
|
# via uvicorn
|
||||||
websockets==12.0
|
websockets==12.0
|
||||||
|
|||||||
@ -6,8 +6,6 @@
|
|||||||
# features: []
|
# features: []
|
||||||
# all-features: false
|
# all-features: false
|
||||||
# with-sources: false
|
# with-sources: false
|
||||||
# generate-hashes: false
|
|
||||||
# universal: false
|
|
||||||
|
|
||||||
-e file:.
|
-e file:.
|
||||||
aiohttp==3.9.5
|
aiohttp==3.9.5
|
||||||
@ -69,7 +67,6 @@ filelock==3.15.4
|
|||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
# via torch
|
# via torch
|
||||||
# via transformers
|
# via transformers
|
||||||
# via triton
|
|
||||||
fireworks-ai==0.14.0
|
fireworks-ai==0.14.0
|
||||||
# via langchain-fireworks
|
# via langchain-fireworks
|
||||||
free-proxy==1.1.1
|
free-proxy==1.1.1
|
||||||
@ -267,34 +264,6 @@ numpy==1.26.4
|
|||||||
# via sentence-transformers
|
# via sentence-transformers
|
||||||
# via shapely
|
# via shapely
|
||||||
# via transformers
|
# via transformers
|
||||||
nvidia-cublas-cu12==12.1.3.1
|
|
||||||
# via nvidia-cudnn-cu12
|
|
||||||
# via nvidia-cusolver-cu12
|
|
||||||
# via torch
|
|
||||||
nvidia-cuda-cupti-cu12==12.1.105
|
|
||||||
# via torch
|
|
||||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
|
||||||
# via torch
|
|
||||||
nvidia-cuda-runtime-cu12==12.1.105
|
|
||||||
# via torch
|
|
||||||
nvidia-cudnn-cu12==8.9.2.26
|
|
||||||
# via torch
|
|
||||||
nvidia-cufft-cu12==11.0.2.54
|
|
||||||
# via torch
|
|
||||||
nvidia-curand-cu12==10.3.2.106
|
|
||||||
# via torch
|
|
||||||
nvidia-cusolver-cu12==11.4.5.107
|
|
||||||
# via torch
|
|
||||||
nvidia-cusparse-cu12==12.1.0.106
|
|
||||||
# via nvidia-cusolver-cu12
|
|
||||||
# via torch
|
|
||||||
nvidia-nccl-cu12==2.19.3
|
|
||||||
# via torch
|
|
||||||
nvidia-nvjitlink-cu12==12.6.20
|
|
||||||
# via nvidia-cusolver-cu12
|
|
||||||
# via nvidia-cusparse-cu12
|
|
||||||
nvidia-nvtx-cu12==12.1.105
|
|
||||||
# via torch
|
|
||||||
openai==1.37.0
|
openai==1.37.0
|
||||||
# via langchain-fireworks
|
# via langchain-fireworks
|
||||||
# via langchain-openai
|
# via langchain-openai
|
||||||
@ -446,8 +415,6 @@ tqdm==4.66.4
|
|||||||
transformers==4.43.3
|
transformers==4.43.3
|
||||||
# via langchain-huggingface
|
# via langchain-huggingface
|
||||||
# via sentence-transformers
|
# via sentence-transformers
|
||||||
triton==2.2.0
|
|
||||||
# via torch
|
|
||||||
typing-extensions==4.12.2
|
typing-extensions==4.12.2
|
||||||
# via anthropic
|
# via anthropic
|
||||||
# via anyio
|
# via anyio
|
||||||
|
|||||||
@ -211,7 +211,7 @@ class AbstractGraph(ABC):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
print("model not found, using default token size (8192)")
|
print("model not found, using default token size (8192)")
|
||||||
self.model_token = 8192
|
self.model_token = 8192
|
||||||
return ErnieBotChat(llm_params)
|
return ErnieBotChat(**llm_params)
|
||||||
|
|
||||||
if "oneapi" in llm_params["model"]:
|
if "oneapi" in llm_params["model"]:
|
||||||
# take the model after the last dash
|
# take the model after the last dash
|
||||||
@ -228,7 +228,7 @@ class AbstractGraph(ABC):
|
|||||||
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise KeyError("Model not supported") from exc
|
raise KeyError("Model not supported") from exc
|
||||||
return ChatNVIDIA(**llm_config)
|
return ChatNVIDIA(**llm_params)
|
||||||
|
|
||||||
# Raise an error if the model did not match any of the previous cases
|
# Raise an error if the model did not match any of the previous cases
|
||||||
raise ValueError("Model provided by the configuration not supported")
|
raise ValueError("Model provided by the configuration not supported")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user