diff --git a/surya/foundation/loader.py b/surya/foundation/loader.py index 8861b7b..77e6256 100644 --- a/surya/foundation/loader.py +++ b/surya/foundation/loader.py @@ -36,7 +36,7 @@ class FoundationModelLoader(ModelLoader): # emulated bf16, but falls back to very slow kernels, especially for SDPA dtype = settings.MODEL_DTYPE_BFLOAT if device == "cuda" and not torch.cuda.is_bf16_supported( - including_emulation=True + including_emulation=False ): # If the device is cuda, we check if bf16 is supported, and if not, we use float16 dtype = settings.MODEL_DTYPE