mirror of
https://github.com/VikParuchuri/surya.git
synced 2026-06-12 21:02:45 +08:00
Add in kv cache quantization
This commit is contained in:
parent
f3c4adb29b
commit
00387bdbdb
52
poetry.lock
generated
52
poetry.lock
generated
@ -1232,14 +1232,14 @@ zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.29.1"
|
||||
version = "0.30.2"
|
||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "huggingface_hub-0.29.1-py3-none-any.whl", hash = "sha256:352f69caf16566c7b6de84b54a822f6238e17ddd8ae3da4f8f2272aea5b198d5"},
|
||||
{file = "huggingface_hub-0.29.1.tar.gz", hash = "sha256:9524eae42077b8ff4fc459ceb7a514eca1c1232b775276b009709fe2a084f250"},
|
||||
{file = "huggingface_hub-0.30.2-py3-none-any.whl", hash = "sha256:68ff05969927058cfa41df4f2155d4bb48f5f54f719dd0390103eefa9b191e28"},
|
||||
{file = "huggingface_hub-0.30.2.tar.gz", hash = "sha256:9a7897c5b6fd9dad3168a794a8998d6378210f5b9688d0dfc180b1a228dc2466"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1257,6 +1257,7 @@ cli = ["InquirerPy (==0.3.4)"]
|
||||
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
|
||||
hf-transfer = ["hf-transfer (>=0.1.4)"]
|
||||
hf-xet = ["hf-xet (>=0.1.4)"]
|
||||
inference = ["aiohttp"]
|
||||
quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.9.0)"]
|
||||
tensorflow = ["graphviz", "pydot", "tensorflow"]
|
||||
@ -4679,71 +4680,74 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,
|
||||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "4.49.0"
|
||||
version = "4.51.2"
|
||||
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
|
||||
optional = false
|
||||
python-versions = ">=3.9.0"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "transformers-4.49.0-py3-none-any.whl", hash = "sha256:6b4fded1c5fee04d384b1014495b4235a2b53c87503d7d592423c06128cbbe03"},
|
||||
{file = "transformers-4.49.0.tar.gz", hash = "sha256:7e40e640b5b8dc3f48743f5f5adbdce3660c82baafbd3afdfc04143cdbd2089e"},
|
||||
{file = "transformers-4.51.2-py3-none-any.whl", hash = "sha256:5cb8259098b75ff4b5dd04533a318f7c4750d5307d9617e6d0593526432c404d"},
|
||||
{file = "transformers-4.51.2.tar.gz", hash = "sha256:ed221c31581e97127cff5de775b05f05d19698b439d7d638ff445502a7f37331"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
filelock = "*"
|
||||
huggingface-hub = ">=0.26.0,<1.0"
|
||||
huggingface-hub = ">=0.30.0,<1.0"
|
||||
numpy = ">=1.17"
|
||||
packaging = ">=20.0"
|
||||
pyyaml = ">=5.1"
|
||||
regex = "!=2019.12.17"
|
||||
requests = "*"
|
||||
safetensors = ">=0.4.1"
|
||||
safetensors = ">=0.4.3"
|
||||
tokenizers = ">=0.21,<0.22"
|
||||
tqdm = ">=4.27"
|
||||
|
||||
[package.extras]
|
||||
accelerate = ["accelerate (>=0.26.0)"]
|
||||
agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=2.0)"]
|
||||
all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"]
|
||||
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.3.2,<0.4)", "librosa", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"]
|
||||
audio = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
benchmark = ["optimum-benchmark (>=0.3.0)"]
|
||||
codecarbon = ["codecarbon (>=2.8.1)"]
|
||||
deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"]
|
||||
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
|
||||
dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.3.2,<0.4)", "libcst", "librosa", "nltk (<=3.8.1)", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
|
||||
dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kernels (>=0.3.2,<0.4)", "libcst", "librosa", "nltk (<=3.8.1)", "num2words", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
|
||||
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
flax-speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
ftfy = ["ftfy"]
|
||||
integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"]
|
||||
hf-xet = ["hf-xet"]
|
||||
hub-kernels = ["kernels (>=0.3.2,<0.4)"]
|
||||
integrations = ["kernels (>=0.3.2,<0.4)", "optuna", "ray[tune] (>=2.7.0)", "sigopt"]
|
||||
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
|
||||
modelcreation = ["cookiecutter (==1.7.3)"]
|
||||
natten = ["natten (>=0.14.6,<0.15.0)"]
|
||||
num2words = ["num2words"]
|
||||
onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
|
||||
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
|
||||
optuna = ["optuna"]
|
||||
quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"]
|
||||
quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.11.2)", "urllib3 (<2.0.0)"]
|
||||
ray = ["ray[tune] (>=2.7.0)"]
|
||||
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
|
||||
ruff = ["ruff (==0.5.1)"]
|
||||
ruff = ["ruff (==0.11.2)"]
|
||||
sagemaker = ["sagemaker (>=2.31.0)"]
|
||||
sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
|
||||
serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
|
||||
sigopt = ["sigopt"]
|
||||
sklearn = ["scikit-learn"]
|
||||
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
tf-speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
tiktoken = ["blobfile", "tiktoken"]
|
||||
timm = ["timm (<=1.0.11)"]
|
||||
tokenizers = ["tokenizers (>=0.21,<0.22)"]
|
||||
torch = ["accelerate (>=0.26.0)", "torch (>=2.0)"]
|
||||
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
torch-speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
|
||||
torchhub = ["filelock", "huggingface-hub (>=0.26.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"]
|
||||
torchhub = ["filelock", "huggingface-hub (>=0.30.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"]
|
||||
video = ["av"]
|
||||
vision = ["Pillow (>=10.0.1,<=15.0)"]
|
||||
|
||||
@ -5201,4 +5205,4 @@ propcache = ">=0.2.0"
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "fa33cb2c8bfc8f62754a8067a8532da7a589e3792ef9106ae8af8f85604d8698"
|
||||
content-hash = "14a2e68d79050c13419742306894ee97191eef7a531b76c78bb8e85caeca7cad"
|
||||
|
||||
@ -13,7 +13,7 @@ packages = [
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
transformers = "^4.41.0"
|
||||
transformers = "^4.51.2"
|
||||
torch = "^2.5.1"
|
||||
pydantic = "^2.5.3"
|
||||
pydantic-settings = "^2.1.0"
|
||||
|
||||
@ -173,12 +173,14 @@ class Qwen2Attention(nn.Module):
|
||||
# IMPORTANT: Do not use causal mask for prefill; Matches training
|
||||
# This is required for flash attn, which doesn't support a 4D mask as input
|
||||
# The `is_causal` argument is ignored by SDPA since we pass a 4D attention mask
|
||||
if self.config._attn_implementation == 'flash_attention_2':
|
||||
is_prefill = all((
|
||||
input_shape[1] > 1,
|
||||
(past_key_value is None) or (past_key_value.get_seq_length(self.layer_idx) == 0)
|
||||
))
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
is_prefill = all(
|
||||
(
|
||||
input_shape[1] > 1,
|
||||
(past_key_value is None)
|
||||
or (past_key_value.get_seq_length(self.layer_idx) == 0),
|
||||
)
|
||||
)
|
||||
if is_prefill:
|
||||
self.is_causal = False
|
||||
else:
|
||||
|
||||
@ -165,7 +165,7 @@ class SuryaOCRProcessor(S3DownloaderMixin, ProcessorMixin):
|
||||
|
||||
num_tiles = image_tiles.shape[0]
|
||||
input_ids = [self.image_token_id] * num_tiles * self.image_tokens_per_tile
|
||||
input_ids += [self.register_token_ids][: self.num_register_tokens]
|
||||
input_ids += self.register_token_ids[: self.num_register_tokens]
|
||||
|
||||
# Handle the image being rotated in the imdataset
|
||||
if rotated:
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import torch.nn.functional as F
|
||||
from transformers import QuantizedCacheConfig
|
||||
|
||||
from surya.common.polygon import PolygonBox
|
||||
from surya.common.surya import SuryaModelConfig, SuryaModelOutput
|
||||
@ -30,7 +31,10 @@ from surya.recognition.util import (
|
||||
)
|
||||
from surya.recognition.schema import TextLine, OCRResult, TextChar
|
||||
from surya.common.surya.schema import TaskNames
|
||||
from surya.recognition.cache import ContinuousBatchingCache
|
||||
from surya.recognition.cache import (
|
||||
ContinuousBatchingCache,
|
||||
ContinuousBatchingQuantizedCache,
|
||||
)
|
||||
from surya.settings import settings
|
||||
|
||||
|
||||
@ -383,7 +387,23 @@ class RecognitionPredictor(BasePredictor):
|
||||
needs_boxes = [self.tasks[p.task_name]["needs_bboxes"] for p in prompts]
|
||||
skip_box_idxs = ~torch.from_numpy(np.array(needs_boxes)).to(self.model.device)
|
||||
|
||||
prefill_cache = ContinuousBatchingCache()
|
||||
if settings.RECOGNITION_MODEL_QUANTIZE:
|
||||
try:
|
||||
import hqq # noqa: F401
|
||||
except Exception:
|
||||
raise ImportError(
|
||||
"Please install hqq to use quantized recognition model"
|
||||
)
|
||||
|
||||
# Use quantized cache if setting activated
|
||||
cache_config = QuantizedCacheConfig(
|
||||
"HQQ", 8, 1, 1, device=self.model.device, compute_dtype=self.model.dtype
|
||||
)
|
||||
prefill_cache = (
|
||||
ContinuousBatchingCache()
|
||||
if not settings.RECOGNITION_MODEL_QUANTIZE
|
||||
else ContinuousBatchingQuantizedCache(cache_config)
|
||||
)
|
||||
|
||||
with settings.INFERENCE_MODE():
|
||||
outputs = self.model(
|
||||
|
||||
@ -1,45 +1,64 @@
|
||||
import torch
|
||||
from transformers import DynamicCache
|
||||
from transformers import DynamicCache, HQQQuantizedCache
|
||||
from typing import List, Tuple
|
||||
|
||||
class ContinuousBatchingCache(DynamicCache):
|
||||
|
||||
class ContinuousBatchingMixin:
|
||||
def pad_left(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
padding_size: int
|
||||
self, key_states: torch.Tensor, value_states: torch.Tensor, padding_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Size is assumed to be (batch_size, num_kv_heads, seq_length, head_dim) - To match huggingface
|
||||
key_padding = torch.zeros((key_states.shape[0], key_states.shape[1], padding_size, key_states.shape[3]), device=key_states.device, dtype=key_states.dtype)
|
||||
key_states_padded = torch.cat([key_padding, key_states], dim=-2) # Pad along the sequence length dimension (dim=-2)
|
||||
key_padding = torch.zeros(
|
||||
(
|
||||
key_states.shape[0],
|
||||
key_states.shape[1],
|
||||
padding_size,
|
||||
key_states.shape[3],
|
||||
),
|
||||
device=key_states.device,
|
||||
dtype=key_states.dtype,
|
||||
)
|
||||
key_states_padded = torch.cat(
|
||||
[key_padding, key_states], dim=-2
|
||||
) # Pad along the sequence length dimension (dim=-2)
|
||||
|
||||
# Pad value_states to the left by `padding_size`
|
||||
value_padding = torch.zeros((value_states.shape[0], value_states.shape[1], padding_size, value_states.shape[3]), device=value_states.device, dtype=value_states.dtype)
|
||||
value_states_padded = torch.cat([value_padding, value_states], dim=-2) # Pad along the sequence length dimension (dim=-2)
|
||||
value_padding = torch.zeros(
|
||||
(
|
||||
value_states.shape[0],
|
||||
value_states.shape[1],
|
||||
padding_size,
|
||||
value_states.shape[3],
|
||||
),
|
||||
device=value_states.device,
|
||||
dtype=value_states.dtype,
|
||||
)
|
||||
value_states_padded = torch.cat(
|
||||
[value_padding, value_states], dim=-2
|
||||
) # Pad along the sequence length dimension (dim=-2)
|
||||
|
||||
return key_states_padded, value_states_padded
|
||||
|
||||
# Trim the cache from the left - Useful when longer sequences are evicted and we have long padding on the left
|
||||
def trim_left(
|
||||
self,
|
||||
trim_length: int
|
||||
):
|
||||
def trim_left(self, trim_length: int):
|
||||
for layer_idx in range(len(self)):
|
||||
# cache sape is (batch_size, num_kv_heads, seq_length, head_dim); Trimming from head dim
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :, trim_length:, :]
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :, trim_length:, :]
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx][
|
||||
:, :, trim_length:, :
|
||||
]
|
||||
|
||||
def merge(self, new_cache: DynamicCache, merge_idxs: List[int]):
|
||||
assert len(new_cache) == len(self), (
|
||||
"The two caches should have the same number of layers"
|
||||
)
|
||||
|
||||
def merge(
|
||||
self,
|
||||
new_cache: DynamicCache,
|
||||
merge_idxs: List[int]
|
||||
):
|
||||
assert len(new_cache) == len(self), "The two caches should have the same number of layers"
|
||||
|
||||
# We should TECHNICALLY be able to pad these values to 0s now, since they will be attention masked
|
||||
current_seq_length = self.get_seq_length()
|
||||
new_cache_seq_length = new_cache.get_seq_length()
|
||||
offset = current_seq_length - new_cache_seq_length # Generally positive, but negative case is handled too
|
||||
offset = (
|
||||
current_seq_length - new_cache_seq_length
|
||||
) # Generally positive, but negative case is handled too
|
||||
with torch.inference_mode():
|
||||
# As long as we set the attention mask and position ids correctly, padding value can be anything
|
||||
for layer_idx in range(len(self)):
|
||||
@ -48,11 +67,18 @@ class ContinuousBatchingCache(DynamicCache):
|
||||
new_k, new_v = self.pad_left(new_k, new_v, offset)
|
||||
|
||||
if offset < 0:
|
||||
adjusted_key_cache, adjusted_value_cache = self.pad_left(self.key_cache[layer_idx], self.value_cache[layer_idx], abs(offset))
|
||||
adjusted_key_cache, adjusted_value_cache = self.pad_left(
|
||||
self.key_cache[layer_idx],
|
||||
self.value_cache[layer_idx],
|
||||
abs(offset),
|
||||
)
|
||||
else:
|
||||
adjusted_key_cache, adjusted_value_cache = self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
adjusted_key_cache, adjusted_value_cache = (
|
||||
self.key_cache[layer_idx],
|
||||
self.value_cache[layer_idx],
|
||||
)
|
||||
|
||||
# TODO Make this assignment batched?
|
||||
# TODO Make this assignment batched?
|
||||
for i, merge_idx in enumerate(merge_idxs):
|
||||
adjusted_key_cache[merge_idx] = new_k[i]
|
||||
adjusted_value_cache[merge_idx] = new_v[i]
|
||||
@ -60,4 +86,12 @@ class ContinuousBatchingCache(DynamicCache):
|
||||
self.key_cache[layer_idx] = adjusted_key_cache
|
||||
self.value_cache[layer_idx] = adjusted_value_cache
|
||||
|
||||
return offset
|
||||
return offset
|
||||
|
||||
|
||||
class ContinuousBatchingCache(DynamicCache, ContinuousBatchingMixin):
|
||||
pass
|
||||
|
||||
|
||||
class ContinuousBatchingQuantizedCache(HQQQuantizedCache, ContinuousBatchingMixin):
|
||||
pass
|
||||
|
||||
@ -5,24 +5,11 @@ from transformers import AutoImageProcessor
|
||||
|
||||
from surya.common.load import ModelLoader
|
||||
from surya.common.surya.config import SuryaModelConfig
|
||||
from surya.common.surya.__init__ import SuryaModel
|
||||
from surya.common.surya.processor.__init__ import SuryaOCRProcessor
|
||||
from surya.common.surya import SuryaModel
|
||||
from surya.common.surya.processor import SuryaOCRProcessor
|
||||
from surya.common.surya.processor.tokenizer import SuryaOCRTokenizer
|
||||
from surya.settings import settings
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
flash_available = True
|
||||
except ImportError:
|
||||
flash_available = False
|
||||
|
||||
torch.backends.cuda.enable_cudnn_sdp(settings.ENABLE_CUDNN_ATTENTION)
|
||||
if not settings.ENABLE_EFFICIENT_ATTENTION:
|
||||
print("Efficient attention is disabled. This will use significantly more VRAM.")
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
|
||||
|
||||
class RecognitionModelLoader(ModelLoader):
|
||||
def __init__(self, checkpoint: Optional[str] = None):
|
||||
@ -39,34 +26,12 @@ class RecognitionModelLoader(ModelLoader):
|
||||
if dtype is None:
|
||||
dtype = settings.MODEL_DTYPE_BFLOAT
|
||||
|
||||
quant_config = {}
|
||||
if settings.RECOGNITION_MODEL_QUANTIZE:
|
||||
try:
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
from transformers import TorchAoConfig
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"`hqq` package is required for quantization. Please install it."
|
||||
) from e
|
||||
|
||||
quant_config = Int4WeightOnlyConfig(group_size=64)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
quant_config = {
|
||||
"quantization_config": quantization_config,
|
||||
"device_map": device,
|
||||
"torch_dtype": "auto",
|
||||
}
|
||||
|
||||
model = SuryaModel.from_pretrained(self.checkpoint, **quant_config)
|
||||
model = SuryaModel.from_pretrained(self.checkpoint, torch_dtype=dtype).to(
|
||||
device
|
||||
)
|
||||
model = model.eval()
|
||||
|
||||
if not settings.RECOGNITION_MODEL_QUANTIZE:
|
||||
model = model.to(device=device, dtype=dtype)
|
||||
|
||||
if flash_available:
|
||||
model.config.decoder._attn_implementation = "flash_attention_2"
|
||||
else:
|
||||
model.config.decoder._attn_implementation = "sdpa"
|
||||
model.config.decoder._attn_implementation = "sdpa"
|
||||
|
||||
if settings.COMPILE_ALL or settings.COMPILE_RECOGNITION:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
@ -81,7 +46,7 @@ class RecognitionModelLoader(ModelLoader):
|
||||
model.decoder = torch.compile(model.decoder, **compile_args)
|
||||
|
||||
print(
|
||||
f"Loaded recognition model {self.checkpoint} on device {model.device} with dtype {dtype}"
|
||||
f"Loaded recognition model {self.checkpoint} on device {model.device} with dtype {dtype}, using attention mechanism {model.config.decoder._attn_implementation}"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@ -13,10 +13,6 @@ class Settings(BaseSettings):
|
||||
IMAGE_DPI: int = 96 # Used for detection, layout, reading order
|
||||
IMAGE_DPI_HIGHRES: int = 192 # Used for OCR, table rec
|
||||
IN_STREAMLIT: bool = False # Whether we're running in streamlit
|
||||
ENABLE_EFFICIENT_ATTENTION: bool = (
|
||||
True # Usually keep True, but if you get CUDA errors, setting to False can help
|
||||
)
|
||||
ENABLE_CUDNN_ATTENTION: bool = False # Causes issues on many systems when set to True, but can improve performance on certain GPUs
|
||||
FLATTEN_PDF: bool = True # Flatten PDFs by merging form fields before processing
|
||||
DISABLE_TQDM: bool = False # Disable tqdm progress bars
|
||||
S3_BASE_URL: str = "https://models.datalab.to"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user