stable-diffusion-webui-forge/backend/diffusion_engine/zimage.py
2025-12-02 15:35:18 -08:00

502 lines
24 KiB
Python

import torch
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
from backend.patcher.clip import CLIP
from backend.patcher.vae import VAE
from backend.patcher.unet import UnetPatcher
from backend.text_processing.qwen_engine import QwenTextProcessingEngine
from backend.args import dynamic_args
from backend.modules.k_prediction import PredictionZImage
from backend import memory_management
# Import control components (lazy loaded when needed)
_zimage_control_module = None
def get_zimage_control_module():
"""Lazy load zimage_control module to avoid circular imports"""
global _zimage_control_module
if _zimage_control_module is None:
from backend.nn import zimage_control
_zimage_control_module = zimage_control
return _zimage_control_module
class ZImageLatentFormat:
"""Latent format for Z-Image models (16-channel latents using FLUX VAE)"""
# Computed via least squares regression from latent/RGB pairs
# Maps 16-channel latents to RGB for cheap preview approximation
latent_rgb_factors = [
[-0.037223, -0.005345, 0.027556],
[ 0.016390, 0.037071, 0.064793],
[ 0.019696, -0.027534, -0.014296],
[-0.006327, 0.014695, 0.036488],
[ 0.068934, 0.055449, 0.033998],
[ 0.013588, 0.014694, 0.000826],
[ 0.074462, 0.109467, 0.114805],
[-0.024411, -0.022535, -0.025943],
[-0.023577, 0.013515, 0.077390],
[ 0.062267, 0.037371, -0.032887],
[-0.052327, -0.016201, -0.019524],
[ 0.068548, 0.035523, 0.014399],
[ 0.014019, 0.010056, 0.018069],
[-0.063674, -0.015887, -0.049793],
[-0.005128, -0.034811, -0.009309],
[-0.064089, -0.038018, -0.024993],
]
def __init__(self):
self.scale_factor = 0.3611 # Z-Image VAE scale factor
self.shift_factor = 0.1159 # Z-Image VAE shift factor
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
class ZImage(ForgeDiffusionEngine):
def __init__(self, components_dict, estimated_config=None):
# Create minimal config if not provided
if estimated_config is None:
class MinimalConfig:
def inpaint_model(self):
return False
estimated_config = MinimalConfig()
super().__init__(estimated_config, components_dict)
self.is_inpaint = False
# Add latent_format to model_config for cheap preview approximation
if not hasattr(self.model_config, 'latent_format'):
self.model_config.latent_format = ZImageLatentFormat()
# Wrap Qwen encoder in CLIP interface
clip = CLIP(
model_dict={
'qwen': components_dict['text_encoder']
},
tokenizer_dict={
'qwen': components_dict['tokenizer']
}
)
vae = VAE(model=components_dict['vae'])
# Ensure VAE latent channels match Transformer input channels
# This is necessary because the VAE config might be inferred incorrectly (e.g. 8 channels)
# while the Transformer expects 16 channels.
transformer_config = components_dict['transformer'].config
transformer_channels = None
if isinstance(transformer_config, dict):
transformer_channels = transformer_config.get('in_channels')
elif hasattr(transformer_config, 'in_channels'):
transformer_channels = transformer_config.in_channels
if transformer_channels is not None:
if vae.latent_channels != transformer_channels:
print(f"Correction: VAE latent_channels ({vae.latent_channels}) != Transformer in_channels ({transformer_channels}). Updating VAE to match Transformer.")
vae.latent_channels = transformer_channels
if hasattr(vae.first_stage_model.config, 'latent_channels'):
vae.first_stage_model.config.latent_channels = transformer_channels
# Set Z-Image specific VAE scaling factors
# Default is 0.18215 / 0.0, but Z-Image needs 0.3611 / 0.1159
vae.first_stage_model.scaling_factor = 0.3611
vae.first_stage_model.shift_factor = 0.1159
if hasattr(vae.first_stage_model.config, 'scaling_factor'):
vae.first_stage_model.config.scaling_factor = 0.3611
if hasattr(vae.first_stage_model.config, 'shift_factor'):
vae.first_stage_model.config.shift_factor = 0.1159
# Wrap the transformer to adapt parameter names
class ZImageTransformerWrapper(torch.nn.Module):
def __init__(self, transformer):
super().__init__()
self.transformer = transformer
def forward(self, x, timestep, context=None, transformer_options=None, **kwargs):
# Check if Z-Image ControlNet is active
control_active = (
transformer_options is not None and
transformer_options.get('z_image_controlnet_active', False)
)
if control_active:
# Use control-enabled forward path
return self._forward_with_control(x, timestep, context, transformer_options, **kwargs)
else:
# Use standard forward path (unchanged behavior)
return self._forward_normal(x, timestep, context, transformer_options, **kwargs)
def _forward_normal(self, x, timestep, context=None, transformer_options=None, **kwargs):
"""Standard forward path - unchanged from original implementation"""
import torch
if not isinstance(x, list):
# Input should be [batch, channels, height, width] (4D)
# Need to add frame dimension and split into list
if len(x.shape) == 4: # [B, C, H, W]
x = x.unsqueeze(2) # [B, C, 1, H, W] - add frame dimension
x = list(x.unbind(dim=0)) # List of [C, 1, H, W]
else:
raise ValueError(f"Unexpected input shape: {x.shape}. Expected 4D tensor [B, C, H, W]")
# Convert context to list format, filtering by attention mask if available
if context is not None and not isinstance(context, list):
attention_mask = None
if transformer_options is not None and 'attention_mask' in transformer_options:
attention_mask = transformer_options['attention_mask']
if len(context.shape) == 3: # [batch, seq_len, features]
if attention_mask is not None:
# Filter by attention mask to create variable-length embeddings (like official)
context_list = []
for i in range(context.shape[0]):
filtered = context[i][attention_mask[i]]
context_list.append(filtered)
context = context_list
else:
context = [context[i] for i in range(context.shape[0])]
elif len(context.shape) == 2: # [seq_len, features] - single item
context = [context]
else:
raise ValueError(f"Unexpected context shape: {context.shape}")
# Invert timestep: Forge uses sigma in [1->0] (noisy->clean)
# but Z-Image expects t in [0->1] (noisy->clean)
timestep = 1.0 - timestep
timestep = torch.clamp(timestep, min=1e-6, max=1.0)
# Ensure pad tokens are on the same device as input
target_device = x[0].device
if hasattr(self.transformer, 'x_pad_token') and self.transformer.x_pad_token.device != target_device:
self.transformer.x_pad_token.data = self.transformer.x_pad_token.data.to(target_device)
if hasattr(self.transformer, 'cap_pad_token') and self.transformer.cap_pad_token.device != target_device:
self.transformer.cap_pad_token.data = self.transformer.cap_pad_token.data.to(target_device)
# Call transformer
result = self.transformer(x=x, t=timestep, cap_feats=context, patch_size=2, f_patch_size=1)
output_list = result[0] if isinstance(result, tuple) else result
# Convert list of [C, F, H, W] tensors back to batched tensor
if isinstance(output_list, list):
output = torch.stack(output_list, dim=0).squeeze(2)
return -output
else:
return -output_list
def _forward_with_control(self, x, timestep, context=None, transformer_options=None, **kwargs):
"""Forward path with ControlNet - matches VideoX-Fun behavior"""
import torch
from torch.nn.utils.rnn import pad_sequence
# Get control parameters from transformer_options
control_context = transformer_options.get('control_context') # [B, C, F, H, W]
control_context_scale = transformer_options.get('control_context_scale', 1.0)
# Convert x to list format
if not isinstance(x, list):
if len(x.shape) == 4: # [B, C, H, W]
x_5d = x.unsqueeze(2) # [B, C, 1, H, W]
x_list = list(x_5d.unbind(dim=0)) # List of [C, 1, H, W]
else:
raise ValueError(f"Unexpected input shape: {x.shape}")
else:
x_list = x
# Convert context to list format
if context is not None and not isinstance(context, list):
attention_mask = None
if transformer_options is not None and 'attention_mask' in transformer_options:
attention_mask = transformer_options['attention_mask']
if len(context.shape) == 3: # [batch, seq_len, features]
if attention_mask is not None:
context_list = []
for i in range(context.shape[0]):
filtered = context[i][attention_mask[i]]
context_list.append(filtered)
context = context_list
else:
context = [context[i] for i in range(context.shape[0])]
elif len(context.shape) == 2:
context = [context]
else:
context = context if isinstance(context, list) else [context]
# Invert timestep
timestep = 1.0 - timestep
timestep = torch.clamp(timestep, min=1e-6, max=1.0)
# Ensure pad tokens are on the same device
target_device = x_list[0].device
if hasattr(self.transformer, 'x_pad_token') and self.transformer.x_pad_token.device != target_device:
self.transformer.x_pad_token.data = self.transformer.x_pad_token.data.to(target_device)
if hasattr(self.transformer, 'cap_pad_token') and self.transformer.cap_pad_token.device != target_device:
self.transformer.cap_pad_token.data = self.transformer.cap_pad_token.data.to(target_device)
# Also move control components if they exist
if hasattr(self.transformer, 'control_x_pad_token') and self.transformer.control_x_pad_token.device != target_device:
self.transformer.control_x_pad_token.data = self.transformer.control_x_pad_token.data.to(target_device)
# Prepare control context as list
if control_context is not None:
control_context = control_context.to(target_device)
if control_context.dim() == 5: # [B, C, F, H, W]
control_context_list = list(control_context.unbind(0)) # List of [C, F, H, W]
elif control_context.dim() == 4: # [B, C, H, W]
# Add frame dimension
control_context = control_context.unsqueeze(2) # [B, C, 1, H, W]
control_context_list = list(control_context.unbind(0))
else:
control_context_list = [control_context]
else:
control_context_list = None
# Check if transformer has control layers loaded
if not (hasattr(self.transformer, '_control_layers_loaded') and self.transformer._control_layers_loaded):
print("WARNING: ControlNet active but control layers not loaded, falling back to normal forward")
result = self.transformer(x=x_list, t=timestep, cap_feats=context, patch_size=2, f_patch_size=1)
output_list = result[0] if isinstance(result, tuple) else result
if isinstance(output_list, list):
output = torch.stack(output_list, dim=0).squeeze(2)
return -output
return -output_list
# === Control-enabled forward path ===
# This mirrors VideoX-Fun's ZImageControlTransformer2DModel.forward
patch_size = 2
f_patch_size = 1
bsz = len(x_list)
device = x_list[0].device
# Get timestep embedding
t_scaled = timestep * self.transformer.t_scale
t_emb = self.transformer.t_embedder(t_scaled)
# Patchify and embed
(
x_patches,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_inner_pad_mask,
cap_inner_pad_mask,
) = self.transformer.patchify_and_embed(x_list, context, patch_size, f_patch_size)
# Process x through embedder and refiner
SEQ_MULTI_OF = 32
x_item_seqlens = [len(_) for _ in x_patches]
x_max_item_seqlen = max(x_item_seqlens)
x_cat = torch.cat(x_patches, dim=0)
x_cat = self.transformer.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_cat)
adaln_input = t_emb.type_as(x_cat)
x_cat[torch.cat(x_inner_pad_mask)] = self.transformer.x_pad_token
x_split = list(x_cat.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
x_padded = pad_sequence(x_split, batch_first=True, padding_value=0.0)
x_freqs_cis_padded = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
# Noise refiner
for layer in self.transformer.noise_refiner:
x_padded = layer(x_padded, x_attn_mask, x_freqs_cis_padded, adaln_input)
# Process caption features
cap_item_seqlens = [len(_) for _ in cap_feats]
cap_max_item_seqlen = max(cap_item_seqlens)
cap_cat = torch.cat(cap_feats, dim=0)
cap_cat = self.transformer.cap_embedder(cap_cat)
cap_cat[torch.cat(cap_inner_pad_mask)] = self.transformer.cap_pad_token
cap_split = list(cap_cat.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(self.transformer.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
cap_padded = pad_sequence(cap_split, batch_first=True, padding_value=0.0)
cap_freqs_cis_padded = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
# Context refiner
for layer in self.transformer.context_refiner:
cap_padded = layer(cap_padded, cap_attn_mask, cap_freqs_cis_padded)
# Unify x and caption
unified = []
unified_freqs_cis = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x_padded[i][:x_len], cap_padded[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis_padded[i][:x_len], cap_freqs_cis_padded[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
unified_max_item_seqlen = max(unified_item_seqlens)
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
# Generate control hints
hints = None
if control_context_list is not None and hasattr(self.transformer, 'forward_control'):
# Build kwargs exactly like VideoX-Fun
kwargs_for_control = dict(
attn_mask=unified_attn_mask,
freqs_cis=unified_freqs_cis,
adaln_input=adaln_input,
)
# Pass cap_padded (TENSOR, not list) - matches VideoX-Fun exactly
hints = self.transformer.forward_control(
unified, cap_padded, control_context_list, kwargs_for_control,
t=t_emb, patch_size=patch_size, f_patch_size=f_patch_size
)
# Forward through main layers with hints
for layer in self.transformer.layers:
unified = layer(
unified,
attn_mask=unified_attn_mask,
freqs_cis=unified_freqs_cis,
adaln_input=adaln_input,
hints=hints,
context_scale=control_context_scale
)
# Final layer
unified = self.transformer.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
unified = list(unified.unbind(dim=0))
output_list = self.transformer.unpatchify(unified, x_size, patch_size, f_patch_size)
# Convert output to tensor
output = torch.stack(output_list, dim=0) # [B, C, F, H, W]
output = output.squeeze(2) # [B, C, H, W]
return -output
def __getattr__(self, name):
# Pass through all other attributes to the underlying transformer
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.transformer, name)
wrapped_transformer = ZImageTransformerWrapper(components_dict['transformer'])
# Z-Image uses static shift=3.0 (from scheduler config)
# This matches the formula: sigmas = shift * t / (1 + (shift-1) * t)
k_predictor = PredictionZImage(
shift=3.0,
timesteps=1000
)
# Create config object for Z-Image identification (used by LoRA loader)
class ZImageModelConfig:
is_zimage = True
huggingface_repo = 'Z-Image'
unet = UnetPatcher.from_model(
model=wrapped_transformer,
diffusers_scheduler=components_dict['scheduler'],
k_predictor=k_predictor,
config=ZImageModelConfig()
)
self.text_processing_engine_qwen = QwenTextProcessingEngine(
text_encoder=clip.cond_stage_model.qwen,
tokenizer=clip.tokenizer.qwen,
emphasis_name=dynamic_args['emphasis_name'],
min_length=1
)
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
self.forge_objects_original = self.forge_objects.shallow_copy()
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
def set_clip_skip(self, clip_skip):
pass
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
# Process prompts through Qwen text encoder
# Format prompts with chat template like official implementation
tokenizer = self.text_processing_engine_qwen.tokenizer
text_encoder = self.text_processing_engine_qwen.text_encoder
formatted_prompts = []
for prompt_item in prompt:
messages = [{"role": "user", "content": prompt_item}]
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
formatted_prompts.append(formatted_prompt)
# Tokenize
text_inputs = tokenizer(
formatted_prompts,
padding="max_length",
max_length=512,
truncation=True,
return_tensors="pt",
)
# Move to device
device = text_encoder.device
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
# Encode (use hidden_states[-2] like official implementation)
outputs = text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
)
prompt_embeds = outputs.hidden_states[-2]
# Filter by attention mask to get variable-length embeddings (official approach)
# This removes padding tokens, creating a list of 2D tensors with different lengths
embeddings_list = []
for i in range(len(prompt_embeds)):
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
# However, for Forge backend compatibility, we need to return a batched tensor
# So we'll return the full padded embeddings with attention mask
# The wrapper will handle splitting into list format
return {'crossattn': prompt_embeds, 'attention_mask': prompt_masks}
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
token_count = len(self.text_processing_engine_qwen.tokenize([prompt])[0])
return token_count, max(512, token_count)
@torch.inference_mode()
def encode_first_stage(self, x):
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
return sample.to(x)
@torch.inference_mode()
def decode_first_stage(self, x):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
decoded = self.forge_objects.vae.decode(sample)
# VAE outputs [0, 1], convert to [-1, 1]
result = decoded.movedim(-1, 1) * 2.0 - 1.0
return result.to(x)