diff --git a/backend/loader.py b/backend/loader.py index bc6e8329..9fa283da 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -8,6 +8,52 @@ import backend.args import huggingface_guess +def patch_zimage_for_fp16(model): + import torch.nn.functional as F + from diffusers.models.transformers.transformer_z_image import FeedForward, ZImageTransformerBlock + + def clamp_fp16(x): + if x.dtype == torch.float16: + return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x + + def patched_forward_silu_gating(self, x1, x3): + return clamp_fp16(F.silu(x1) * x3) + + for module in model.modules(): + if isinstance(module, FeedForward): + module._forward_silu_gating = patched_forward_silu_gating.__get__(module, FeedForward) + + original_block_forward = ZImageTransformerBlock.forward + + def patched_block_forward(self, x, attn_mask, freqs_cis, adaln_input=None): + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, + attention_mask=attn_mask, + freqs_cis=freqs_cis, + ) + x = x + gate_msa * self.attention_norm2(clamp_fp16(attn_out)) + x = x + gate_mlp * self.ffn_norm2(clamp_fp16(self.feed_forward(self.ffn_norm1(x) * scale_mlp))) + else: + attn_out = self.attention( + self.attention_norm1(x), + attention_mask=attn_mask, + freqs_cis=freqs_cis, + ) + x = x + self.attention_norm2(clamp_fp16(attn_out)) + x = x + self.ffn_norm2(clamp_fp16(self.feed_forward(self.ffn_norm1(x)))) + return x + + for module in model.modules(): + if isinstance(module, ZImageTransformerBlock): + module.forward = patched_block_forward.__get__(module, ZImageTransformerBlock) + + def convert_comfy_zimage_state_dict(state_dict): """ Convert ComfyUI Z-Image state dict format to Diffusers format. @@ -336,6 +382,9 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p load_state_dict(model, state_dict) + if cls_name == 'ZImageTransformer2DModel': + patch_zimage_for_fp16(model) + if hasattr(model, '_internal_dict'): model._internal_dict = unet_config else: diff --git a/backend/utils.py b/backend/utils.py index c88fceae..6dc596ad 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -102,12 +102,8 @@ def tensor2parameter(x): def fp16_fix(x): - # An interesting trick to avoid fp16 overflow - # Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114 - # Related: https://github.com/comfyanonymous/ComfyUI/blob/f1d6cef71c70719cc3ed45a2455a4e5ac910cd5e/comfy/ldm/flux/layers.py#L180 - - if x.dtype in [torch.float16]: - return x.clip(-32768.0, 32768.0) + if x.dtype == torch.float16: + return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) return x diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index 8cf41381..66bbb6aa 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -251,7 +251,10 @@ def refresh_model_loading_parameters(): model_data.forge_loading_parameters = dict( checkpoint_info=checkpoint_info, additional_modules=shared.opts.forge_additional_modules, - unet_storage_dtype=unet_storage_dtype + unet_storage_dtype=unet_storage_dtype, + z_transformer_dtype=getattr(shared.opts, 'z_transformer_dtype', 'Automatic'), + z_vae_dtype=getattr(shared.opts, 'z_vae_dtype', 'Automatic'), + z_text_encoder_dtype=getattr(shared.opts, 'z_text_encoder_dtype', 'Automatic'), ) print(f'Model selected: {model_data.forge_loading_parameters}')