mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-04 21:05:48 +08:00
fix fp16 controlnets
This commit is contained in:
parent
fb08f0f592
commit
f07e7b3ea8
@ -18,6 +18,12 @@ ADALN_EMBED_DIM = 256
|
||||
SEQ_MULTI_OF = 32
|
||||
|
||||
|
||||
def clamp_fp16(x):
|
||||
if x.dtype == torch.float16:
|
||||
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
class ZSingleStreamAttnProcessor:
|
||||
"""
|
||||
Processor for Z-Image single stream attention.
|
||||
@ -94,7 +100,7 @@ class FeedForward(nn.Module):
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
return self.w2(clamp_fp16(F.silu(self.w1(x)) * self.w3(x)))
|
||||
|
||||
|
||||
class ZImageTransformerBlock(nn.Module):
|
||||
@ -157,9 +163,9 @@ class ZImageTransformerBlock(nn.Module):
|
||||
attention_mask=attn_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||
x = x + gate_msa * self.attention_norm2(clamp_fp16(attn_out))
|
||||
x = x + gate_mlp * self.ffn_norm2(
|
||||
self.feed_forward(self.ffn_norm1(x) * scale_mlp)
|
||||
clamp_fp16(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
|
||||
)
|
||||
else:
|
||||
attn_out = self.attention(
|
||||
@ -167,8 +173,8 @@ class ZImageTransformerBlock(nn.Module):
|
||||
attention_mask=attn_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
x = x + self.attention_norm2(attn_out)
|
||||
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
||||
x = x + self.attention_norm2(clamp_fp16(attn_out))
|
||||
x = x + self.ffn_norm2(clamp_fp16(self.feed_forward(self.ffn_norm1(x))))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@ -328,15 +328,13 @@ class ZImageControlNetPatcher(ControlModelPatcher):
|
||||
|
||||
def _ensure_control_layers_loaded(self, sd_model, unet):
|
||||
"""Load control weights into transformer if not already loaded"""
|
||||
if self._control_loaded:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get the wrapped transformer
|
||||
wrapped_model = unet.model.diffusion_model
|
||||
transformer = wrapped_model.transformer if hasattr(wrapped_model, 'transformer') else wrapped_model
|
||||
|
||||
# Check if already loaded
|
||||
# Always check if the CURRENT transformer has control layers loaded
|
||||
# (model may have been reloaded, creating a new transformer without control layers)
|
||||
if hasattr(transformer, '_control_layers_loaded') and transformer._control_layers_loaded:
|
||||
self._control_loaded = True
|
||||
return
|
||||
|
||||
Loading…
Reference in New Issue
Block a user