mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-13 21:01:06 +08:00
fix sage attention implementation
This commit is contained in:
parent
f581c39302
commit
fd3a3b856a
10
README.md
10
README.md
@ -20,13 +20,15 @@ set the distilled config scale to 1, and the normal config scale to something li
|
||||
use a very long positive prompt and a very long negative prompt.
|
||||
forge doesn't seem to work with all quantized model, Q4_K_S fail, but Q4_1 work
|
||||
|
||||
To update to torch 2.7.0 with cuda 12.8 on windows, navigate to your home directory ie c:/chromaforge and run these commands:
|
||||
To update to torch 2.7.0 with cuda 12.8 on windows and install sage attention, navigate to your home directory ie c:/chromaforge and run these commands:
|
||||
venv/scripts/activate
|
||||
pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --index-url https://download.pytorch.org/whl/cu128
|
||||
|
||||
pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --index-url https://download.pytorch.org/whl/cu128
|
||||
pip install -U "triton-windows<3.4"
|
||||
pip install .\sageattention-2.1.1+cu128torch2.7.0-cp310-cp310-win_amd64.whl
|
||||
|
||||
## Changlog
|
||||
|
||||
5/27/2025
|
||||
Fixed the sage attention implementation to work with chroma.
|
||||
5/25/2025
|
||||
Add support for sage and flash attention from this pr: https://github.com/lllyasviel/stable-diffusion-webui-forge/pull/2881 from @spawner1145
|
||||
use the methods by adding --use-sage-attention or --use-flash-attention ... upon testing by a few people does not seem to have an increase on speed at all.
|
||||
|
||||
@ -355,13 +355,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
return out
|
||||
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
# sageattn doesn't work with sd1.5
|
||||
if q.shape[-1] // heads not in [64, 96, 128]:
|
||||
dim_per_head = q.shape[-1] // heads
|
||||
#print(f"[DEBUG] Entering attention_sage. dim_per_head: {dim_per_head}")
|
||||
if dim_per_head not in [5, 64, 96, 128]:
|
||||
#print(f"[DEBUG] Sage Attention: dim_per_head {dim_per_head} not supported. Falling back.") # <--- ADD THIS
|
||||
if memory_management.flash_attention_enabled():
|
||||
#print("[DEBUG] Sage Attention: Falling back to Flash Attention.") # <--- ADD THIS
|
||||
return attention_flash(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
|
||||
elif memory_management.xformers_enabled():
|
||||
#print("[DEBUG] Sage Attention: Falling back to xFormers.") # <--- ADD THIS
|
||||
return attention_xformers(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
|
||||
#print("[DEBUG] Sage Attention: Falling back to PyTorch SDPA.") # <--- ADD THIS
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
|
||||
# sageattn doesn't work with sd1.5
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout="HND"
|
||||
@ -381,7 +387,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
#print("[DEBUG] Sage Attention: Calling sageattn library.")
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
if tensor_layout == "HND":
|
||||
if not skip_output_reshape:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user