mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-30 21:09:33 +08:00
502 lines
24 KiB
Python
502 lines
24 KiB
Python
import torch
|
|
|
|
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
|
from backend.patcher.clip import CLIP
|
|
from backend.patcher.vae import VAE
|
|
from backend.patcher.unet import UnetPatcher
|
|
from backend.text_processing.qwen_engine import QwenTextProcessingEngine
|
|
from backend.args import dynamic_args
|
|
from backend.modules.k_prediction import PredictionZImage
|
|
from backend import memory_management
|
|
|
|
# Import control components (lazy loaded when needed)
|
|
_zimage_control_module = None
|
|
|
|
def get_zimage_control_module():
|
|
"""Lazy load zimage_control module to avoid circular imports"""
|
|
global _zimage_control_module
|
|
if _zimage_control_module is None:
|
|
from backend.nn import zimage_control
|
|
_zimage_control_module = zimage_control
|
|
return _zimage_control_module
|
|
|
|
|
|
class ZImageLatentFormat:
|
|
"""Latent format for Z-Image models (16-channel latents using FLUX VAE)"""
|
|
# Computed via least squares regression from latent/RGB pairs
|
|
# Maps 16-channel latents to RGB for cheap preview approximation
|
|
latent_rgb_factors = [
|
|
[-0.037223, -0.005345, 0.027556],
|
|
[ 0.016390, 0.037071, 0.064793],
|
|
[ 0.019696, -0.027534, -0.014296],
|
|
[-0.006327, 0.014695, 0.036488],
|
|
[ 0.068934, 0.055449, 0.033998],
|
|
[ 0.013588, 0.014694, 0.000826],
|
|
[ 0.074462, 0.109467, 0.114805],
|
|
[-0.024411, -0.022535, -0.025943],
|
|
[-0.023577, 0.013515, 0.077390],
|
|
[ 0.062267, 0.037371, -0.032887],
|
|
[-0.052327, -0.016201, -0.019524],
|
|
[ 0.068548, 0.035523, 0.014399],
|
|
[ 0.014019, 0.010056, 0.018069],
|
|
[-0.063674, -0.015887, -0.049793],
|
|
[-0.005128, -0.034811, -0.009309],
|
|
[-0.064089, -0.038018, -0.024993],
|
|
]
|
|
|
|
def __init__(self):
|
|
self.scale_factor = 0.3611 # Z-Image VAE scale factor
|
|
self.shift_factor = 0.1159 # Z-Image VAE shift factor
|
|
|
|
def process_in(self, latent):
|
|
return (latent - self.shift_factor) * self.scale_factor
|
|
|
|
def process_out(self, latent):
|
|
return (latent / self.scale_factor) + self.shift_factor
|
|
|
|
|
|
class ZImage(ForgeDiffusionEngine):
|
|
def __init__(self, components_dict, estimated_config=None):
|
|
# Create minimal config if not provided
|
|
if estimated_config is None:
|
|
class MinimalConfig:
|
|
def inpaint_model(self):
|
|
return False
|
|
estimated_config = MinimalConfig()
|
|
|
|
super().__init__(estimated_config, components_dict)
|
|
self.is_inpaint = False
|
|
|
|
# Add latent_format to model_config for cheap preview approximation
|
|
if not hasattr(self.model_config, 'latent_format'):
|
|
self.model_config.latent_format = ZImageLatentFormat()
|
|
|
|
# Wrap Qwen encoder in CLIP interface
|
|
clip = CLIP(
|
|
model_dict={
|
|
'qwen': components_dict['text_encoder']
|
|
},
|
|
tokenizer_dict={
|
|
'qwen': components_dict['tokenizer']
|
|
}
|
|
)
|
|
|
|
vae = VAE(model=components_dict['vae'])
|
|
|
|
# Ensure VAE latent channels match Transformer input channels
|
|
# This is necessary because the VAE config might be inferred incorrectly (e.g. 8 channels)
|
|
# while the Transformer expects 16 channels.
|
|
transformer_config = components_dict['transformer'].config
|
|
transformer_channels = None
|
|
|
|
if isinstance(transformer_config, dict):
|
|
transformer_channels = transformer_config.get('in_channels')
|
|
elif hasattr(transformer_config, 'in_channels'):
|
|
transformer_channels = transformer_config.in_channels
|
|
|
|
if transformer_channels is not None:
|
|
if vae.latent_channels != transformer_channels:
|
|
print(f"Correction: VAE latent_channels ({vae.latent_channels}) != Transformer in_channels ({transformer_channels}). Updating VAE to match Transformer.")
|
|
vae.latent_channels = transformer_channels
|
|
if hasattr(vae.first_stage_model.config, 'latent_channels'):
|
|
vae.first_stage_model.config.latent_channels = transformer_channels
|
|
|
|
# Set Z-Image specific VAE scaling factors
|
|
# Default is 0.18215 / 0.0, but Z-Image needs 0.3611 / 0.1159
|
|
vae.first_stage_model.scaling_factor = 0.3611
|
|
vae.first_stage_model.shift_factor = 0.1159
|
|
if hasattr(vae.first_stage_model.config, 'scaling_factor'):
|
|
vae.first_stage_model.config.scaling_factor = 0.3611
|
|
if hasattr(vae.first_stage_model.config, 'shift_factor'):
|
|
vae.first_stage_model.config.shift_factor = 0.1159
|
|
|
|
# Wrap the transformer to adapt parameter names
|
|
class ZImageTransformerWrapper(torch.nn.Module):
|
|
def __init__(self, transformer):
|
|
super().__init__()
|
|
self.transformer = transformer
|
|
|
|
def forward(self, x, timestep, context=None, transformer_options=None, **kwargs):
|
|
# Check if Z-Image ControlNet is active
|
|
control_active = (
|
|
transformer_options is not None and
|
|
transformer_options.get('z_image_controlnet_active', False)
|
|
)
|
|
|
|
if control_active:
|
|
# Use control-enabled forward path
|
|
return self._forward_with_control(x, timestep, context, transformer_options, **kwargs)
|
|
else:
|
|
# Use standard forward path (unchanged behavior)
|
|
return self._forward_normal(x, timestep, context, transformer_options, **kwargs)
|
|
|
|
def _forward_normal(self, x, timestep, context=None, transformer_options=None, **kwargs):
|
|
"""Standard forward path - unchanged from original implementation"""
|
|
import torch
|
|
|
|
if not isinstance(x, list):
|
|
# Input should be [batch, channels, height, width] (4D)
|
|
# Need to add frame dimension and split into list
|
|
if len(x.shape) == 4: # [B, C, H, W]
|
|
x = x.unsqueeze(2) # [B, C, 1, H, W] - add frame dimension
|
|
x = list(x.unbind(dim=0)) # List of [C, 1, H, W]
|
|
else:
|
|
raise ValueError(f"Unexpected input shape: {x.shape}. Expected 4D tensor [B, C, H, W]")
|
|
|
|
# Convert context to list format, filtering by attention mask if available
|
|
if context is not None and not isinstance(context, list):
|
|
attention_mask = None
|
|
if transformer_options is not None and 'attention_mask' in transformer_options:
|
|
attention_mask = transformer_options['attention_mask']
|
|
|
|
if len(context.shape) == 3: # [batch, seq_len, features]
|
|
if attention_mask is not None:
|
|
# Filter by attention mask to create variable-length embeddings (like official)
|
|
context_list = []
|
|
for i in range(context.shape[0]):
|
|
filtered = context[i][attention_mask[i]]
|
|
context_list.append(filtered)
|
|
context = context_list
|
|
else:
|
|
context = [context[i] for i in range(context.shape[0])]
|
|
elif len(context.shape) == 2: # [seq_len, features] - single item
|
|
context = [context]
|
|
else:
|
|
raise ValueError(f"Unexpected context shape: {context.shape}")
|
|
|
|
# Invert timestep: Forge uses sigma in [1->0] (noisy->clean)
|
|
# but Z-Image expects t in [0->1] (noisy->clean)
|
|
timestep = 1.0 - timestep
|
|
timestep = torch.clamp(timestep, min=1e-6, max=1.0)
|
|
|
|
# Ensure pad tokens are on the same device as input
|
|
target_device = x[0].device
|
|
if hasattr(self.transformer, 'x_pad_token') and self.transformer.x_pad_token.device != target_device:
|
|
self.transformer.x_pad_token.data = self.transformer.x_pad_token.data.to(target_device)
|
|
if hasattr(self.transformer, 'cap_pad_token') and self.transformer.cap_pad_token.device != target_device:
|
|
self.transformer.cap_pad_token.data = self.transformer.cap_pad_token.data.to(target_device)
|
|
|
|
# Call transformer
|
|
result = self.transformer(x=x, t=timestep, cap_feats=context, patch_size=2, f_patch_size=1)
|
|
output_list = result[0] if isinstance(result, tuple) else result
|
|
|
|
# Convert list of [C, F, H, W] tensors back to batched tensor
|
|
if isinstance(output_list, list):
|
|
output = torch.stack(output_list, dim=0).squeeze(2)
|
|
return -output
|
|
else:
|
|
return -output_list
|
|
|
|
def _forward_with_control(self, x, timestep, context=None, transformer_options=None, **kwargs):
|
|
"""Forward path with ControlNet - matches VideoX-Fun behavior"""
|
|
import torch
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
# Get control parameters from transformer_options
|
|
control_context = transformer_options.get('control_context') # [B, C, F, H, W]
|
|
control_context_scale = transformer_options.get('control_context_scale', 1.0)
|
|
|
|
# Convert x to list format
|
|
if not isinstance(x, list):
|
|
if len(x.shape) == 4: # [B, C, H, W]
|
|
x_5d = x.unsqueeze(2) # [B, C, 1, H, W]
|
|
x_list = list(x_5d.unbind(dim=0)) # List of [C, 1, H, W]
|
|
else:
|
|
raise ValueError(f"Unexpected input shape: {x.shape}")
|
|
else:
|
|
x_list = x
|
|
|
|
# Convert context to list format
|
|
if context is not None and not isinstance(context, list):
|
|
attention_mask = None
|
|
if transformer_options is not None and 'attention_mask' in transformer_options:
|
|
attention_mask = transformer_options['attention_mask']
|
|
|
|
if len(context.shape) == 3: # [batch, seq_len, features]
|
|
if attention_mask is not None:
|
|
context_list = []
|
|
for i in range(context.shape[0]):
|
|
filtered = context[i][attention_mask[i]]
|
|
context_list.append(filtered)
|
|
context = context_list
|
|
else:
|
|
context = [context[i] for i in range(context.shape[0])]
|
|
elif len(context.shape) == 2:
|
|
context = [context]
|
|
else:
|
|
context = context if isinstance(context, list) else [context]
|
|
|
|
# Invert timestep
|
|
timestep = 1.0 - timestep
|
|
timestep = torch.clamp(timestep, min=1e-6, max=1.0)
|
|
|
|
# Ensure pad tokens are on the same device
|
|
target_device = x_list[0].device
|
|
if hasattr(self.transformer, 'x_pad_token') and self.transformer.x_pad_token.device != target_device:
|
|
self.transformer.x_pad_token.data = self.transformer.x_pad_token.data.to(target_device)
|
|
if hasattr(self.transformer, 'cap_pad_token') and self.transformer.cap_pad_token.device != target_device:
|
|
self.transformer.cap_pad_token.data = self.transformer.cap_pad_token.data.to(target_device)
|
|
|
|
# Also move control components if they exist
|
|
if hasattr(self.transformer, 'control_x_pad_token') and self.transformer.control_x_pad_token.device != target_device:
|
|
self.transformer.control_x_pad_token.data = self.transformer.control_x_pad_token.data.to(target_device)
|
|
|
|
# Prepare control context as list
|
|
if control_context is not None:
|
|
control_context = control_context.to(target_device)
|
|
if control_context.dim() == 5: # [B, C, F, H, W]
|
|
control_context_list = list(control_context.unbind(0)) # List of [C, F, H, W]
|
|
elif control_context.dim() == 4: # [B, C, H, W]
|
|
# Add frame dimension
|
|
control_context = control_context.unsqueeze(2) # [B, C, 1, H, W]
|
|
control_context_list = list(control_context.unbind(0))
|
|
else:
|
|
control_context_list = [control_context]
|
|
else:
|
|
control_context_list = None
|
|
|
|
# Check if transformer has control layers loaded
|
|
if not (hasattr(self.transformer, '_control_layers_loaded') and self.transformer._control_layers_loaded):
|
|
print("WARNING: ControlNet active but control layers not loaded, falling back to normal forward")
|
|
result = self.transformer(x=x_list, t=timestep, cap_feats=context, patch_size=2, f_patch_size=1)
|
|
output_list = result[0] if isinstance(result, tuple) else result
|
|
if isinstance(output_list, list):
|
|
output = torch.stack(output_list, dim=0).squeeze(2)
|
|
return -output
|
|
return -output_list
|
|
|
|
# === Control-enabled forward path ===
|
|
# This mirrors VideoX-Fun's ZImageControlTransformer2DModel.forward
|
|
|
|
patch_size = 2
|
|
f_patch_size = 1
|
|
bsz = len(x_list)
|
|
device = x_list[0].device
|
|
|
|
# Get timestep embedding
|
|
t_scaled = timestep * self.transformer.t_scale
|
|
t_emb = self.transformer.t_embedder(t_scaled)
|
|
|
|
# Patchify and embed
|
|
(
|
|
x_patches,
|
|
cap_feats,
|
|
x_size,
|
|
x_pos_ids,
|
|
cap_pos_ids,
|
|
x_inner_pad_mask,
|
|
cap_inner_pad_mask,
|
|
) = self.transformer.patchify_and_embed(x_list, context, patch_size, f_patch_size)
|
|
|
|
# Process x through embedder and refiner
|
|
SEQ_MULTI_OF = 32
|
|
x_item_seqlens = [len(_) for _ in x_patches]
|
|
x_max_item_seqlen = max(x_item_seqlens)
|
|
|
|
x_cat = torch.cat(x_patches, dim=0)
|
|
x_cat = self.transformer.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_cat)
|
|
|
|
adaln_input = t_emb.type_as(x_cat)
|
|
x_cat[torch.cat(x_inner_pad_mask)] = self.transformer.x_pad_token
|
|
x_split = list(x_cat.split(x_item_seqlens, dim=0))
|
|
x_freqs_cis = list(self.transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
|
|
|
x_padded = pad_sequence(x_split, batch_first=True, padding_value=0.0)
|
|
x_freqs_cis_padded = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
|
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
|
for i, seq_len in enumerate(x_item_seqlens):
|
|
x_attn_mask[i, :seq_len] = 1
|
|
|
|
# Noise refiner
|
|
for layer in self.transformer.noise_refiner:
|
|
x_padded = layer(x_padded, x_attn_mask, x_freqs_cis_padded, adaln_input)
|
|
|
|
# Process caption features
|
|
cap_item_seqlens = [len(_) for _ in cap_feats]
|
|
cap_max_item_seqlen = max(cap_item_seqlens)
|
|
|
|
cap_cat = torch.cat(cap_feats, dim=0)
|
|
cap_cat = self.transformer.cap_embedder(cap_cat)
|
|
cap_cat[torch.cat(cap_inner_pad_mask)] = self.transformer.cap_pad_token
|
|
cap_split = list(cap_cat.split(cap_item_seqlens, dim=0))
|
|
cap_freqs_cis = list(self.transformer.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
|
|
|
cap_padded = pad_sequence(cap_split, batch_first=True, padding_value=0.0)
|
|
cap_freqs_cis_padded = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
|
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
|
for i, seq_len in enumerate(cap_item_seqlens):
|
|
cap_attn_mask[i, :seq_len] = 1
|
|
|
|
# Context refiner
|
|
for layer in self.transformer.context_refiner:
|
|
cap_padded = layer(cap_padded, cap_attn_mask, cap_freqs_cis_padded)
|
|
|
|
# Unify x and caption
|
|
unified = []
|
|
unified_freqs_cis = []
|
|
for i in range(bsz):
|
|
x_len = x_item_seqlens[i]
|
|
cap_len = cap_item_seqlens[i]
|
|
unified.append(torch.cat([x_padded[i][:x_len], cap_padded[i][:cap_len]]))
|
|
unified_freqs_cis.append(torch.cat([x_freqs_cis_padded[i][:x_len], cap_freqs_cis_padded[i][:cap_len]]))
|
|
|
|
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
|
unified_max_item_seqlen = max(unified_item_seqlens)
|
|
|
|
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
|
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
|
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
|
for i, seq_len in enumerate(unified_item_seqlens):
|
|
unified_attn_mask[i, :seq_len] = 1
|
|
|
|
# Generate control hints
|
|
hints = None
|
|
if control_context_list is not None and hasattr(self.transformer, 'forward_control'):
|
|
# Build kwargs exactly like VideoX-Fun
|
|
kwargs_for_control = dict(
|
|
attn_mask=unified_attn_mask,
|
|
freqs_cis=unified_freqs_cis,
|
|
adaln_input=adaln_input,
|
|
)
|
|
# Pass cap_padded (TENSOR, not list) - matches VideoX-Fun exactly
|
|
hints = self.transformer.forward_control(
|
|
unified, cap_padded, control_context_list, kwargs_for_control,
|
|
t=t_emb, patch_size=patch_size, f_patch_size=f_patch_size
|
|
)
|
|
|
|
# Forward through main layers with hints
|
|
for layer in self.transformer.layers:
|
|
unified = layer(
|
|
unified,
|
|
attn_mask=unified_attn_mask,
|
|
freqs_cis=unified_freqs_cis,
|
|
adaln_input=adaln_input,
|
|
hints=hints,
|
|
context_scale=control_context_scale
|
|
)
|
|
|
|
# Final layer
|
|
unified = self.transformer.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
|
unified = list(unified.unbind(dim=0))
|
|
output_list = self.transformer.unpatchify(unified, x_size, patch_size, f_patch_size)
|
|
|
|
# Convert output to tensor
|
|
output = torch.stack(output_list, dim=0) # [B, C, F, H, W]
|
|
output = output.squeeze(2) # [B, C, H, W]
|
|
|
|
return -output
|
|
|
|
def __getattr__(self, name):
|
|
# Pass through all other attributes to the underlying transformer
|
|
try:
|
|
return super().__getattr__(name)
|
|
except AttributeError:
|
|
return getattr(self.transformer, name)
|
|
|
|
wrapped_transformer = ZImageTransformerWrapper(components_dict['transformer'])
|
|
|
|
# Z-Image uses static shift=3.0 (from scheduler config)
|
|
# This matches the formula: sigmas = shift * t / (1 + (shift-1) * t)
|
|
k_predictor = PredictionZImage(
|
|
shift=3.0,
|
|
timesteps=1000
|
|
)
|
|
|
|
# Create config object for Z-Image identification (used by LoRA loader)
|
|
class ZImageModelConfig:
|
|
is_zimage = True
|
|
huggingface_repo = 'Z-Image'
|
|
|
|
unet = UnetPatcher.from_model(
|
|
model=wrapped_transformer,
|
|
diffusers_scheduler=components_dict['scheduler'],
|
|
k_predictor=k_predictor,
|
|
config=ZImageModelConfig()
|
|
)
|
|
|
|
self.text_processing_engine_qwen = QwenTextProcessingEngine(
|
|
text_encoder=clip.cond_stage_model.qwen,
|
|
tokenizer=clip.tokenizer.qwen,
|
|
emphasis_name=dynamic_args['emphasis_name'],
|
|
min_length=1
|
|
)
|
|
|
|
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
|
self.forge_objects_original = self.forge_objects.shallow_copy()
|
|
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
|
|
|
def set_clip_skip(self, clip_skip):
|
|
pass
|
|
|
|
@torch.inference_mode()
|
|
def get_learned_conditioning(self, prompt: list[str]):
|
|
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
|
|
|
# Process prompts through Qwen text encoder
|
|
# Format prompts with chat template like official implementation
|
|
tokenizer = self.text_processing_engine_qwen.tokenizer
|
|
text_encoder = self.text_processing_engine_qwen.text_encoder
|
|
|
|
formatted_prompts = []
|
|
for prompt_item in prompt:
|
|
messages = [{"role": "user", "content": prompt_item}]
|
|
formatted_prompt = tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
enable_thinking=True,
|
|
)
|
|
formatted_prompts.append(formatted_prompt)
|
|
|
|
# Tokenize
|
|
text_inputs = tokenizer(
|
|
formatted_prompts,
|
|
padding="max_length",
|
|
max_length=512,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
# Move to device
|
|
device = text_encoder.device
|
|
text_input_ids = text_inputs.input_ids.to(device)
|
|
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
|
|
|
# Encode (use hidden_states[-2] like official implementation)
|
|
outputs = text_encoder(
|
|
input_ids=text_input_ids,
|
|
attention_mask=prompt_masks,
|
|
output_hidden_states=True,
|
|
)
|
|
prompt_embeds = outputs.hidden_states[-2]
|
|
|
|
# Filter by attention mask to get variable-length embeddings (official approach)
|
|
# This removes padding tokens, creating a list of 2D tensors with different lengths
|
|
embeddings_list = []
|
|
for i in range(len(prompt_embeds)):
|
|
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
|
|
|
# However, for Forge backend compatibility, we need to return a batched tensor
|
|
# So we'll return the full padded embeddings with attention mask
|
|
# The wrapper will handle splitting into list format
|
|
return {'crossattn': prompt_embeds, 'attention_mask': prompt_masks}
|
|
|
|
@torch.inference_mode()
|
|
def get_prompt_lengths_on_ui(self, prompt):
|
|
token_count = len(self.text_processing_engine_qwen.tokenize([prompt])[0])
|
|
return token_count, max(512, token_count)
|
|
|
|
@torch.inference_mode()
|
|
def encode_first_stage(self, x):
|
|
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
|
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
|
return sample.to(x)
|
|
|
|
@torch.inference_mode()
|
|
def decode_first_stage(self, x):
|
|
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
|
decoded = self.forge_objects.vae.decode(sample)
|
|
# VAE outputs [0, 1], convert to [-1, 1]
|
|
result = decoded.movedim(-1, 1) * 2.0 - 1.0
|
|
return result.to(x)
|