From fd3a3b856abf433c558f52bd2c012b2cd8d1036a Mon Sep 17 00:00:00 2001 From: maybleMyers Date: Tue, 27 May 2025 02:01:28 -0700 Subject: [PATCH] fix sage attention implementation --- README.md | 10 ++++++---- backend/attention.py | 12 +++++++++--- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 33c11272..fdc5ded3 100644 --- a/README.md +++ b/README.md @@ -20,13 +20,15 @@ set the distilled config scale to 1, and the normal config scale to something li use a very long positive prompt and a very long negative prompt. forge doesn't seem to work with all quantized model, Q4_K_S fail, but Q4_1 work -To update to torch 2.7.0 with cuda 12.8 on windows, navigate to your home directory ie c:/chromaforge and run these commands: +To update to torch 2.7.0 with cuda 12.8 on windows and install sage attention, navigate to your home directory ie c:/chromaforge and run these commands: venv/scripts/activate - pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --index-url https://download.pytorch.org/whl/cu128 - +pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --index-url https://download.pytorch.org/whl/cu128 +pip install -U "triton-windows<3.4" +pip install .\sageattention-2.1.1+cu128torch2.7.0-cp310-cp310-win_amd64.whl ## Changlog - +5/27/2025 + Fixed the sage attention implementation to work with chroma. 5/25/2025 Add support for sage and flash attention from this pr: https://github.com/lllyasviel/stable-diffusion-webui-forge/pull/2881 from @spawner1145 use the methods by adding --use-sage-attention or --use-flash-attention ... upon testing by a few people does not seem to have an increase on speed at all. diff --git a/backend/attention.py b/backend/attention.py index 077111ee..e745948b 100644 --- a/backend/attention.py +++ b/backend/attention.py @@ -355,13 +355,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha return out def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): - # sageattn doesn't work with sd1.5 - if q.shape[-1] // heads not in [64, 96, 128]: + dim_per_head = q.shape[-1] // heads + #print(f"[DEBUG] Entering attention_sage. dim_per_head: {dim_per_head}") + if dim_per_head not in [5, 64, 96, 128]: + #print(f"[DEBUG] Sage Attention: dim_per_head {dim_per_head} not supported. Falling back.") # <--- ADD THIS if memory_management.flash_attention_enabled(): + #print("[DEBUG] Sage Attention: Falling back to Flash Attention.") # <--- ADD THIS return attention_flash(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape) elif memory_management.xformers_enabled(): + #print("[DEBUG] Sage Attention: Falling back to xFormers.") # <--- ADD THIS return attention_xformers(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape) + #print("[DEBUG] Sage Attention: Falling back to PyTorch SDPA.") # <--- ADD THIS return attention_pytorch(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape) + # sageattn doesn't work with sd1.5 if skip_reshape: b, _, _, dim_head = q.shape tensor_layout="HND" @@ -381,7 +387,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= # add a heads dimension if there isn't already one if mask.ndim == 3: mask = mask.unsqueeze(1) - + #print("[DEBUG] Sage Attention: Calling sageattn library.") out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) if tensor_layout == "HND": if not skip_output_reshape: