fix sage attention implementation

This commit is contained in:
maybleMyers 2025-05-27 02:01:28 -07:00
parent f581c39302
commit fd3a3b856a
2 changed files with 15 additions and 7 deletions

View File

@ -20,13 +20,15 @@ set the distilled config scale to 1, and the normal config scale to something li
use a very long positive prompt and a very long negative prompt.
forge doesn't seem to work with all quantized model, Q4_K_S fail, but Q4_1 work
To update to torch 2.7.0 with cuda 12.8 on windows, navigate to your home directory ie c:/chromaforge and run these commands:
To update to torch 2.7.0 with cuda 12.8 on windows and install sage attention, navigate to your home directory ie c:/chromaforge and run these commands:
venv/scripts/activate
pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --index-url https://download.pytorch.org/whl/cu128
pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --index-url https://download.pytorch.org/whl/cu128
pip install -U "triton-windows<3.4"
pip install .\sageattention-2.1.1+cu128torch2.7.0-cp310-cp310-win_amd64.whl
## Changlog
5/27/2025
Fixed the sage attention implementation to work with chroma.
5/25/2025
Add support for sage and flash attention from this pr: https://github.com/lllyasviel/stable-diffusion-webui-forge/pull/2881 from @spawner1145
use the methods by adding --use-sage-attention or --use-flash-attention ... upon testing by a few people does not seem to have an increase on speed at all.

View File

@ -355,13 +355,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
# sageattn doesn't work with sd1.5
if q.shape[-1] // heads not in [64, 96, 128]:
dim_per_head = q.shape[-1] // heads
#print(f"[DEBUG] Entering attention_sage. dim_per_head: {dim_per_head}")
if dim_per_head not in [5, 64, 96, 128]:
#print(f"[DEBUG] Sage Attention: dim_per_head {dim_per_head} not supported. Falling back.") # <--- ADD THIS
if memory_management.flash_attention_enabled():
#print("[DEBUG] Sage Attention: Falling back to Flash Attention.") # <--- ADD THIS
return attention_flash(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
elif memory_management.xformers_enabled():
#print("[DEBUG] Sage Attention: Falling back to xFormers.") # <--- ADD THIS
return attention_xformers(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
#print("[DEBUG] Sage Attention: Falling back to PyTorch SDPA.") # <--- ADD THIS
return attention_pytorch(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
# sageattn doesn't work with sd1.5
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
@ -381,7 +387,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
#print("[DEBUG] Sage Attention: Calling sageattn library.")
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
if not skip_output_reshape: