From d3aecc0977ac22bbf147cac294ad78fffbe13407 Mon Sep 17 00:00:00 2001 From: Tarun Menta Date: Tue, 23 Sep 2025 17:22:34 -0400 Subject: [PATCH] Pick correct dtype on T4 GPUs --- surya/foundation/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/surya/foundation/loader.py b/surya/foundation/loader.py index 8861b7b..77e6256 100644 --- a/surya/foundation/loader.py +++ b/surya/foundation/loader.py @@ -36,7 +36,7 @@ class FoundationModelLoader(ModelLoader): # emulated bf16, but falls back to very slow kernels, especially for SDPA dtype = settings.MODEL_DTYPE_BFLOAT if device == "cuda" and not torch.cuda.is_bf16_supported( - including_emulation=True + including_emulation=False ): # If the device is cuda, we check if bf16 is supported, and if not, we use float16 dtype = settings.MODEL_DTYPE