diff --git a/backend/loader.py b/backend/loader.py index 89ebb67a..bc6e8329 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -1,4 +1,5 @@ import os +import re import torch import logging import importlib @@ -6,6 +7,83 @@ import importlib import backend.args import huggingface_guess + +def convert_comfy_zimage_state_dict(state_dict): + """ + Convert ComfyUI Z-Image state dict format to Diffusers format. + Only applies if ComfyUI format is detected, otherwise returns unchanged. + + Key differences: + - x_embedder -> all_x_embedder.2-1 + - final_layer -> all_final_layer.2-1 + - Fused qkv.weight -> separate to_q, to_k, to_v weights + - q_norm/k_norm -> norm_q/norm_k + - out.weight -> to_out.0.weight + """ + # Detect ComfyUI format + if 'x_embedder.weight' not in state_dict or 'all_x_embedder.2-1.weight' in state_dict: + return state_dict # Already in diffusers format or unknown format + + print("[Z-Image] Detected ComfyUI format, converting to Diffusers format...") + + new_state_dict = {} + converted_count = 0 + qkv_split_count = 0 + + # Pattern for fused QKV weights in attention blocks + qkv_pattern = re.compile(r'^(noise_refiner\.\d+|context_refiner\.\d+|layers\.\d+)\.attention\.qkv\.weight$') + + for key, value in state_dict.items(): + new_key = key + + # 1. Convert embedder names + if key == 'x_embedder.weight': + new_key = 'all_x_embedder.2-1.weight' + converted_count += 1 + elif key == 'x_embedder.bias': + new_key = 'all_x_embedder.2-1.bias' + converted_count += 1 + + # 2. Convert final_layer names + elif key.startswith('final_layer.'): + new_key = key.replace('final_layer.', 'all_final_layer.2-1.') + converted_count += 1 + + # 3. Convert q_norm/k_norm to norm_q/norm_k + elif '.attention.q_norm.weight' in key: + new_key = key.replace('.attention.q_norm.weight', '.attention.norm_q.weight') + converted_count += 1 + elif '.attention.k_norm.weight' in key: + new_key = key.replace('.attention.k_norm.weight', '.attention.norm_k.weight') + converted_count += 1 + + # 4. Convert out.weight to to_out.0.weight + elif '.attention.out.weight' in key: + new_key = key.replace('.attention.out.weight', '.attention.to_out.0.weight') + converted_count += 1 + + # 5. Split fused QKV weights into separate Q, K, V + elif qkv_pattern.match(key): + prefix = key.replace('.qkv.weight', '') + + # QKV is fused as [Q, K, V] along dim 0 + # Shape is [3 * hidden_dim, hidden_dim] = [11520, 3840] + qkv_weight = value + q_weight, k_weight, v_weight = torch.chunk(qkv_weight, 3, dim=0) + + new_state_dict[f'{prefix}.to_q.weight'] = q_weight + new_state_dict[f'{prefix}.to_k.weight'] = k_weight + new_state_dict[f'{prefix}.to_v.weight'] = v_weight + + qkv_split_count += 1 + continue # Don't add the original qkv key + + new_state_dict[new_key] = value + + print(f"[Z-Image] Converted {converted_count} keys, split {qkv_split_count} fused QKV weights") + + return new_state_dict + from diffusers import DiffusionPipeline from transformers import modeling_utils @@ -252,6 +330,10 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p with using_forge_operations(**to_args, manual_cast_enabled=need_manual_cast): model = model_loader(unet_config).to(**to_args) + # Convert ComfyUI Z-Image format to Diffusers format if needed + if cls_name == 'ZImageTransformer2DModel': + state_dict = convert_comfy_zimage_state_dict(state_dict) + load_state_dict(model, state_dict) if hasattr(model, '_internal_dict'):