mirror of
https://github.com/VikParuchuri/surya.git
synced 2026-06-04 21:03:53 +08:00
Pick correct dtype on T4 GPUs
Some checks failed
Integration test / build (push) Has been cancelled
Unit tests / build (t4_gpu) (push) Has been cancelled
Unit tests / build (ubuntu-latest) (push) Has been cancelled
Unit tests / build (windows-latest) (push) Has been cancelled
Test CLI scripts / build (push) Has been cancelled
Some checks failed
Integration test / build (push) Has been cancelled
Unit tests / build (t4_gpu) (push) Has been cancelled
Unit tests / build (ubuntu-latest) (push) Has been cancelled
Unit tests / build (windows-latest) (push) Has been cancelled
Test CLI scripts / build (push) Has been cancelled
This commit is contained in:
parent
eb179cc543
commit
d3aecc0977
@ -36,7 +36,7 @@ class FoundationModelLoader(ModelLoader):
|
||||
# emulated bf16, but falls back to very slow kernels, especially for SDPA
|
||||
dtype = settings.MODEL_DTYPE_BFLOAT
|
||||
if device == "cuda" and not torch.cuda.is_bf16_supported(
|
||||
including_emulation=True
|
||||
including_emulation=False
|
||||
):
|
||||
# If the device is cuda, we check if bf16 is supported, and if not, we use float16
|
||||
dtype = settings.MODEL_DTYPE
|
||||
|
||||
Loading…
Reference in New Issue
Block a user