From 00387bdbdb6bf40ee4c48eb930ced556c4370fa2 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Fri, 11 Apr 2025 13:19:57 -0400 Subject: [PATCH] Add in kv cache quantization --- poetry.lock | 52 +++++++------- pyproject.toml | 2 +- surya/common/surya/decoder/__init__.py | 14 ++-- surya/common/surya/processor/__init__.py | 2 +- surya/recognition/__init__.py | 24 ++++++- surya/recognition/cache.py | 88 ++++++++++++++++-------- surya/recognition/loader.py | 49 ++----------- surya/settings.py | 4 -- 8 files changed, 128 insertions(+), 107 deletions(-) diff --git a/poetry.lock b/poetry.lock index 6f9b6e9..8e438d4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1232,14 +1232,14 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "huggingface-hub" -version = "0.29.1" +version = "0.30.2" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" groups = ["main", "dev"] files = [ - {file = "huggingface_hub-0.29.1-py3-none-any.whl", hash = "sha256:352f69caf16566c7b6de84b54a822f6238e17ddd8ae3da4f8f2272aea5b198d5"}, - {file = "huggingface_hub-0.29.1.tar.gz", hash = "sha256:9524eae42077b8ff4fc459ceb7a514eca1c1232b775276b009709fe2a084f250"}, + {file = "huggingface_hub-0.30.2-py3-none-any.whl", hash = "sha256:68ff05969927058cfa41df4f2155d4bb48f5f54f719dd0390103eefa9b191e28"}, + {file = "huggingface_hub-0.30.2.tar.gz", hash = "sha256:9a7897c5b6fd9dad3168a794a8998d6378210f5b9688d0dfc180b1a228dc2466"}, ] [package.dependencies] @@ -1257,6 +1257,7 @@ cli = ["InquirerPy (==0.3.4)"] dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] hf-transfer = ["hf-transfer (>=0.1.4)"] +hf-xet = ["hf-xet (>=0.1.4)"] inference = ["aiohttp"] quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.9.0)"] tensorflow = ["graphviz", "pydot", "tensorflow"] @@ -4679,71 +4680,74 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0, [[package]] name = "transformers" -version = "4.49.0" +version = "4.51.2" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.9.0" groups = ["main"] files = [ - {file = "transformers-4.49.0-py3-none-any.whl", hash = "sha256:6b4fded1c5fee04d384b1014495b4235a2b53c87503d7d592423c06128cbbe03"}, - {file = "transformers-4.49.0.tar.gz", hash = "sha256:7e40e640b5b8dc3f48743f5f5adbdce3660c82baafbd3afdfc04143cdbd2089e"}, + {file = "transformers-4.51.2-py3-none-any.whl", hash = "sha256:5cb8259098b75ff4b5dd04533a318f7c4750d5307d9617e6d0593526432c404d"}, + {file = "transformers-4.51.2.tar.gz", hash = "sha256:ed221c31581e97127cff5de775b05f05d19698b439d7d638ff445502a7f37331"}, ] [package.dependencies] filelock = "*" -huggingface-hub = ">=0.26.0,<1.0" +huggingface-hub = ">=0.30.0,<1.0" numpy = ">=1.17" packaging = ">=20.0" pyyaml = ">=5.1" regex = "!=2019.12.17" requests = "*" -safetensors = ">=0.4.1" +safetensors = ">=0.4.3" tokenizers = ">=0.21,<0.22" tqdm = ">=4.27" [package.extras] accelerate = ["accelerate (>=0.26.0)"] agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=2.0)"] -all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"] -audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.3.2,<0.4)", "librosa", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"] +audio = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] benchmark = ["optimum-benchmark (>=0.3.0)"] codecarbon = ["codecarbon (>=2.8.1)"] deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.3.2,<0.4)", "libcst", "librosa", "nltk (<=3.8.1)", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kernels (>=0.3.2,<0.4)", "libcst", "librosa", "nltk (<=3.8.1)", "num2words", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] -flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +flax-speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] ftfy = ["ftfy"] -integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] +hf-xet = ["hf-xet"] +hub-kernels = ["kernels (>=0.3.2,<0.4)"] +integrations = ["kernels (>=0.3.2,<0.4)", "optuna", "ray[tune] (>=2.7.0)", "sigopt"] ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] modelcreation = ["cookiecutter (==1.7.3)"] natten = ["natten (>=0.14.6,<0.15.0)"] +num2words = ["num2words"] onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.11.2)", "urllib3 (<2.0.0)"] ray = ["ray[tune] (>=2.7.0)"] retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] -ruff = ["ruff (==0.5.1)"] +ruff = ["ruff (==0.11.2)"] sagemaker = ["sagemaker (>=2.31.0)"] sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] serving = ["fastapi", "pydantic", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] -speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] -tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +tf-speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] tiktoken = ["blobfile", "tiktoken"] timm = ["timm (<=1.0.11)"] tokenizers = ["tokenizers (>=0.21,<0.22)"] torch = ["accelerate (>=0.26.0)", "torch (>=2.0)"] -torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +torch-speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.26.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"] +torchhub = ["filelock", "huggingface-hub (>=0.30.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"] video = ["av"] vision = ["Pillow (>=10.0.1,<=15.0)"] @@ -5201,4 +5205,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.1" python-versions = "^3.10" -content-hash = "fa33cb2c8bfc8f62754a8067a8532da7a589e3792ef9106ae8af8f85604d8698" +content-hash = "14a2e68d79050c13419742306894ee97191eef7a531b76c78bb8e85caeca7cad" diff --git a/pyproject.toml b/pyproject.toml index c18abbf..6df67bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.10" -transformers = "^4.41.0" +transformers = "^4.51.2" torch = "^2.5.1" pydantic = "^2.5.3" pydantic-settings = "^2.1.0" diff --git a/surya/common/surya/decoder/__init__.py b/surya/common/surya/decoder/__init__.py index 1f8528c..0b77cda 100644 --- a/surya/common/surya/decoder/__init__.py +++ b/surya/common/surya/decoder/__init__.py @@ -173,12 +173,14 @@ class Qwen2Attention(nn.Module): # IMPORTANT: Do not use causal mask for prefill; Matches training # This is required for flash attn, which doesn't support a 4D mask as input # The `is_causal` argument is ignored by SDPA since we pass a 4D attention mask - if self.config._attn_implementation == 'flash_attention_2': - is_prefill = all(( - input_shape[1] > 1, - (past_key_value is None) or (past_key_value.get_seq_length(self.layer_idx) == 0) - )) - + if self.config._attn_implementation == "flash_attention_2": + is_prefill = all( + ( + input_shape[1] > 1, + (past_key_value is None) + or (past_key_value.get_seq_length(self.layer_idx) == 0), + ) + ) if is_prefill: self.is_causal = False else: diff --git a/surya/common/surya/processor/__init__.py b/surya/common/surya/processor/__init__.py index 96770fb..1b56e69 100644 --- a/surya/common/surya/processor/__init__.py +++ b/surya/common/surya/processor/__init__.py @@ -165,7 +165,7 @@ class SuryaOCRProcessor(S3DownloaderMixin, ProcessorMixin): num_tiles = image_tiles.shape[0] input_ids = [self.image_token_id] * num_tiles * self.image_tokens_per_tile - input_ids += [self.register_token_ids][: self.num_register_tokens] + input_ids += self.register_token_ids[: self.num_register_tokens] # Handle the image being rotated in the imdataset if rotated: diff --git a/surya/recognition/__init__.py b/surya/recognition/__init__.py index 25d3261..64dd69a 100644 --- a/surya/recognition/__init__.py +++ b/surya/recognition/__init__.py @@ -8,6 +8,7 @@ import torch from PIL import Image from tqdm import tqdm import torch.nn.functional as F +from transformers import QuantizedCacheConfig from surya.common.polygon import PolygonBox from surya.common.surya import SuryaModelConfig, SuryaModelOutput @@ -30,7 +31,10 @@ from surya.recognition.util import ( ) from surya.recognition.schema import TextLine, OCRResult, TextChar from surya.common.surya.schema import TaskNames -from surya.recognition.cache import ContinuousBatchingCache +from surya.recognition.cache import ( + ContinuousBatchingCache, + ContinuousBatchingQuantizedCache, +) from surya.settings import settings @@ -383,7 +387,23 @@ class RecognitionPredictor(BasePredictor): needs_boxes = [self.tasks[p.task_name]["needs_bboxes"] for p in prompts] skip_box_idxs = ~torch.from_numpy(np.array(needs_boxes)).to(self.model.device) - prefill_cache = ContinuousBatchingCache() + if settings.RECOGNITION_MODEL_QUANTIZE: + try: + import hqq # noqa: F401 + except Exception: + raise ImportError( + "Please install hqq to use quantized recognition model" + ) + + # Use quantized cache if setting activated + cache_config = QuantizedCacheConfig( + "HQQ", 8, 1, 1, device=self.model.device, compute_dtype=self.model.dtype + ) + prefill_cache = ( + ContinuousBatchingCache() + if not settings.RECOGNITION_MODEL_QUANTIZE + else ContinuousBatchingQuantizedCache(cache_config) + ) with settings.INFERENCE_MODE(): outputs = self.model( diff --git a/surya/recognition/cache.py b/surya/recognition/cache.py index d4d6e95..1feea26 100644 --- a/surya/recognition/cache.py +++ b/surya/recognition/cache.py @@ -1,45 +1,64 @@ import torch -from transformers import DynamicCache +from transformers import DynamicCache, HQQQuantizedCache from typing import List, Tuple -class ContinuousBatchingCache(DynamicCache): + +class ContinuousBatchingMixin: def pad_left( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - padding_size: int + self, key_states: torch.Tensor, value_states: torch.Tensor, padding_size: int ) -> Tuple[torch.Tensor, torch.Tensor]: # Size is assumed to be (batch_size, num_kv_heads, seq_length, head_dim) - To match huggingface - key_padding = torch.zeros((key_states.shape[0], key_states.shape[1], padding_size, key_states.shape[3]), device=key_states.device, dtype=key_states.dtype) - key_states_padded = torch.cat([key_padding, key_states], dim=-2) # Pad along the sequence length dimension (dim=-2) + key_padding = torch.zeros( + ( + key_states.shape[0], + key_states.shape[1], + padding_size, + key_states.shape[3], + ), + device=key_states.device, + dtype=key_states.dtype, + ) + key_states_padded = torch.cat( + [key_padding, key_states], dim=-2 + ) # Pad along the sequence length dimension (dim=-2) # Pad value_states to the left by `padding_size` - value_padding = torch.zeros((value_states.shape[0], value_states.shape[1], padding_size, value_states.shape[3]), device=value_states.device, dtype=value_states.dtype) - value_states_padded = torch.cat([value_padding, value_states], dim=-2) # Pad along the sequence length dimension (dim=-2) + value_padding = torch.zeros( + ( + value_states.shape[0], + value_states.shape[1], + padding_size, + value_states.shape[3], + ), + device=value_states.device, + dtype=value_states.dtype, + ) + value_states_padded = torch.cat( + [value_padding, value_states], dim=-2 + ) # Pad along the sequence length dimension (dim=-2) return key_states_padded, value_states_padded # Trim the cache from the left - Useful when longer sequences are evicted and we have long padding on the left - def trim_left( - self, - trim_length: int - ): + def trim_left(self, trim_length: int): for layer_idx in range(len(self)): # cache sape is (batch_size, num_kv_heads, seq_length, head_dim); Trimming from head dim self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :, trim_length:, :] - self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :, trim_length:, :] + self.value_cache[layer_idx] = self.value_cache[layer_idx][ + :, :, trim_length:, : + ] + + def merge(self, new_cache: DynamicCache, merge_idxs: List[int]): + assert len(new_cache) == len(self), ( + "The two caches should have the same number of layers" + ) - def merge( - self, - new_cache: DynamicCache, - merge_idxs: List[int] - ): - assert len(new_cache) == len(self), "The two caches should have the same number of layers" - # We should TECHNICALLY be able to pad these values to 0s now, since they will be attention masked current_seq_length = self.get_seq_length() new_cache_seq_length = new_cache.get_seq_length() - offset = current_seq_length - new_cache_seq_length # Generally positive, but negative case is handled too + offset = ( + current_seq_length - new_cache_seq_length + ) # Generally positive, but negative case is handled too with torch.inference_mode(): # As long as we set the attention mask and position ids correctly, padding value can be anything for layer_idx in range(len(self)): @@ -48,11 +67,18 @@ class ContinuousBatchingCache(DynamicCache): new_k, new_v = self.pad_left(new_k, new_v, offset) if offset < 0: - adjusted_key_cache, adjusted_value_cache = self.pad_left(self.key_cache[layer_idx], self.value_cache[layer_idx], abs(offset)) + adjusted_key_cache, adjusted_value_cache = self.pad_left( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + abs(offset), + ) else: - adjusted_key_cache, adjusted_value_cache = self.key_cache[layer_idx], self.value_cache[layer_idx] + adjusted_key_cache, adjusted_value_cache = ( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + ) - # TODO Make this assignment batched? + # TODO Make this assignment batched? for i, merge_idx in enumerate(merge_idxs): adjusted_key_cache[merge_idx] = new_k[i] adjusted_value_cache[merge_idx] = new_v[i] @@ -60,4 +86,12 @@ class ContinuousBatchingCache(DynamicCache): self.key_cache[layer_idx] = adjusted_key_cache self.value_cache[layer_idx] = adjusted_value_cache - return offset \ No newline at end of file + return offset + + +class ContinuousBatchingCache(DynamicCache, ContinuousBatchingMixin): + pass + + +class ContinuousBatchingQuantizedCache(HQQQuantizedCache, ContinuousBatchingMixin): + pass diff --git a/surya/recognition/loader.py b/surya/recognition/loader.py index 09a681d..da3dd99 100644 --- a/surya/recognition/loader.py +++ b/surya/recognition/loader.py @@ -5,24 +5,11 @@ from transformers import AutoImageProcessor from surya.common.load import ModelLoader from surya.common.surya.config import SuryaModelConfig -from surya.common.surya.__init__ import SuryaModel -from surya.common.surya.processor.__init__ import SuryaOCRProcessor +from surya.common.surya import SuryaModel +from surya.common.surya.processor import SuryaOCRProcessor from surya.common.surya.processor.tokenizer import SuryaOCRTokenizer from surya.settings import settings -try: - import flash_attn - flash_available = True -except ImportError: - flash_available = False - -torch.backends.cuda.enable_cudnn_sdp(settings.ENABLE_CUDNN_ATTENTION) -if not settings.ENABLE_EFFICIENT_ATTENTION: - print("Efficient attention is disabled. This will use significantly more VRAM.") - torch.backends.cuda.enable_mem_efficient_sdp(False) - torch.backends.cuda.enable_flash_sdp(False) - torch.backends.cuda.enable_math_sdp(True) - class RecognitionModelLoader(ModelLoader): def __init__(self, checkpoint: Optional[str] = None): @@ -39,34 +26,12 @@ class RecognitionModelLoader(ModelLoader): if dtype is None: dtype = settings.MODEL_DTYPE_BFLOAT - quant_config = {} - if settings.RECOGNITION_MODEL_QUANTIZE: - try: - from torchao.quantization import Int4WeightOnlyConfig - from transformers import TorchAoConfig - except ImportError as e: - raise RuntimeError( - "`hqq` package is required for quantization. Please install it." - ) from e - - quant_config = Int4WeightOnlyConfig(group_size=64) - quantization_config = TorchAoConfig(quant_type=quant_config) - quant_config = { - "quantization_config": quantization_config, - "device_map": device, - "torch_dtype": "auto", - } - - model = SuryaModel.from_pretrained(self.checkpoint, **quant_config) + model = SuryaModel.from_pretrained(self.checkpoint, torch_dtype=dtype).to( + device + ) model = model.eval() - if not settings.RECOGNITION_MODEL_QUANTIZE: - model = model.to(device=device, dtype=dtype) - - if flash_available: - model.config.decoder._attn_implementation = "flash_attention_2" - else: - model.config.decoder._attn_implementation = "sdpa" + model.config.decoder._attn_implementation = "sdpa" if settings.COMPILE_ALL or settings.COMPILE_RECOGNITION: torch.set_float32_matmul_precision("high") @@ -81,7 +46,7 @@ class RecognitionModelLoader(ModelLoader): model.decoder = torch.compile(model.decoder, **compile_args) print( - f"Loaded recognition model {self.checkpoint} on device {model.device} with dtype {dtype}" + f"Loaded recognition model {self.checkpoint} on device {model.device} with dtype {dtype}, using attention mechanism {model.config.decoder._attn_implementation}" ) return model diff --git a/surya/settings.py b/surya/settings.py index 118564c..b1b6777 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -13,10 +13,6 @@ class Settings(BaseSettings): IMAGE_DPI: int = 96 # Used for detection, layout, reading order IMAGE_DPI_HIGHRES: int = 192 # Used for OCR, table rec IN_STREAMLIT: bool = False # Whether we're running in streamlit - ENABLE_EFFICIENT_ATTENTION: bool = ( - True # Usually keep True, but if you get CUDA errors, setting to False can help - ) - ENABLE_CUDNN_ATTENTION: bool = False # Causes issues on many systems when set to True, but can improve performance on certain GPUs FLATTEN_PDF: bool = True # Flatten PDFs by merging form fields before processing DISABLE_TQDM: bool = False # Disable tqdm progress bars S3_BASE_URL: str = "https://models.datalab.to"