From f07e7b3ea8510f93f6ce6a80802acf5fc43cc276 Mon Sep 17 00:00:00 2001 From: maybleMyers Date: Thu, 4 Dec 2025 00:58:49 -0800 Subject: [PATCH] fix fp16 controlnets --- backend/nn/zimage_control.py | 16 +++++++++++----- modules_forge/supported_controlnet.py | 6 ++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/backend/nn/zimage_control.py b/backend/nn/zimage_control.py index c9628898..0f0667c0 100644 --- a/backend/nn/zimage_control.py +++ b/backend/nn/zimage_control.py @@ -18,6 +18,12 @@ ADALN_EMBED_DIM = 256 SEQ_MULTI_OF = 32 +def clamp_fp16(x): + if x.dtype == torch.float16: + return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x + + class ZSingleStreamAttnProcessor: """ Processor for Z-Image single stream attention. @@ -94,7 +100,7 @@ class FeedForward(nn.Module): self.w3 = nn.Linear(dim, hidden_dim, bias=False) def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + return self.w2(clamp_fp16(F.silu(self.w1(x)) * self.w3(x))) class ZImageTransformerBlock(nn.Module): @@ -157,9 +163,9 @@ class ZImageTransformerBlock(nn.Module): attention_mask=attn_mask, freqs_cis=freqs_cis, ) - x = x + gate_msa * self.attention_norm2(attn_out) + x = x + gate_msa * self.attention_norm2(clamp_fp16(attn_out)) x = x + gate_mlp * self.ffn_norm2( - self.feed_forward(self.ffn_norm1(x) * scale_mlp) + clamp_fp16(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) ) else: attn_out = self.attention( @@ -167,8 +173,8 @@ class ZImageTransformerBlock(nn.Module): attention_mask=attn_mask, freqs_cis=freqs_cis, ) - x = x + self.attention_norm2(attn_out) - x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + x = x + self.attention_norm2(clamp_fp16(attn_out)) + x = x + self.ffn_norm2(clamp_fp16(self.feed_forward(self.ffn_norm1(x)))) return x diff --git a/modules_forge/supported_controlnet.py b/modules_forge/supported_controlnet.py index 8d275f78..d913c7b4 100644 --- a/modules_forge/supported_controlnet.py +++ b/modules_forge/supported_controlnet.py @@ -328,15 +328,13 @@ class ZImageControlNetPatcher(ControlModelPatcher): def _ensure_control_layers_loaded(self, sd_model, unet): """Load control weights into transformer if not already loaded""" - if self._control_loaded: - return - try: # Get the wrapped transformer wrapped_model = unet.model.diffusion_model transformer = wrapped_model.transformer if hasattr(wrapped_model, 'transformer') else wrapped_model - # Check if already loaded + # Always check if the CURRENT transformer has control layers loaded + # (model may have been reloaded, creating a new transformer without control layers) if hasattr(transformer, '_control_layers_loaded') and transformer._control_layers_loaded: self._control_loaded = True return