fix fp16 controlnets

This commit is contained in:
maybleMyers 2025-12-04 00:58:49 -08:00
parent fb08f0f592
commit f07e7b3ea8
2 changed files with 13 additions and 9 deletions

View File

@ -18,6 +18,12 @@ ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32
def clamp_fp16(x):
if x.dtype == torch.float16:
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class ZSingleStreamAttnProcessor:
"""
Processor for Z-Image single stream attention.
@ -94,7 +100,7 @@ class FeedForward(nn.Module):
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
return self.w2(clamp_fp16(F.silu(self.w1(x)) * self.w3(x)))
class ZImageTransformerBlock(nn.Module):
@ -157,9 +163,9 @@ class ZImageTransformerBlock(nn.Module):
attention_mask=attn_mask,
freqs_cis=freqs_cis,
)
x = x + gate_msa * self.attention_norm2(attn_out)
x = x + gate_msa * self.attention_norm2(clamp_fp16(attn_out))
x = x + gate_mlp * self.ffn_norm2(
self.feed_forward(self.ffn_norm1(x) * scale_mlp)
clamp_fp16(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
)
else:
attn_out = self.attention(
@ -167,8 +173,8 @@ class ZImageTransformerBlock(nn.Module):
attention_mask=attn_mask,
freqs_cis=freqs_cis,
)
x = x + self.attention_norm2(attn_out)
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
x = x + self.attention_norm2(clamp_fp16(attn_out))
x = x + self.ffn_norm2(clamp_fp16(self.feed_forward(self.ffn_norm1(x))))
return x

View File

@ -328,15 +328,13 @@ class ZImageControlNetPatcher(ControlModelPatcher):
def _ensure_control_layers_loaded(self, sd_model, unet):
"""Load control weights into transformer if not already loaded"""
if self._control_loaded:
return
try:
# Get the wrapped transformer
wrapped_model = unet.model.diffusion_model
transformer = wrapped_model.transformer if hasattr(wrapped_model, 'transformer') else wrapped_model
# Check if already loaded
# Always check if the CURRENT transformer has control layers loaded
# (model may have been reloaded, creating a new transformer without control layers)
if hasattr(transformer, '_control_layers_loaded') and transformer._control_layers_loaded:
self._control_loaded = True
return