diff --git a/backend/memory_management.py b/backend/memory_management.py index 5f0c8312..56d19b7c 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -955,6 +955,11 @@ def cast_to_device(tensor, device, dtype, copy=False): else: return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) +def sage_attention_enabled(): + return args.use_sage_attention + +def flash_attention_enabled(): + return args.use_flash_attention def xformers_enabled(): global directml_enabled