merge experimental sage pr

This commit is contained in:
maybleMyers 2025-05-25 18:23:12 -07:00
commit f362ec6638
5 changed files with 187 additions and 1 deletions

View File

@ -31,6 +31,8 @@ attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--attention-split", action="store_true")
attn_group.add_argument("--attention-quad", action="store_true")
attn_group.add_argument("--attention-pytorch", action="store_true")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
upcast = parser.add_mutually_exclusive_group()
upcast.add_argument("--force-upcast-attention", action="store_true")

View File

@ -18,6 +18,22 @@ if memory_management.xformers_enabled():
except:
pass
if memory_management.sage_attention_enabled():
try:
from sageattention import sageattn
except ModuleNotFoundError:
print(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
exit(-1)
if memory_management.flash_attention_enabled():
try:
from flash_attn import flash_attn_func
except ModuleNotFoundError:
print(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
import backend.operations
ops = backend.operations.ForgeOperations
FORCE_UPCAST_ATTENTION_DTYPE = memory_management.force_upcast_attention_dtype()
@ -338,6 +354,102 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
)
return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
# sageattn doesn't work with sd1.5
if q.shape[-1] // heads not in [64, 96, 128]:
if memory_management.flash_attention_enabled():
return attention_flash(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
elif memory_management.xformers_enabled():
return attention_xformers(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
return attention_pytorch(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
tensor_layout="NHD"
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
if skip_output_reshape:
out = out.transpose(1, 2)
else:
out = out.reshape(b, -1, heads * dim_head)
return out
try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
@flash_attn_wrapper.register_fake
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
# Output shape is the same as q
return q.new_empty(q.shape)
except AttributeError as error:
FLASH_ATTN_ERROR = error
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
try:
assert mask is None
out = flash_attn_wrapper(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
dropout_p=0.0,
causal=False,
).transpose(1, 2)
except Exception as e:
print(f"Flash Attention failed, using default SDPA: {e}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
def slice_attention_single_head_spatial(q, k, v):
r1 = torch.zeros_like(k, device=q.device)
@ -427,7 +539,13 @@ def pytorch_attention_single_head_spatial(q, k, v):
return out
if memory_management.xformers_enabled():
if memory_management.sage_attention_enabled():
print("Using sage attention")
attention_function = attention_sage
elif memory_management.flash_attention_enabled():
print("Using Flash Attention")
attention_function = attention_flash
elif memory_management.xformers_enabled():
print("Using xformers cross attention")
attention_function = attention_xformers
elif memory_management.pytorch_attention_enabled():

View File

@ -955,6 +955,11 @@ def cast_to_device(tensor, device, dtype, copy=False):
else:
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
def sage_attention_enabled():
return args.use_sage_attention
def flash_attention_enabled():
return args.use_flash_attention
def xformers_enabled():
global directml_enabled

View File

@ -159,9 +159,53 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly):
process.modified_noise = original_noise + self.latent.to(original_noise) / sigma_max.to(original_noise)
return cond, mask
class PreprocessorInpaintNoobAIXL(Preprocessor): # support noob ctrlnet inpaint model https://civitai.com/models/1376234/noobai-inpainting-controlnet
def __init__(self):
super().__init__()
self.name = 'inpaint_noobai_xl'
self.tags = ['Inpaint']
self.model_filename_filters = ['inpaint', 'noobai']
self.slider_resolution = PreprocessorParameter(visible=False)
self.fill_mask_with_one_when_resize_and_fill = True
self.expand_mask_when_resize_and_fill = True
def __call__(self, input_image, resolution=512, slider_1=None, slider_2=None, slider_3=None, input_mask=None, **kwargs):
if input_mask is None:
return input_image
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image)
if not isinstance(input_mask, np.ndarray):
input_mask = np.array(input_mask)
mask = input_mask.astype(np.float32) / 255.0
mask = (mask > 0.5).astype(np.float32)
# Create a copy of the input image
result = input_image.copy()
# Convert mask to proper shape if needed
if mask.ndim == 2:
mask = np.expand_dims(mask, axis=-1)
if mask.shape[-1] == 1:
mask = np.repeat(mask, 3, axis=-1)
mask_indices = mask > 0.5
result[mask_indices] = 0.0
return result
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
mask = mask.round()
mixed_cond = cond.clone()
mixed_cond = mixed_cond * (1.0 - mask)
return mixed_cond, None
add_supported_preprocessor(PreprocessorInpaint())
add_supported_preprocessor(PreprocessorInpaintOnly())
add_supported_preprocessor(PreprocessorInpaintLama())
add_supported_preprocessor(PreprocessorInpaintNoobAIXL())

View File

@ -40,6 +40,23 @@ gradio_hf_hub_themes = [
"NoCrypt/miku"
]
# automaticly add local themes
theme_dir = "tmp/gradio_themes"
json_files = []
try:
if os.path.exists(theme_dir):
json_files = [f for f in os.listdir(theme_dir) if f.endswith('.json')]
else:
print(f"Directory {theme_dir} does not exist. No new themes will be added.")
except OSError as e:
print(f"Error accessing directory {theme_dir}: {e}. No new themes will be added.")
for json_file in json_files:
theme_name = json_file.replace('.json', '')
theme_name = theme_name.replace('_', '/')
if theme_name not in gradio_hf_hub_themes:
gradio_hf_hub_themes.append(theme_name)
def reload_gradio_theme(theme_name=None):
if not theme_name: