mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-21 21:14:23 +08:00
merge experimental sage pr
This commit is contained in:
commit
f362ec6638
@ -31,6 +31,8 @@ attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--attention-split", action="store_true")
|
||||
attn_group.add_argument("--attention-quad", action="store_true")
|
||||
attn_group.add_argument("--attention-pytorch", action="store_true")
|
||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
||||
|
||||
upcast = parser.add_mutually_exclusive_group()
|
||||
upcast.add_argument("--force-upcast-attention", action="store_true")
|
||||
|
||||
@ -18,6 +18,22 @@ if memory_management.xformers_enabled():
|
||||
except:
|
||||
pass
|
||||
|
||||
if memory_management.sage_attention_enabled():
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ModuleNotFoundError:
|
||||
print(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
exit(-1)
|
||||
|
||||
if memory_management.flash_attention_enabled():
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except ModuleNotFoundError:
|
||||
print(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||
exit(-1)
|
||||
|
||||
import backend.operations
|
||||
ops = backend.operations.ForgeOperations
|
||||
|
||||
FORCE_UPCAST_ATTENTION_DTYPE = memory_management.force_upcast_attention_dtype()
|
||||
|
||||
@ -338,6 +354,102 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
)
|
||||
return out
|
||||
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
# sageattn doesn't work with sd1.5
|
||||
if q.shape[-1] // heads not in [64, 96, 128]:
|
||||
if memory_management.flash_attention_enabled():
|
||||
return attention_flash(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
|
||||
elif memory_management.xformers_enabled():
|
||||
return attention_xformers(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout="HND"
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
tensor_layout="NHD"
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
if tensor_layout == "HND":
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
if skip_output_reshape:
|
||||
out = out.transpose(1, 2)
|
||||
else:
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
||||
|
||||
|
||||
@flash_attn_wrapper.register_fake
|
||||
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
||||
# Output shape is the same as q
|
||||
return q.new_empty(q.shape)
|
||||
except AttributeError as error:
|
||||
FLASH_ATTN_ERROR = error
|
||||
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||
|
||||
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
try:
|
||||
assert mask is None
|
||||
out = flash_attn_wrapper(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
).transpose(1, 2)
|
||||
except Exception as e:
|
||||
print(f"Flash Attention failed, using default SDPA: {e}")
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
def slice_attention_single_head_spatial(q, k, v):
|
||||
r1 = torch.zeros_like(k, device=q.device)
|
||||
@ -427,7 +539,13 @@ def pytorch_attention_single_head_spatial(q, k, v):
|
||||
return out
|
||||
|
||||
|
||||
if memory_management.xformers_enabled():
|
||||
if memory_management.sage_attention_enabled():
|
||||
print("Using sage attention")
|
||||
attention_function = attention_sage
|
||||
elif memory_management.flash_attention_enabled():
|
||||
print("Using Flash Attention")
|
||||
attention_function = attention_flash
|
||||
elif memory_management.xformers_enabled():
|
||||
print("Using xformers cross attention")
|
||||
attention_function = attention_xformers
|
||||
elif memory_management.pytorch_attention_enabled():
|
||||
|
||||
@ -955,6 +955,11 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
else:
|
||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
||||
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
def flash_attention_enabled():
|
||||
return args.use_flash_attention
|
||||
|
||||
def xformers_enabled():
|
||||
global directml_enabled
|
||||
|
||||
@ -159,9 +159,53 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly):
|
||||
process.modified_noise = original_noise + self.latent.to(original_noise) / sigma_max.to(original_noise)
|
||||
return cond, mask
|
||||
|
||||
class PreprocessorInpaintNoobAIXL(Preprocessor): # support noob ctrlnet inpaint model https://civitai.com/models/1376234/noobai-inpainting-controlnet
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = 'inpaint_noobai_xl'
|
||||
self.tags = ['Inpaint']
|
||||
self.model_filename_filters = ['inpaint', 'noobai']
|
||||
self.slider_resolution = PreprocessorParameter(visible=False)
|
||||
self.fill_mask_with_one_when_resize_and_fill = True
|
||||
self.expand_mask_when_resize_and_fill = True
|
||||
|
||||
def __call__(self, input_image, resolution=512, slider_1=None, slider_2=None, slider_3=None, input_mask=None, **kwargs):
|
||||
if input_mask is None:
|
||||
return input_image
|
||||
|
||||
if not isinstance(input_image, np.ndarray):
|
||||
input_image = np.array(input_image)
|
||||
if not isinstance(input_mask, np.ndarray):
|
||||
input_mask = np.array(input_mask)
|
||||
|
||||
mask = input_mask.astype(np.float32) / 255.0
|
||||
mask = (mask > 0.5).astype(np.float32)
|
||||
|
||||
# Create a copy of the input image
|
||||
result = input_image.copy()
|
||||
|
||||
# Convert mask to proper shape if needed
|
||||
if mask.ndim == 2:
|
||||
mask = np.expand_dims(mask, axis=-1)
|
||||
if mask.shape[-1] == 1:
|
||||
mask = np.repeat(mask, 3, axis=-1)
|
||||
|
||||
mask_indices = mask > 0.5
|
||||
result[mask_indices] = 0.0
|
||||
|
||||
return result
|
||||
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
mask = mask.round()
|
||||
mixed_cond = cond.clone()
|
||||
mixed_cond = mixed_cond * (1.0 - mask)
|
||||
|
||||
return mixed_cond, None
|
||||
|
||||
add_supported_preprocessor(PreprocessorInpaint())
|
||||
|
||||
add_supported_preprocessor(PreprocessorInpaintOnly())
|
||||
|
||||
add_supported_preprocessor(PreprocessorInpaintLama())
|
||||
|
||||
add_supported_preprocessor(PreprocessorInpaintNoobAIXL())
|
||||
|
||||
@ -40,6 +40,23 @@ gradio_hf_hub_themes = [
|
||||
"NoCrypt/miku"
|
||||
]
|
||||
|
||||
# automaticly add local themes
|
||||
theme_dir = "tmp/gradio_themes"
|
||||
json_files = []
|
||||
|
||||
try:
|
||||
if os.path.exists(theme_dir):
|
||||
json_files = [f for f in os.listdir(theme_dir) if f.endswith('.json')]
|
||||
else:
|
||||
print(f"Directory {theme_dir} does not exist. No new themes will be added.")
|
||||
except OSError as e:
|
||||
print(f"Error accessing directory {theme_dir}: {e}. No new themes will be added.")
|
||||
|
||||
for json_file in json_files:
|
||||
theme_name = json_file.replace('.json', '')
|
||||
theme_name = theme_name.replace('_', '/')
|
||||
if theme_name not in gradio_hf_hub_themes:
|
||||
gradio_hf_hub_themes.append(theme_name)
|
||||
|
||||
def reload_gradio_theme(theme_name=None):
|
||||
if not theme_name:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user