Pick correct dtype on T4 GPUs
Some checks failed
Integration test / build (push) Has been cancelled
Unit tests / build (t4_gpu) (push) Has been cancelled
Unit tests / build (ubuntu-latest) (push) Has been cancelled
Unit tests / build (windows-latest) (push) Has been cancelled
Test CLI scripts / build (push) Has been cancelled

This commit is contained in:
Tarun Menta 2025-09-23 17:22:34 -04:00
parent eb179cc543
commit d3aecc0977
No known key found for this signature in database

View File

@ -36,7 +36,7 @@ class FoundationModelLoader(ModelLoader):
# emulated bf16, but falls back to very slow kernels, especially for SDPA
dtype = settings.MODEL_DTYPE_BFLOAT
if device == "cuda" and not torch.cuda.is_bf16_supported(
including_emulation=True
including_emulation=False
):
# If the device is cuda, we check if bf16 is supported, and if not, we use float16
dtype = settings.MODEL_DTYPE