Add in kv cache quantization

This commit is contained in:
Vik Paruchuri 2025-04-11 13:19:57 -04:00
parent f3c4adb29b
commit 00387bdbdb
8 changed files with 128 additions and 107 deletions

52
poetry.lock generated
View File

@ -1232,14 +1232,14 @@ zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "huggingface-hub"
version = "0.29.1"
version = "0.30.2"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
python-versions = ">=3.8.0"
groups = ["main", "dev"]
files = [
{file = "huggingface_hub-0.29.1-py3-none-any.whl", hash = "sha256:352f69caf16566c7b6de84b54a822f6238e17ddd8ae3da4f8f2272aea5b198d5"},
{file = "huggingface_hub-0.29.1.tar.gz", hash = "sha256:9524eae42077b8ff4fc459ceb7a514eca1c1232b775276b009709fe2a084f250"},
{file = "huggingface_hub-0.30.2-py3-none-any.whl", hash = "sha256:68ff05969927058cfa41df4f2155d4bb48f5f54f719dd0390103eefa9b191e28"},
{file = "huggingface_hub-0.30.2.tar.gz", hash = "sha256:9a7897c5b6fd9dad3168a794a8998d6378210f5b9688d0dfc180b1a228dc2466"},
]
[package.dependencies]
@ -1257,6 +1257,7 @@ cli = ["InquirerPy (==0.3.4)"]
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
hf-transfer = ["hf-transfer (>=0.1.4)"]
hf-xet = ["hf-xet (>=0.1.4)"]
inference = ["aiohttp"]
quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.9.0)"]
tensorflow = ["graphviz", "pydot", "tensorflow"]
@ -4679,71 +4680,74 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,
[[package]]
name = "transformers"
version = "4.49.0"
version = "4.51.2"
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
optional = false
python-versions = ">=3.9.0"
groups = ["main"]
files = [
{file = "transformers-4.49.0-py3-none-any.whl", hash = "sha256:6b4fded1c5fee04d384b1014495b4235a2b53c87503d7d592423c06128cbbe03"},
{file = "transformers-4.49.0.tar.gz", hash = "sha256:7e40e640b5b8dc3f48743f5f5adbdce3660c82baafbd3afdfc04143cdbd2089e"},
{file = "transformers-4.51.2-py3-none-any.whl", hash = "sha256:5cb8259098b75ff4b5dd04533a318f7c4750d5307d9617e6d0593526432c404d"},
{file = "transformers-4.51.2.tar.gz", hash = "sha256:ed221c31581e97127cff5de775b05f05d19698b439d7d638ff445502a7f37331"},
]
[package.dependencies]
filelock = "*"
huggingface-hub = ">=0.26.0,<1.0"
huggingface-hub = ">=0.30.0,<1.0"
numpy = ">=1.17"
packaging = ">=20.0"
pyyaml = ">=5.1"
regex = "!=2019.12.17"
requests = "*"
safetensors = ">=0.4.1"
safetensors = ">=0.4.3"
tokenizers = ">=0.21,<0.22"
tqdm = ">=4.27"
[package.extras]
accelerate = ["accelerate (>=0.26.0)"]
agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=2.0)"]
all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"]
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.3.2,<0.4)", "librosa", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"]
audio = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
benchmark = ["optimum-benchmark (>=0.3.0)"]
codecarbon = ["codecarbon (>=2.8.1)"]
deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"]
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.3.2,<0.4)", "libcst", "librosa", "nltk (<=3.8.1)", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kernels (>=0.3.2,<0.4)", "libcst", "librosa", "nltk (<=3.8.1)", "num2words", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
flax-speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
ftfy = ["ftfy"]
integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"]
hf-xet = ["hf-xet"]
hub-kernels = ["kernels (>=0.3.2,<0.4)"]
integrations = ["kernels (>=0.3.2,<0.4)", "optuna", "ray[tune] (>=2.7.0)", "sigopt"]
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
modelcreation = ["cookiecutter (==1.7.3)"]
natten = ["natten (>=0.14.6,<0.15.0)"]
num2words = ["num2words"]
onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
optuna = ["optuna"]
quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"]
quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.11.2)", "urllib3 (<2.0.0)"]
ray = ["ray[tune] (>=2.7.0)"]
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
ruff = ["ruff (==0.5.1)"]
ruff = ["ruff (==0.11.2)"]
sagemaker = ["sagemaker (>=2.31.0)"]
sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
sigopt = ["sigopt"]
sklearn = ["scikit-learn"]
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"]
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
tf-speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
tiktoken = ["blobfile", "tiktoken"]
timm = ["timm (<=1.0.11)"]
tokenizers = ["tokenizers (>=0.21,<0.22)"]
torch = ["accelerate (>=0.26.0)", "torch (>=2.0)"]
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
torch-speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
torchhub = ["filelock", "huggingface-hub (>=0.26.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"]
torchhub = ["filelock", "huggingface-hub (>=0.30.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"]
video = ["av"]
vision = ["Pillow (>=10.0.1,<=15.0)"]
@ -5201,4 +5205,4 @@ propcache = ">=0.2.0"
[metadata]
lock-version = "2.1"
python-versions = "^3.10"
content-hash = "fa33cb2c8bfc8f62754a8067a8532da7a589e3792ef9106ae8af8f85604d8698"
content-hash = "14a2e68d79050c13419742306894ee97191eef7a531b76c78bb8e85caeca7cad"

View File

@ -13,7 +13,7 @@ packages = [
[tool.poetry.dependencies]
python = "^3.10"
transformers = "^4.41.0"
transformers = "^4.51.2"
torch = "^2.5.1"
pydantic = "^2.5.3"
pydantic-settings = "^2.1.0"

View File

@ -173,12 +173,14 @@ class Qwen2Attention(nn.Module):
# IMPORTANT: Do not use causal mask for prefill; Matches training
# This is required for flash attn, which doesn't support a 4D mask as input
# The `is_causal` argument is ignored by SDPA since we pass a 4D attention mask
if self.config._attn_implementation == 'flash_attention_2':
is_prefill = all((
input_shape[1] > 1,
(past_key_value is None) or (past_key_value.get_seq_length(self.layer_idx) == 0)
))
if self.config._attn_implementation == "flash_attention_2":
is_prefill = all(
(
input_shape[1] > 1,
(past_key_value is None)
or (past_key_value.get_seq_length(self.layer_idx) == 0),
)
)
if is_prefill:
self.is_causal = False
else:

View File

@ -165,7 +165,7 @@ class SuryaOCRProcessor(S3DownloaderMixin, ProcessorMixin):
num_tiles = image_tiles.shape[0]
input_ids = [self.image_token_id] * num_tiles * self.image_tokens_per_tile
input_ids += [self.register_token_ids][: self.num_register_tokens]
input_ids += self.register_token_ids[: self.num_register_tokens]
# Handle the image being rotated in the imdataset
if rotated:

View File

@ -8,6 +8,7 @@ import torch
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
from transformers import QuantizedCacheConfig
from surya.common.polygon import PolygonBox
from surya.common.surya import SuryaModelConfig, SuryaModelOutput
@ -30,7 +31,10 @@ from surya.recognition.util import (
)
from surya.recognition.schema import TextLine, OCRResult, TextChar
from surya.common.surya.schema import TaskNames
from surya.recognition.cache import ContinuousBatchingCache
from surya.recognition.cache import (
ContinuousBatchingCache,
ContinuousBatchingQuantizedCache,
)
from surya.settings import settings
@ -383,7 +387,23 @@ class RecognitionPredictor(BasePredictor):
needs_boxes = [self.tasks[p.task_name]["needs_bboxes"] for p in prompts]
skip_box_idxs = ~torch.from_numpy(np.array(needs_boxes)).to(self.model.device)
prefill_cache = ContinuousBatchingCache()
if settings.RECOGNITION_MODEL_QUANTIZE:
try:
import hqq # noqa: F401
except Exception:
raise ImportError(
"Please install hqq to use quantized recognition model"
)
# Use quantized cache if setting activated
cache_config = QuantizedCacheConfig(
"HQQ", 8, 1, 1, device=self.model.device, compute_dtype=self.model.dtype
)
prefill_cache = (
ContinuousBatchingCache()
if not settings.RECOGNITION_MODEL_QUANTIZE
else ContinuousBatchingQuantizedCache(cache_config)
)
with settings.INFERENCE_MODE():
outputs = self.model(

View File

@ -1,45 +1,64 @@
import torch
from transformers import DynamicCache
from transformers import DynamicCache, HQQQuantizedCache
from typing import List, Tuple
class ContinuousBatchingCache(DynamicCache):
class ContinuousBatchingMixin:
def pad_left(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
padding_size: int
self, key_states: torch.Tensor, value_states: torch.Tensor, padding_size: int
) -> Tuple[torch.Tensor, torch.Tensor]:
# Size is assumed to be (batch_size, num_kv_heads, seq_length, head_dim) - To match huggingface
key_padding = torch.zeros((key_states.shape[0], key_states.shape[1], padding_size, key_states.shape[3]), device=key_states.device, dtype=key_states.dtype)
key_states_padded = torch.cat([key_padding, key_states], dim=-2) # Pad along the sequence length dimension (dim=-2)
key_padding = torch.zeros(
(
key_states.shape[0],
key_states.shape[1],
padding_size,
key_states.shape[3],
),
device=key_states.device,
dtype=key_states.dtype,
)
key_states_padded = torch.cat(
[key_padding, key_states], dim=-2
) # Pad along the sequence length dimension (dim=-2)
# Pad value_states to the left by `padding_size`
value_padding = torch.zeros((value_states.shape[0], value_states.shape[1], padding_size, value_states.shape[3]), device=value_states.device, dtype=value_states.dtype)
value_states_padded = torch.cat([value_padding, value_states], dim=-2) # Pad along the sequence length dimension (dim=-2)
value_padding = torch.zeros(
(
value_states.shape[0],
value_states.shape[1],
padding_size,
value_states.shape[3],
),
device=value_states.device,
dtype=value_states.dtype,
)
value_states_padded = torch.cat(
[value_padding, value_states], dim=-2
) # Pad along the sequence length dimension (dim=-2)
return key_states_padded, value_states_padded
# Trim the cache from the left - Useful when longer sequences are evicted and we have long padding on the left
def trim_left(
self,
trim_length: int
):
def trim_left(self, trim_length: int):
for layer_idx in range(len(self)):
# cache sape is (batch_size, num_kv_heads, seq_length, head_dim); Trimming from head dim
self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :, trim_length:, :]
self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :, trim_length:, :]
self.value_cache[layer_idx] = self.value_cache[layer_idx][
:, :, trim_length:, :
]
def merge(self, new_cache: DynamicCache, merge_idxs: List[int]):
assert len(new_cache) == len(self), (
"The two caches should have the same number of layers"
)
def merge(
self,
new_cache: DynamicCache,
merge_idxs: List[int]
):
assert len(new_cache) == len(self), "The two caches should have the same number of layers"
# We should TECHNICALLY be able to pad these values to 0s now, since they will be attention masked
current_seq_length = self.get_seq_length()
new_cache_seq_length = new_cache.get_seq_length()
offset = current_seq_length - new_cache_seq_length # Generally positive, but negative case is handled too
offset = (
current_seq_length - new_cache_seq_length
) # Generally positive, but negative case is handled too
with torch.inference_mode():
# As long as we set the attention mask and position ids correctly, padding value can be anything
for layer_idx in range(len(self)):
@ -48,11 +67,18 @@ class ContinuousBatchingCache(DynamicCache):
new_k, new_v = self.pad_left(new_k, new_v, offset)
if offset < 0:
adjusted_key_cache, adjusted_value_cache = self.pad_left(self.key_cache[layer_idx], self.value_cache[layer_idx], abs(offset))
adjusted_key_cache, adjusted_value_cache = self.pad_left(
self.key_cache[layer_idx],
self.value_cache[layer_idx],
abs(offset),
)
else:
adjusted_key_cache, adjusted_value_cache = self.key_cache[layer_idx], self.value_cache[layer_idx]
adjusted_key_cache, adjusted_value_cache = (
self.key_cache[layer_idx],
self.value_cache[layer_idx],
)
# TODO Make this assignment batched?
# TODO Make this assignment batched?
for i, merge_idx in enumerate(merge_idxs):
adjusted_key_cache[merge_idx] = new_k[i]
adjusted_value_cache[merge_idx] = new_v[i]
@ -60,4 +86,12 @@ class ContinuousBatchingCache(DynamicCache):
self.key_cache[layer_idx] = adjusted_key_cache
self.value_cache[layer_idx] = adjusted_value_cache
return offset
return offset
class ContinuousBatchingCache(DynamicCache, ContinuousBatchingMixin):
pass
class ContinuousBatchingQuantizedCache(HQQQuantizedCache, ContinuousBatchingMixin):
pass

View File

@ -5,24 +5,11 @@ from transformers import AutoImageProcessor
from surya.common.load import ModelLoader
from surya.common.surya.config import SuryaModelConfig
from surya.common.surya.__init__ import SuryaModel
from surya.common.surya.processor.__init__ import SuryaOCRProcessor
from surya.common.surya import SuryaModel
from surya.common.surya.processor import SuryaOCRProcessor
from surya.common.surya.processor.tokenizer import SuryaOCRTokenizer
from surya.settings import settings
try:
import flash_attn
flash_available = True
except ImportError:
flash_available = False
torch.backends.cuda.enable_cudnn_sdp(settings.ENABLE_CUDNN_ATTENTION)
if not settings.ENABLE_EFFICIENT_ATTENTION:
print("Efficient attention is disabled. This will use significantly more VRAM.")
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
class RecognitionModelLoader(ModelLoader):
def __init__(self, checkpoint: Optional[str] = None):
@ -39,34 +26,12 @@ class RecognitionModelLoader(ModelLoader):
if dtype is None:
dtype = settings.MODEL_DTYPE_BFLOAT
quant_config = {}
if settings.RECOGNITION_MODEL_QUANTIZE:
try:
from torchao.quantization import Int4WeightOnlyConfig
from transformers import TorchAoConfig
except ImportError as e:
raise RuntimeError(
"`hqq` package is required for quantization. Please install it."
) from e
quant_config = Int4WeightOnlyConfig(group_size=64)
quantization_config = TorchAoConfig(quant_type=quant_config)
quant_config = {
"quantization_config": quantization_config,
"device_map": device,
"torch_dtype": "auto",
}
model = SuryaModel.from_pretrained(self.checkpoint, **quant_config)
model = SuryaModel.from_pretrained(self.checkpoint, torch_dtype=dtype).to(
device
)
model = model.eval()
if not settings.RECOGNITION_MODEL_QUANTIZE:
model = model.to(device=device, dtype=dtype)
if flash_available:
model.config.decoder._attn_implementation = "flash_attention_2"
else:
model.config.decoder._attn_implementation = "sdpa"
model.config.decoder._attn_implementation = "sdpa"
if settings.COMPILE_ALL or settings.COMPILE_RECOGNITION:
torch.set_float32_matmul_precision("high")
@ -81,7 +46,7 @@ class RecognitionModelLoader(ModelLoader):
model.decoder = torch.compile(model.decoder, **compile_args)
print(
f"Loaded recognition model {self.checkpoint} on device {model.device} with dtype {dtype}"
f"Loaded recognition model {self.checkpoint} on device {model.device} with dtype {dtype}, using attention mechanism {model.config.decoder._attn_implementation}"
)
return model

View File

@ -13,10 +13,6 @@ class Settings(BaseSettings):
IMAGE_DPI: int = 96 # Used for detection, layout, reading order
IMAGE_DPI_HIGHRES: int = 192 # Used for OCR, table rec
IN_STREAMLIT: bool = False # Whether we're running in streamlit
ENABLE_EFFICIENT_ATTENTION: bool = (
True # Usually keep True, but if you get CUDA errors, setting to False can help
)
ENABLE_CUDNN_ATTENTION: bool = False # Causes issues on many systems when set to True, but can improve performance on certain GPUs
FLATTEN_PDF: bool = True # Flatten PDFs by merging form fields before processing
DISABLE_TQDM: bool = False # Disable tqdm progress bars
S3_BASE_URL: str = "https://models.datalab.to"