mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-04 21:05:48 +08:00
962 lines
45 KiB
Python
962 lines
45 KiB
Python
import os
|
|
import re
|
|
import torch
|
|
import logging
|
|
import importlib
|
|
|
|
import backend.args
|
|
import huggingface_guess
|
|
|
|
|
|
def patch_zimage_for_fp16(model):
|
|
import torch.nn.functional as F
|
|
from diffusers.models.transformers.transformer_z_image import FeedForward, ZImageTransformerBlock
|
|
|
|
def clamp_fp16(x):
|
|
if x.dtype == torch.float16:
|
|
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
|
return x
|
|
|
|
def patched_forward_silu_gating(self, x1, x3):
|
|
return clamp_fp16(F.silu(x1) * x3)
|
|
|
|
for module in model.modules():
|
|
if isinstance(module, FeedForward):
|
|
module._forward_silu_gating = patched_forward_silu_gating.__get__(module, FeedForward)
|
|
|
|
original_block_forward = ZImageTransformerBlock.forward
|
|
|
|
def patched_block_forward(self, x, attn_mask, freqs_cis, adaln_input=None):
|
|
if self.modulation:
|
|
assert adaln_input is not None
|
|
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
|
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
|
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
|
attn_out = self.attention(
|
|
self.attention_norm1(x) * scale_msa,
|
|
attention_mask=attn_mask,
|
|
freqs_cis=freqs_cis,
|
|
)
|
|
x = x + gate_msa * self.attention_norm2(clamp_fp16(attn_out))
|
|
x = x + gate_mlp * self.ffn_norm2(clamp_fp16(self.feed_forward(self.ffn_norm1(x) * scale_mlp)))
|
|
else:
|
|
attn_out = self.attention(
|
|
self.attention_norm1(x),
|
|
attention_mask=attn_mask,
|
|
freqs_cis=freqs_cis,
|
|
)
|
|
x = x + self.attention_norm2(clamp_fp16(attn_out))
|
|
x = x + self.ffn_norm2(clamp_fp16(self.feed_forward(self.ffn_norm1(x))))
|
|
return x
|
|
|
|
for module in model.modules():
|
|
if isinstance(module, ZImageTransformerBlock):
|
|
module.forward = patched_block_forward.__get__(module, ZImageTransformerBlock)
|
|
|
|
|
|
def convert_comfy_zimage_state_dict(state_dict):
|
|
"""
|
|
Convert ComfyUI Z-Image state dict format to Diffusers format.
|
|
Only applies if ComfyUI format is detected, otherwise returns unchanged.
|
|
|
|
Key differences:
|
|
- x_embedder -> all_x_embedder.2-1
|
|
- final_layer -> all_final_layer.2-1
|
|
- Fused qkv.weight -> separate to_q, to_k, to_v weights
|
|
- q_norm/k_norm -> norm_q/norm_k
|
|
- out.weight -> to_out.0.weight
|
|
"""
|
|
# Detect ComfyUI format
|
|
if 'x_embedder.weight' not in state_dict or 'all_x_embedder.2-1.weight' in state_dict:
|
|
return state_dict # Already in diffusers format or unknown format
|
|
|
|
print("[Z-Image] Detected ComfyUI format, converting to Diffusers format...")
|
|
|
|
new_state_dict = {}
|
|
converted_count = 0
|
|
qkv_split_count = 0
|
|
|
|
# Pattern for fused QKV weights in attention blocks
|
|
qkv_pattern = re.compile(r'^(noise_refiner\.\d+|context_refiner\.\d+|layers\.\d+)\.attention\.qkv\.weight$')
|
|
|
|
for key, value in state_dict.items():
|
|
new_key = key
|
|
|
|
# 1. Convert embedder names
|
|
if key == 'x_embedder.weight':
|
|
new_key = 'all_x_embedder.2-1.weight'
|
|
converted_count += 1
|
|
elif key == 'x_embedder.bias':
|
|
new_key = 'all_x_embedder.2-1.bias'
|
|
converted_count += 1
|
|
|
|
# 2. Convert final_layer names
|
|
elif key.startswith('final_layer.'):
|
|
new_key = key.replace('final_layer.', 'all_final_layer.2-1.')
|
|
converted_count += 1
|
|
|
|
# 3. Convert q_norm/k_norm to norm_q/norm_k
|
|
elif '.attention.q_norm.weight' in key:
|
|
new_key = key.replace('.attention.q_norm.weight', '.attention.norm_q.weight')
|
|
converted_count += 1
|
|
elif '.attention.k_norm.weight' in key:
|
|
new_key = key.replace('.attention.k_norm.weight', '.attention.norm_k.weight')
|
|
converted_count += 1
|
|
|
|
# 4. Convert out.weight to to_out.0.weight
|
|
elif '.attention.out.weight' in key:
|
|
new_key = key.replace('.attention.out.weight', '.attention.to_out.0.weight')
|
|
converted_count += 1
|
|
|
|
# 5. Split fused QKV weights into separate Q, K, V
|
|
elif qkv_pattern.match(key):
|
|
prefix = key.replace('.qkv.weight', '')
|
|
|
|
# QKV is fused as [Q, K, V] along dim 0
|
|
# Shape is [3 * hidden_dim, hidden_dim] = [11520, 3840]
|
|
qkv_weight = value
|
|
q_weight, k_weight, v_weight = torch.chunk(qkv_weight, 3, dim=0)
|
|
|
|
new_state_dict[f'{prefix}.to_q.weight'] = q_weight
|
|
new_state_dict[f'{prefix}.to_k.weight'] = k_weight
|
|
new_state_dict[f'{prefix}.to_v.weight'] = v_weight
|
|
|
|
qkv_split_count += 1
|
|
continue # Don't add the original qkv key
|
|
|
|
new_state_dict[new_key] = value
|
|
|
|
print(f"[Z-Image] Converted {converted_count} keys, split {qkv_split_count} fused QKV weights")
|
|
|
|
return new_state_dict
|
|
|
|
from diffusers import DiffusionPipeline
|
|
from transformers import modeling_utils
|
|
|
|
from backend import memory_management
|
|
from backend.utils import read_arbitrary_config, load_torch_file, beautiful_print_gguf_state_dict_statics
|
|
from backend.state_dict import try_filter_state_dict, load_state_dict
|
|
from backend.operations import using_forge_operations
|
|
from backend.nn.vae import IntegratedAutoencoderKL
|
|
from backend.nn.clip import IntegratedCLIP
|
|
from backend.nn.unet import IntegratedUNet2DConditionModel
|
|
|
|
from backend.diffusion_engine.sd15 import StableDiffusion
|
|
from backend.diffusion_engine.sd20 import StableDiffusion2
|
|
from backend.diffusion_engine.sdxl import StableDiffusionXL, StableDiffusionXLRefiner
|
|
from backend.diffusion_engine.sd35 import StableDiffusion3
|
|
from backend.diffusion_engine.flux import Flux
|
|
from backend.diffusion_engine.chroma import Chroma
|
|
from backend.diffusion_engine.chroma_dct import ChromaDCT
|
|
from backend.diffusion_engine.zimage import ZImage
|
|
|
|
|
|
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Chroma, ChromaDCT, ZImage, Flux]
|
|
|
|
|
|
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
|
dir_path = os.path.dirname(__file__)
|
|
|
|
|
|
def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_path, state_dict):
|
|
config_path = os.path.join(repo_path, component_name)
|
|
|
|
if component_name in ['feature_extractor', 'safety_checker']:
|
|
return None
|
|
|
|
if lib_name in ['transformers', 'diffusers']:
|
|
if component_name in ['scheduler']:
|
|
cls = getattr(importlib.import_module(lib_name), cls_name)
|
|
return cls.from_pretrained(os.path.join(repo_path, component_name))
|
|
if component_name.startswith('tokenizer'):
|
|
cls = getattr(importlib.import_module(lib_name), cls_name)
|
|
comp = cls.from_pretrained(os.path.join(repo_path, component_name))
|
|
comp._eventual_warn_about_too_long_sequence = lambda *args, **kwargs: None
|
|
return comp
|
|
if cls_name in ['AutoencoderKL']:
|
|
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have VAE state dict!'
|
|
|
|
config = IntegratedAutoencoderKL.load_config(config_path)
|
|
|
|
# Check for Z-Image specific VAE precision
|
|
vae_dtype = memory_management.vae_dtype()
|
|
if getattr(guess, 'is_zimage', False):
|
|
try:
|
|
from modules import shared
|
|
z_vae_dtype = getattr(shared.opts, 'z_vae_dtype', 'Automatic')
|
|
if z_vae_dtype != 'Automatic':
|
|
dtype_map = {
|
|
'bfloat16': torch.bfloat16,
|
|
'float16': torch.float16,
|
|
'float32': torch.float32,
|
|
}
|
|
if z_vae_dtype in dtype_map:
|
|
vae_dtype = dtype_map[z_vae_dtype]
|
|
print(f'Z-Image VAE: Using user-specified dtype: {vae_dtype}')
|
|
except Exception as e:
|
|
print(f'Warning: Could not read Z-Image VAE precision setting: {e}')
|
|
|
|
with using_forge_operations(device=memory_management.cpu, dtype=vae_dtype):
|
|
model = IntegratedAutoencoderKL.from_config(config)
|
|
|
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in state_dict.keys(): #diffusers format
|
|
state_dict = huggingface_guess.diffusers_convert.convert_vae_state_dict(state_dict)
|
|
load_state_dict(model, state_dict, ignore_start='loss.')
|
|
return model
|
|
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
|
|
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have CLIP state dict!'
|
|
|
|
from transformers import CLIPTextConfig, CLIPTextModel
|
|
config = CLIPTextConfig.from_pretrained(config_path)
|
|
|
|
to_args = dict(device=memory_management.cpu, dtype=memory_management.text_encoder_dtype())
|
|
|
|
with modeling_utils.no_init_weights():
|
|
with using_forge_operations(**to_args, manual_cast_enabled=True):
|
|
model = IntegratedCLIP(CLIPTextModel, config, add_text_projection=True).to(**to_args)
|
|
|
|
load_state_dict(model, state_dict, ignore_errors=[
|
|
'transformer.text_projection.weight',
|
|
'transformer.text_model.embeddings.position_ids',
|
|
'logit_scale'
|
|
], log_name=cls_name)
|
|
|
|
return model
|
|
if cls_name == 'T5EncoderModel':
|
|
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have T5 state dict!'
|
|
|
|
from backend.nn.t5 import IntegratedT5
|
|
config = read_arbitrary_config(config_path)
|
|
|
|
storage_dtype = memory_management.text_encoder_dtype()
|
|
state_dict_dtype = memory_management.state_dict_dtype(state_dict)
|
|
|
|
if state_dict_dtype in [torch.float32, torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4', 'gguf']:
|
|
print(f'Using Detected T5 Data Type: {state_dict_dtype}')
|
|
storage_dtype = state_dict_dtype
|
|
if state_dict_dtype in ['nf4', 'fp4', 'gguf']:
|
|
print(f'Using pre-quant state dict!')
|
|
if state_dict_dtype in ['gguf']:
|
|
beautiful_print_gguf_state_dict_statics(state_dict)
|
|
else:
|
|
print(f'Using Default T5 Data Type: {storage_dtype}')
|
|
|
|
if storage_dtype in ['nf4', 'fp4', 'gguf']:
|
|
with modeling_utils.no_init_weights():
|
|
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.text_encoder_dtype(), manual_cast_enabled=False, bnb_dtype=storage_dtype):
|
|
model = IntegratedT5(config)
|
|
else:
|
|
with modeling_utils.no_init_weights():
|
|
with using_forge_operations(device=memory_management.cpu, dtype=storage_dtype, manual_cast_enabled=True):
|
|
model = IntegratedT5(config)
|
|
|
|
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
|
|
|
return model
|
|
if cls_name == 'Qwen3Model':
|
|
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have Qwen3 text encoder state dict!'
|
|
|
|
text_encoder_dtype = memory_management.text_encoder_dtype()
|
|
|
|
# Check for Z-Image specific text encoder precision
|
|
if getattr(guess, 'is_zimage', False):
|
|
try:
|
|
from modules import shared
|
|
z_text_encoder_dtype = getattr(shared.opts, 'z_text_encoder_dtype', 'Automatic')
|
|
if z_text_encoder_dtype != 'Automatic':
|
|
dtype_map = {
|
|
'bfloat16': torch.bfloat16,
|
|
'float16': torch.float16,
|
|
'float32': torch.float32,
|
|
}
|
|
if z_text_encoder_dtype in dtype_map:
|
|
text_encoder_dtype = dtype_map[z_text_encoder_dtype]
|
|
print(f'Z-Image Text Encoder: Using user-specified dtype: {text_encoder_dtype}')
|
|
except Exception as e:
|
|
print(f'Warning: Could not read Z-Image text encoder precision setting: {e}')
|
|
|
|
from transformers import Qwen3Config
|
|
config = Qwen3Config.from_pretrained(config_path)
|
|
cls = getattr(importlib.import_module('transformers'), cls_name)
|
|
with modeling_utils.no_init_weights():
|
|
model = cls(config)
|
|
model = model.to(dtype=text_encoder_dtype)
|
|
# Strip 'model.' prefix from state_dict keys if present
|
|
if any(k.startswith('model.') for k in state_dict.keys()):
|
|
state_dict = {k.replace('model.', '', 1): v for k, v in state_dict.items()}
|
|
load_state_dict(model, state_dict, log_name=cls_name)
|
|
return model
|
|
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel', 'ChromaTransformer2DModel', 'ChromaDCTTransformer2DModel', 'ZImageTransformer2DModel']:
|
|
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
|
|
|
|
model_loader = None
|
|
if cls_name == 'UNet2DConditionModel':
|
|
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
|
|
elif cls_name == 'FluxTransformer2DModel':
|
|
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
|
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
|
elif cls_name == 'ChromaTransformer2DModel':
|
|
from backend.nn.chroma import IntegratedChromaTransformer2DModel
|
|
model_loader = lambda c: IntegratedChromaTransformer2DModel(**c)
|
|
elif cls_name == 'ChromaDCTTransformer2DModel':
|
|
from backend.nn.model_dct import IntegratedChromaDCTTransformer2DModel
|
|
model_loader = lambda c: IntegratedChromaDCTTransformer2DModel(**c)
|
|
elif cls_name == 'ZImageTransformer2DModel':
|
|
# Load ZImageTransformer2DModel directly from diffusers
|
|
from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
|
|
model_loader = lambda c: ZImageTransformer2DModel(**c)
|
|
elif cls_name == 'SD3Transformer2DModel':
|
|
from backend.nn.mmditx import MMDiTX
|
|
model_loader = lambda c: MMDiTX(**c)
|
|
|
|
unet_config = guess.unet_config.copy()
|
|
state_dict_parameters = memory_management.state_dict_parameters(state_dict)
|
|
state_dict_dtype = memory_management.state_dict_dtype(state_dict)
|
|
|
|
storage_dtype = memory_management.unet_dtype(model_params=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes)
|
|
|
|
unet_storage_dtype_overwrite = backend.args.dynamic_args.get('forge_unet_storage_dtype')
|
|
|
|
if unet_storage_dtype_overwrite is not None:
|
|
storage_dtype = unet_storage_dtype_overwrite
|
|
elif state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4', 'gguf']:
|
|
print(f'Using Detected UNet Type: {state_dict_dtype}')
|
|
storage_dtype = state_dict_dtype
|
|
if state_dict_dtype in ['nf4', 'fp4', 'gguf']:
|
|
print(f'Using pre-quant state dict!')
|
|
if state_dict_dtype in ['gguf']:
|
|
beautiful_print_gguf_state_dict_statics(state_dict)
|
|
|
|
# Z-Image specific precision settings (highest priority)
|
|
if cls_name == 'ZImageTransformer2DModel':
|
|
try:
|
|
from modules import shared
|
|
z_transformer_dtype = getattr(shared.opts, 'z_transformer_dtype', 'Automatic')
|
|
if z_transformer_dtype != 'Automatic':
|
|
dtype_map = {
|
|
'bfloat16': torch.bfloat16,
|
|
'float16': torch.float16,
|
|
'float32': torch.float32,
|
|
}
|
|
if z_transformer_dtype in dtype_map:
|
|
storage_dtype = dtype_map[z_transformer_dtype]
|
|
print(f'Z-Image Transformer: Using user-specified dtype: {storage_dtype}')
|
|
except Exception as e:
|
|
print(f'Warning: Could not read Z-Image precision setting: {e}')
|
|
|
|
load_device = memory_management.get_torch_device()
|
|
computation_dtype = memory_management.get_computation_dtype(load_device, parameters=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes)
|
|
|
|
# For Z-Image, if user specified a precision, also use it for computation
|
|
if cls_name == 'ZImageTransformer2DModel':
|
|
try:
|
|
from modules import shared
|
|
z_transformer_dtype = getattr(shared.opts, 'z_transformer_dtype', 'Automatic')
|
|
if z_transformer_dtype != 'Automatic':
|
|
dtype_map = {
|
|
'bfloat16': torch.bfloat16,
|
|
'float16': torch.float16,
|
|
'float32': torch.float32,
|
|
}
|
|
if z_transformer_dtype in dtype_map:
|
|
computation_dtype = dtype_map[z_transformer_dtype]
|
|
except Exception:
|
|
pass
|
|
|
|
offload_device = memory_management.unet_offload_device()
|
|
|
|
if storage_dtype in ['nf4', 'fp4', 'gguf']:
|
|
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_parameters, dtype=computation_dtype)
|
|
with using_forge_operations(device=initial_device, dtype=computation_dtype, manual_cast_enabled=False, bnb_dtype=storage_dtype):
|
|
model = model_loader(unet_config)
|
|
else:
|
|
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_parameters, dtype=storage_dtype)
|
|
need_manual_cast = storage_dtype != computation_dtype
|
|
to_args = dict(device=initial_device, dtype=storage_dtype)
|
|
with using_forge_operations(**to_args, manual_cast_enabled=need_manual_cast):
|
|
model = model_loader(unet_config).to(**to_args)
|
|
|
|
# Convert ComfyUI Z-Image format to Diffusers format if needed
|
|
if cls_name == 'ZImageTransformer2DModel':
|
|
state_dict = convert_comfy_zimage_state_dict(state_dict)
|
|
|
|
load_state_dict(model, state_dict)
|
|
|
|
if cls_name == 'ZImageTransformer2DModel':
|
|
patch_zimage_for_fp16(model)
|
|
|
|
if hasattr(model, '_internal_dict'):
|
|
model._internal_dict = unet_config
|
|
else:
|
|
model.config = unet_config
|
|
|
|
model.storage_dtype = storage_dtype
|
|
model.computation_dtype = computation_dtype
|
|
model.load_device = load_device
|
|
model.initial_device = initial_device
|
|
model.offload_device = offload_device
|
|
|
|
# Apply RamTorch for Chroma models if enabled
|
|
if cls_name == 'ChromaTransformer2DModel' and backend.args.args.use_ramtorch_chroma:
|
|
print("[RamTorch] Enabling RamTorch memory management for Chroma model...")
|
|
from backend.ramtorch_integration import replace_linear_with_bouncing, configure_ramtorch_for_chroma
|
|
|
|
# Configure RamTorch for inference-only use
|
|
configure_ramtorch_for_chroma(
|
|
memory_threshold=0.8, # Use RamTorch when VRAM usage exceeds 80%
|
|
prefetch_enabled=True, # Enable block prefetching for better performance
|
|
enable_zero=False # No ZeRO optimizer needed for inference
|
|
)
|
|
|
|
# Replace Linear layers with CPU-bouncing versions
|
|
model = replace_linear_with_bouncing(
|
|
model,
|
|
device=str(load_device),
|
|
enable_ramtorch=True
|
|
)
|
|
|
|
print("[RamTorch] Chroma model configured for CPU-GPU weight bouncing")
|
|
|
|
return model
|
|
|
|
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
|
|
return None
|
|
|
|
|
|
def replace_state_dict(sd, asd, guess):
|
|
vae_key_prefix = guess.vae_key_prefix[0]
|
|
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
|
|
|
|
if 'enc.blk.0.attn_k.weight' in asd:
|
|
wierd_t5_format_from_city96 = {
|
|
"enc.": "encoder.",
|
|
".blk.": ".block.",
|
|
"token_embd": "shared",
|
|
"output_norm": "final_layer_norm",
|
|
"attn_q": "layer.0.SelfAttention.q",
|
|
"attn_k": "layer.0.SelfAttention.k",
|
|
"attn_v": "layer.0.SelfAttention.v",
|
|
"attn_o": "layer.0.SelfAttention.o",
|
|
"attn_norm": "layer.0.layer_norm",
|
|
"attn_rel_b": "layer.0.SelfAttention.relative_attention_bias",
|
|
"ffn_up": "layer.1.DenseReluDense.wi_1",
|
|
"ffn_down": "layer.1.DenseReluDense.wo",
|
|
"ffn_gate": "layer.1.DenseReluDense.wi_0",
|
|
"ffn_norm": "layer.1.layer_norm",
|
|
}
|
|
wierd_t5_pre_quant_keys_from_city96 = ['shared.weight']
|
|
asd_new = {}
|
|
for k, v in asd.items():
|
|
for s, d in wierd_t5_format_from_city96.items():
|
|
k = k.replace(s, d)
|
|
asd_new[k] = v
|
|
for k in wierd_t5_pre_quant_keys_from_city96:
|
|
asd_new[k] = asd_new[k].dequantize_as_pytorch_parameter()
|
|
asd.clear()
|
|
asd = asd_new
|
|
|
|
if "decoder.conv_in.weight" in asd:
|
|
keys_to_delete = [k for k in sd if k.startswith(vae_key_prefix)]
|
|
for k in keys_to_delete:
|
|
del sd[k]
|
|
for k, v in asd.items():
|
|
sd[vae_key_prefix + k] = v
|
|
|
|
|
|
## identify model type
|
|
flux_test_key = "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale"
|
|
sd3_test_key = "model.diffusion_model.final_layer.adaLN_modulation.1.bias"
|
|
legacy_test_key = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
|
|
|
model_type = "-"
|
|
if legacy_test_key in sd:
|
|
match sd[legacy_test_key].shape[1]:
|
|
case 768:
|
|
model_type = "sd1"
|
|
case 1024:
|
|
model_type = "sd2"
|
|
case 1280:
|
|
model_type = "xlrf" # sdxl refiner model
|
|
case 2048:
|
|
model_type = "sdxl"
|
|
elif flux_test_key in sd:
|
|
model_type = "flux"
|
|
elif sd3_test_key in sd:
|
|
model_type = "sd3"
|
|
|
|
## prefixes used by various model types for CLIP-L
|
|
prefix_L = {
|
|
"-" : None,
|
|
"sd1" : "cond_stage_model.transformer.",
|
|
"sd2" : None,
|
|
"xlrf": None,
|
|
"sdxl": "conditioner.embedders.0.transformer.",
|
|
"flux": "text_encoders.clip_l.transformer.",
|
|
"sd3" : "text_encoders.clip_l.transformer.",
|
|
}
|
|
## prefixes used by various model types for CLIP-G
|
|
prefix_G = {
|
|
"-" : None,
|
|
"sd1" : None,
|
|
"sd2" : None,
|
|
"xlrf": "conditioner.embedders.0.model.transformer.",
|
|
"sdxl": "conditioner.embedders.1.model.transformer.",
|
|
"flux": None,
|
|
"sd3" : "text_encoders.clip_g.transformer.",
|
|
}
|
|
## prefixes used by various model types for CLIP-H
|
|
prefix_H = {
|
|
"-" : None,
|
|
"sd1" : None,
|
|
"sd2" : "conditioner.embedders.0.model.",
|
|
"xlrf": None,
|
|
"sdxl": None,
|
|
"flux": None,
|
|
"sd3" : None,
|
|
}
|
|
|
|
|
|
## VAE format 0 (extracted from model, could be sd1, sd2, sdxl, sd3).
|
|
if "first_stage_model.decoder.conv_in.weight" in asd:
|
|
channels = asd["first_stage_model.decoder.conv_in.weight"].shape[1]
|
|
if model_type == "sd1" or model_type == "sd2" or model_type == "xlrf" or model_type == "sdxl":
|
|
if channels == 4:
|
|
for k, v in asd.items():
|
|
sd[k] = v
|
|
elif model_type == "sd3":
|
|
if channels == 16:
|
|
for k, v in asd.items():
|
|
sd[k] = v
|
|
|
|
## CLIP-H
|
|
CLIP_H = { # key to identify source model old_prefix
|
|
'cond_stage_model.model.ln_final.weight' : 'cond_stage_model.model.',
|
|
# 'text_model.encoder.layers.0.layer_norm1.bias' : 'text_model'. # would need converting
|
|
}
|
|
for CLIP_key in CLIP_H.keys():
|
|
if CLIP_key in asd and asd[CLIP_key].shape[0] == 1024:
|
|
new_prefix = prefix_H[model_type]
|
|
old_prefix = CLIP_H[CLIP_key]
|
|
|
|
if new_prefix is not None:
|
|
for k, v in asd.items():
|
|
new_k = k.replace(old_prefix, new_prefix)
|
|
sd[new_k] = v
|
|
|
|
## CLIP-G
|
|
CLIP_G = { # key to identify source model old_prefix
|
|
'conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias' : 'conditioner.embedders.1.model.transformer.',
|
|
'text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_g.transformer.',
|
|
'text_model.encoder.layers.0.layer_norm1.bias' : '',
|
|
'transformer.resblocks.0.ln_1.bias' : 'transformer.'
|
|
}
|
|
for CLIP_key in CLIP_G.keys():
|
|
if CLIP_key in asd and asd[CLIP_key].shape[0] == 1280:
|
|
new_prefix = prefix_G[model_type]
|
|
old_prefix = CLIP_G[CLIP_key]
|
|
|
|
if new_prefix is not None:
|
|
if "resblocks" not in CLIP_key and model_type != "sd3": # need to convert
|
|
def convert_transformers(statedict, prefix_from, prefix_to, number):
|
|
keys_to_replace = {
|
|
"{}text_model.embeddings.position_embedding.weight" : "{}positional_embedding",
|
|
"{}text_model.embeddings.token_embedding.weight" : "{}token_embedding.weight",
|
|
"{}text_model.final_layer_norm.weight" : "{}ln_final.weight",
|
|
"{}text_model.final_layer_norm.bias" : "{}ln_final.bias",
|
|
"text_projection.weight" : "{}text_projection",
|
|
}
|
|
resblock_to_replace = {
|
|
"layer_norm1" : "ln_1",
|
|
"layer_norm2" : "ln_2",
|
|
"mlp.fc1" : "mlp.c_fc",
|
|
"mlp.fc2" : "mlp.c_proj",
|
|
"self_attn.out_proj" : "attn.out_proj" ,
|
|
}
|
|
|
|
for x in keys_to_replace: # remove trailing 'transformer.' from new prefix
|
|
k = x.format(prefix_from)
|
|
statedict[keys_to_replace[x].format(prefix_to[:-12])] = statedict.pop(k)
|
|
|
|
for resblock in range(number):
|
|
for y in ["weight", "bias"]:
|
|
for x in resblock_to_replace:
|
|
k = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, x, y)
|
|
k_to = "{}resblocks.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
|
statedict[k_to] = statedict.pop(k)
|
|
|
|
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.q_proj", y)
|
|
weightsQ = statedict.pop(k_from)
|
|
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.k_proj", y)
|
|
weightsK = statedict.pop(k_from)
|
|
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.v_proj", y)
|
|
weightsV = statedict.pop(k_from)
|
|
|
|
k_to = "{}resblocks.{}.attn.in_proj_{}".format(prefix_to, resblock, y)
|
|
|
|
statedict[k_to] = torch.cat((weightsQ, weightsK, weightsV))
|
|
return statedict
|
|
|
|
asd = convert_transformers(asd, old_prefix, new_prefix, 32)
|
|
for k, v in asd.items():
|
|
sd[k] = v
|
|
|
|
elif old_prefix == "":
|
|
for k, v in asd.items():
|
|
new_k = new_prefix + k
|
|
sd[new_k] = v
|
|
else:
|
|
for k, v in asd.items():
|
|
new_k = k.replace(old_prefix, new_prefix)
|
|
sd[new_k] = v
|
|
|
|
## CLIP-L
|
|
CLIP_L = { # key to identify source model old_prefix
|
|
'cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'cond_stage_model.transformer.',
|
|
'conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'conditioner.embedders.0.transformer.',
|
|
'text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_l.transformer.',
|
|
'text_model.encoder.layers.0.layer_norm1.bias' : '',
|
|
'transformer.resblocks.0.ln_1.bias' : 'transformer.'
|
|
}
|
|
|
|
for CLIP_key in CLIP_L.keys():
|
|
if CLIP_key in asd and asd[CLIP_key].shape[0] == 768:
|
|
new_prefix = prefix_L[model_type]
|
|
old_prefix = CLIP_L[CLIP_key]
|
|
|
|
if new_prefix is not None:
|
|
if "resblocks" in CLIP_key: # need to convert
|
|
def transformers_convert(statedict, prefix_from, prefix_to, number):
|
|
keys_to_replace = {
|
|
"positional_embedding" : "{}text_model.embeddings.position_embedding.weight",
|
|
"token_embedding.weight": "{}text_model.embeddings.token_embedding.weight",
|
|
"ln_final.weight" : "{}text_model.final_layer_norm.weight",
|
|
"ln_final.bias" : "{}text_model.final_layer_norm.bias",
|
|
"text_projection" : "text_projection.weight",
|
|
}
|
|
resblock_to_replace = {
|
|
"ln_1" : "layer_norm1",
|
|
"ln_2" : "layer_norm2",
|
|
"mlp.c_fc" : "mlp.fc1",
|
|
"mlp.c_proj" : "mlp.fc2",
|
|
"attn.out_proj" : "self_attn.out_proj",
|
|
}
|
|
|
|
for k in keys_to_replace:
|
|
statedict[keys_to_replace[k].format(prefix_to)] = statedict.pop(k)
|
|
|
|
for resblock in range(number):
|
|
for y in ["weight", "bias"]:
|
|
for x in resblock_to_replace:
|
|
k = "{}resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
|
k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
|
statedict[k_to] = statedict.pop(k)
|
|
|
|
k_from = "{}resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
|
weights = statedict.pop(k_from)
|
|
shape_from = weights.shape[0] // 3
|
|
for x in range(3):
|
|
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
|
k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
|
statedict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
|
return statedict
|
|
|
|
asd = transformers_convert(asd, old_prefix, new_prefix, 12)
|
|
for k, v in asd.items():
|
|
sd[k] = v
|
|
|
|
elif old_prefix == "":
|
|
for k, v in asd.items():
|
|
new_k = new_prefix + k
|
|
sd[new_k] = v
|
|
else:
|
|
for k, v in asd.items():
|
|
new_k = k.replace(old_prefix, new_prefix)
|
|
sd[new_k] = v
|
|
|
|
|
|
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
|
|
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
|
|
for k in keys_to_delete:
|
|
del sd[k]
|
|
for k, v in asd.items():
|
|
sd[f"{text_encoder_key_prefix}t5xxl.transformer.{k}"] = v
|
|
|
|
return sd
|
|
|
|
|
|
def preprocess_state_dict(sd):
|
|
if not any(k.startswith("model.diffusion_model") for k in sd.keys()):
|
|
sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()}
|
|
|
|
return sd
|
|
|
|
|
|
def split_state_dict(sd, additional_state_dicts: list = None):
|
|
sd = load_torch_file(sd)
|
|
sd = preprocess_state_dict(sd)
|
|
guess = huggingface_guess.guess(sd)
|
|
|
|
if isinstance(additional_state_dicts, list):
|
|
for asd in additional_state_dicts:
|
|
asd = load_torch_file(asd)
|
|
sd = replace_state_dict(sd, asd, guess)
|
|
del asd
|
|
|
|
guess.clip_target = guess.clip_target(sd)
|
|
guess.model_type = guess.model_type(sd)
|
|
guess.ztsnr = 'ztsnr' in sd
|
|
|
|
sd = guess.process_vae_state_dict(sd)
|
|
|
|
state_dict = {
|
|
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
|
|
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
|
|
}
|
|
|
|
sd = guess.process_clip_state_dict(sd)
|
|
|
|
for k, v in guess.clip_target.items():
|
|
state_dict[v] = try_filter_state_dict(sd, [k + '.'])
|
|
|
|
state_dict['ignore'] = sd
|
|
|
|
print_dict = {k: len(v) for k, v in state_dict.items()}
|
|
print(f'StateDict Keys: {print_dict}')
|
|
|
|
del state_dict['ignore']
|
|
|
|
return state_dict, guess
|
|
|
|
# To be removed once PR merged on huggingface_guess
|
|
chroma_is_in_huggingface_guess = hasattr(huggingface_guess.model_list, "Chroma")
|
|
|
|
if not chroma_is_in_huggingface_guess:
|
|
class GuessChroma:
|
|
huggingface_repo = 'Chroma'
|
|
unet_extra_config = {
|
|
'guidance_out_dim': 3072,
|
|
'guidance_hidden_dim': 5120,
|
|
'guidance_n_layers': 5
|
|
}
|
|
unet_remove_config = ['guidance_embed']
|
|
@torch.inference_mode()
|
|
def forge_loader(sd, additional_state_dicts=None, preset=None):
|
|
# Handle Z-Image preset with manual loading
|
|
if preset == 'z':
|
|
print("DEBUG: Loading Z-Image model from preset")
|
|
# Load components manually from safetensors files
|
|
state_dicts = {}
|
|
|
|
# Load transformer (main checkpoint)
|
|
transformer_sd = load_torch_file(sd)
|
|
state_dicts['transformer'] = transformer_sd
|
|
|
|
# Load additional modules (VAE and text encoder)
|
|
if additional_state_dicts:
|
|
for module_path in additional_state_dicts:
|
|
module_sd = load_torch_file(module_path)
|
|
# Determine if it's VAE or text encoder based on keys
|
|
if 'decoder.conv_in.weight' in module_sd or 'decoder.up_blocks.0.resnets.0.conv1.weight' in module_sd:
|
|
state_dicts['vae'] = module_sd
|
|
elif 'model.embed_tokens.weight' in module_sd or 'embed_tokens.weight' in module_sd:
|
|
state_dicts['text_encoder'] = module_sd
|
|
|
|
# Create minimal config for Z-Image
|
|
local_path = os.path.join(dir_path, 'huggingface', 'Z-Image')
|
|
if os.path.exists(local_path):
|
|
config = DiffusionPipeline.load_config(local_path)
|
|
else:
|
|
# Fallback config if Z-Image folder doesn't exist
|
|
config = {
|
|
"_class_name": "ZImagePipeline",
|
|
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
|
|
"text_encoder": ["transformers", "Qwen3Model"],
|
|
"tokenizer": ["transformers", "Qwen2Tokenizer"],
|
|
"transformer": ["diffusers", "ZImageTransformer2DModel"],
|
|
"vae": ["diffusers", "AutoencoderKL"]
|
|
}
|
|
|
|
# Create a minimal estimated_config for Z-Image
|
|
class ZImageConfig:
|
|
def __init__(self):
|
|
self.huggingface_repo = "Z-Image"
|
|
self.is_zimage = True # Flag to identify Z-Image models
|
|
self.unet_config = {
|
|
'in_channels': 16,
|
|
'dim': 3840,
|
|
'n_heads': 30,
|
|
'n_kv_heads': 30,
|
|
'n_layers': 30,
|
|
'n_refiner_layers': 2,
|
|
'axes_dims': [32, 48, 48],
|
|
'axes_lens': [1536, 512, 512],
|
|
'cap_feat_dim': 2560,
|
|
'all_patch_size': [2],
|
|
'all_f_patch_size': [1],
|
|
'rope_theta': 256.0,
|
|
't_scale': 1000.0,
|
|
'norm_eps': 1e-05,
|
|
'qk_norm': True
|
|
}
|
|
# Z-Image requires bfloat16 - the model has precision-sensitive layers
|
|
# (t_embedder, cap_embedder) that produce NaN with float16
|
|
self.supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
|
|
|
def inpaint_model(self):
|
|
return False
|
|
|
|
estimated_config = ZImageConfig()
|
|
|
|
huggingface_components = {}
|
|
for component_name, v in config.items():
|
|
if isinstance(v, list) and len(v) == 2:
|
|
lib_name, cls_name = v
|
|
component_sd = state_dicts.get(component_name, None)
|
|
component = load_huggingface_component(estimated_config, component_name, lib_name, cls_name, local_path, component_sd)
|
|
if component is not None:
|
|
huggingface_components[component_name] = component
|
|
|
|
print("DEBUG: Loaded Z-Image components:", list(huggingface_components.keys()))
|
|
return ZImage(components_dict=huggingface_components, estimated_config=estimated_config)
|
|
|
|
try:
|
|
state_dicts, estimated_config = split_state_dict(sd, additional_state_dicts=additional_state_dicts)
|
|
except:
|
|
raise ValueError('Failed to recognize model type!')
|
|
|
|
# Debug: Print transformer keys to understand the structure
|
|
if "transformer" in state_dicts:
|
|
transformer_keys = list(state_dicts["transformer"].keys())
|
|
print(f"DEBUG: Found {len(transformer_keys)} transformer keys")
|
|
dct_keys = [k for k in transformer_keys if k.startswith("img_in_patch.") or k.startswith("nerf_blocks.")]
|
|
print(f"DEBUG: DCT keys found: {dct_keys[:5]}..." if len(dct_keys) > 5 else f"DEBUG: DCT keys found: {dct_keys}")
|
|
|
|
# Detect ChromaDCT models FIRST by checking for DCT-specific layers
|
|
if "transformer" in state_dicts and any(key.startswith("img_in_patch.") or key.startswith("nerf_blocks.") for key in state_dicts["transformer"]):
|
|
estimated_config.huggingface_repo = "ChromaDCT"
|
|
# Configure DCT-specific parameters
|
|
estimated_config.unet_config.update({
|
|
'in_channels': 3,
|
|
'context_in_dim': 4096,
|
|
'hidden_size': 3072,
|
|
'mlp_ratio': 4.0,
|
|
'num_heads': 24,
|
|
'depth': 19,
|
|
'depth_single_blocks': 38,
|
|
'axes_dim': [16, 56, 56],
|
|
'theta': 10000,
|
|
'qkv_bias': True,
|
|
'guidance_embed': True,
|
|
'approximator_in_dim': 64,
|
|
'approximator_depth': 5,
|
|
'approximator_hidden_size': 5120,
|
|
'patch_size': 16,
|
|
'nerf_hidden_size': 64,
|
|
'nerf_mlp_ratio': 4,
|
|
'nerf_depth': 4,
|
|
'nerf_max_freqs': 8,
|
|
'_use_compiled': False
|
|
})
|
|
# ChromaDCT uses same text encoder setup as regular Chroma
|
|
if 'text_encoder_2' in state_dicts:
|
|
state_dicts['text_encoder'] = state_dicts['text_encoder_2']
|
|
del state_dicts['text_encoder_2']
|
|
elif not chroma_is_in_huggingface_guess \
|
|
and estimated_config.huggingface_repo == "black-forest-labs/FLUX.1-schnell" \
|
|
and "transformer" in state_dicts \
|
|
and "distilled_guidance_layer.layers.0.in_layer.bias" in state_dicts["transformer"]:
|
|
estimated_config.huggingface_repo = GuessChroma.huggingface_repo
|
|
for x in GuessChroma.unet_extra_config:
|
|
estimated_config.unet_config[x] = GuessChroma.unet_extra_config[x]
|
|
for x in GuessChroma.unet_remove_config:
|
|
del estimated_config.unet_config[x]
|
|
state_dicts['text_encoder'] = state_dicts['text_encoder_2']
|
|
del state_dicts['text_encoder_2']
|
|
|
|
repo_name = estimated_config.huggingface_repo
|
|
|
|
# Handle ChromaDCT with direct config to avoid HuggingFace directory structure requirements
|
|
if repo_name == "ChromaDCT":
|
|
config = {
|
|
"_class_name": "FluxPipeline",
|
|
"_diffusers_version": "0.30.0.dev0",
|
|
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
|
|
"text_encoder": ["transformers", "T5EncoderModel"],
|
|
"tokenizer": ["transformers", "T5TokenizerFast"],
|
|
"transformer": ["diffusers", "ChromaDCTTransformer2DModel"]
|
|
}
|
|
local_path = os.path.join(dir_path, 'huggingface', 'Chroma') # Use Chroma configs for components
|
|
else:
|
|
local_path = os.path.join(dir_path, 'huggingface', repo_name)
|
|
config: dict = DiffusionPipeline.load_config(local_path)
|
|
huggingface_components = {}
|
|
for component_name, v in config.items():
|
|
if isinstance(v, list) and len(v) == 2:
|
|
lib_name, cls_name = v
|
|
component_sd = state_dicts.get(component_name, None)
|
|
component = load_huggingface_component(estimated_config, component_name, lib_name, cls_name, local_path, component_sd)
|
|
if component_sd is not None:
|
|
del state_dicts[component_name]
|
|
if component is not None:
|
|
huggingface_components[component_name] = component
|
|
|
|
yaml_config = None
|
|
yaml_config_prediction_type = None
|
|
|
|
try:
|
|
import yaml
|
|
from pathlib import Path
|
|
config_filename = os.path.splitext(sd)[0] + '.yaml'
|
|
if Path(config_filename).is_file():
|
|
with open(config_filename, 'r') as stream:
|
|
yaml_config = yaml.safe_load(stream)
|
|
except ImportError:
|
|
pass
|
|
|
|
# Fix Huggingface prediction type using .yaml config or estimated config detection
|
|
prediction_types = {
|
|
'EPS': 'epsilon',
|
|
'V_PREDICTION': 'v_prediction',
|
|
'EDM': 'edm',
|
|
}
|
|
|
|
has_prediction_type = 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config
|
|
|
|
if yaml_config is not None:
|
|
yaml_config_prediction_type: str = (
|
|
yaml_config.get('model', {}).get('params', {}).get('parameterization', '')
|
|
or yaml_config.get('model', {}).get('params', {}).get('denoiser_config', {}).get('params', {}).get('scaling_config', {}).get('target', '')
|
|
)
|
|
if yaml_config_prediction_type == 'v' or yaml_config_prediction_type.endswith(".VScaling"):
|
|
yaml_config_prediction_type = 'v_prediction'
|
|
else:
|
|
# Use estimated prediction config if no suitable prediction type found
|
|
yaml_config_prediction_type = ''
|
|
|
|
if has_prediction_type:
|
|
if yaml_config_prediction_type:
|
|
huggingface_components['scheduler'].config.prediction_type = yaml_config_prediction_type
|
|
else:
|
|
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
|
|
|
|
# Debug: Print final repo detection
|
|
print(f"DEBUG: Final repo name: {estimated_config.huggingface_repo}")
|
|
|
|
if not chroma_is_in_huggingface_guess and estimated_config.huggingface_repo == "Chroma":
|
|
print("DEBUG: Loading regular Chroma engine")
|
|
return Chroma(estimated_config=estimated_config, huggingface_components=huggingface_components)
|
|
if estimated_config.huggingface_repo == "ChromaDCT":
|
|
print("DEBUG: Loading ChromaDCT engine")
|
|
return ChromaDCT(estimated_config=estimated_config, huggingface_components=huggingface_components)
|
|
|
|
# Load Z-Image when 'z' preset is selected
|
|
if preset == 'z':
|
|
print("DEBUG: Loading Z-Image engine (preset selected)")
|
|
return ZImage(components_dict=huggingface_components)
|
|
|
|
for M in possible_models:
|
|
if any(isinstance(estimated_config, x) for x in M.matched_guesses):
|
|
return M(estimated_config=estimated_config, huggingface_components=huggingface_components)
|
|
|
|
print('Failed to recognize model type!')
|
|
return None |