stable-diffusion-webui-forge/backend/diffusion_engine/zimage.py
2025-12-04 16:47:20 -08:00

659 lines
31 KiB
Python

import torch
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
from backend.patcher.clip import CLIP
from backend.patcher.vae import VAE
from backend.patcher.unet import UnetPatcher
from backend.text_processing.qwen_engine import QwenTextProcessingEngine
from backend.args import dynamic_args
from backend.modules.k_prediction import PredictionZImage
from backend import memory_management
# Import control components (lazy loaded when needed)
_zimage_control_module = None
def get_zimage_control_module():
"""Lazy load zimage_control module to avoid circular imports"""
global _zimage_control_module
if _zimage_control_module is None:
from backend.nn import zimage_control
_zimage_control_module = zimage_control
return _zimage_control_module
class ZImageLatentFormat:
"""Latent format for Z-Image models (16-channel latents using FLUX VAE)"""
# Computed via least squares regression from latent/RGB pairs
# Maps 16-channel latents to RGB for cheap preview approximation
latent_rgb_factors = [
[-0.037223, -0.005345, 0.027556],
[ 0.016390, 0.037071, 0.064793],
[ 0.019696, -0.027534, -0.014296],
[-0.006327, 0.014695, 0.036488],
[ 0.068934, 0.055449, 0.033998],
[ 0.013588, 0.014694, 0.000826],
[ 0.074462, 0.109467, 0.114805],
[-0.024411, -0.022535, -0.025943],
[-0.023577, 0.013515, 0.077390],
[ 0.062267, 0.037371, -0.032887],
[-0.052327, -0.016201, -0.019524],
[ 0.068548, 0.035523, 0.014399],
[ 0.014019, 0.010056, 0.018069],
[-0.063674, -0.015887, -0.049793],
[-0.005128, -0.034811, -0.009309],
[-0.064089, -0.038018, -0.024993],
]
def __init__(self):
self.scale_factor = 0.3611 # Z-Image VAE scale factor
self.shift_factor = 0.1159 # Z-Image VAE shift factor
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
class ZImage(ForgeDiffusionEngine):
def __init__(self, components_dict, estimated_config=None):
# Create minimal config if not provided
if estimated_config is None:
class MinimalConfig:
def inpaint_model(self):
return False
estimated_config = MinimalConfig()
super().__init__(estimated_config, components_dict)
self.is_inpaint = False
# Add latent_format to model_config for cheap preview approximation
if not hasattr(self.model_config, 'latent_format'):
self.model_config.latent_format = ZImageLatentFormat()
# Wrap Qwen encoder in CLIP interface
clip = CLIP(
model_dict={
'qwen': components_dict['text_encoder']
},
tokenizer_dict={
'qwen': components_dict['tokenizer']
}
)
vae = VAE(model=components_dict['vae'])
# Ensure VAE latent channels match Transformer input channels
# This is necessary because the VAE config might be inferred incorrectly (e.g. 8 channels)
# while the Transformer expects 16 channels.
transformer_config = components_dict['transformer'].config
transformer_channels = None
if isinstance(transformer_config, dict):
transformer_channels = transformer_config.get('in_channels')
elif hasattr(transformer_config, 'in_channels'):
transformer_channels = transformer_config.in_channels
if transformer_channels is not None:
if vae.latent_channels != transformer_channels:
print(f"Correction: VAE latent_channels ({vae.latent_channels}) != Transformer in_channels ({transformer_channels}). Updating VAE to match Transformer.")
vae.latent_channels = transformer_channels
if hasattr(vae.first_stage_model.config, 'latent_channels'):
vae.first_stage_model.config.latent_channels = transformer_channels
# Set Z-Image specific VAE scaling factors
# Default is 0.18215 / 0.0, but Z-Image needs 0.3611 / 0.1159
vae.first_stage_model.scaling_factor = 0.3611
vae.first_stage_model.shift_factor = 0.1159
if hasattr(vae.first_stage_model.config, 'scaling_factor'):
vae.first_stage_model.config.scaling_factor = 0.3611
if hasattr(vae.first_stage_model.config, 'shift_factor'):
vae.first_stage_model.config.shift_factor = 0.1159
# Wrap the transformer to adapt parameter names
class ZImageTransformerWrapper(torch.nn.Module):
def __init__(self, transformer):
super().__init__()
self.transformer = transformer
def forward(self, x, timestep, context=None, transformer_options=None, **kwargs):
# Check if Z-Image ControlNet is active
control_active = (
transformer_options is not None and
transformer_options.get('z_image_controlnet_active', False)
)
if control_active:
# Use control-enabled forward path
return self._forward_with_control(x, timestep, context, transformer_options, **kwargs)
else:
# Use standard forward path (unchanged behavior)
return self._forward_normal(x, timestep, context, transformer_options, **kwargs)
def _forward_normal(self, x, timestep, context=None, transformer_options=None, **kwargs):
"""Standard forward path - unchanged from original implementation"""
import torch
if not isinstance(x, list):
# Input should be [batch, channels, height, width] (4D)
# Need to add frame dimension and split into list
if len(x.shape) == 4: # [B, C, H, W]
x = x.unsqueeze(2) # [B, C, 1, H, W] - add frame dimension
x = list(x.unbind(dim=0)) # List of [C, 1, H, W]
else:
raise ValueError(f"Unexpected input shape: {x.shape}. Expected 4D tensor [B, C, H, W]")
# Convert context to list format, filtering by attention mask if available
if context is not None and not isinstance(context, list):
attention_mask = None
if transformer_options is not None and 'attention_mask' in transformer_options:
attention_mask = transformer_options['attention_mask']
if len(context.shape) == 3: # [batch, seq_len, features]
if attention_mask is not None:
# Filter by attention mask to create variable-length embeddings (like official)
context_list = []
for i in range(context.shape[0]):
filtered = context[i][attention_mask[i]]
context_list.append(filtered)
context = context_list
else:
context = [context[i] for i in range(context.shape[0])]
elif len(context.shape) == 2: # [seq_len, features] - single item
context = [context]
else:
raise ValueError(f"Unexpected context shape: {context.shape}")
# Invert timestep: Forge uses sigma in [1->0] (noisy->clean)
# but Z-Image expects t in [0->1] (noisy->clean)
timestep = 1.0 - timestep
timestep = torch.clamp(timestep, min=1e-6, max=1.0)
# Ensure pad tokens are on the same device as input
target_device = x[0].device
if hasattr(self.transformer, 'x_pad_token') and self.transformer.x_pad_token.device != target_device:
self.transformer.x_pad_token.data = self.transformer.x_pad_token.data.to(target_device)
if hasattr(self.transformer, 'cap_pad_token') and self.transformer.cap_pad_token.device != target_device:
self.transformer.cap_pad_token.data = self.transformer.cap_pad_token.data.to(target_device)
# Call transformer
result = self.transformer(x=x, t=timestep, cap_feats=context, patch_size=2, f_patch_size=1)
output_list = result[0] if isinstance(result, tuple) else result
# Convert list of [C, F, H, W] tensors back to batched tensor
if isinstance(output_list, list):
output = torch.stack(output_list, dim=0).squeeze(2)
return -output
else:
return -output_list
def _forward_with_control(self, x, timestep, context=None, transformer_options=None, **kwargs):
"""Forward path with ControlNet - matches VideoX-Fun behavior"""
import torch
from torch.nn.utils.rnn import pad_sequence
# Get control parameters from transformer_options
control_context = transformer_options.get('control_context') # [B, C, F, H, W]
control_context_scale = transformer_options.get('control_context_scale', 1.0)
# Convert x to list format
if not isinstance(x, list):
if len(x.shape) == 4: # [B, C, H, W]
x_5d = x.unsqueeze(2) # [B, C, 1, H, W]
x_list = list(x_5d.unbind(dim=0)) # List of [C, 1, H, W]
else:
raise ValueError(f"Unexpected input shape: {x.shape}")
else:
x_list = x
# Convert context to list format
if context is not None and not isinstance(context, list):
attention_mask = None
if transformer_options is not None and 'attention_mask' in transformer_options:
attention_mask = transformer_options['attention_mask']
if len(context.shape) == 3: # [batch, seq_len, features]
if attention_mask is not None:
context_list = []
for i in range(context.shape[0]):
filtered = context[i][attention_mask[i]]
context_list.append(filtered)
context = context_list
else:
context = [context[i] for i in range(context.shape[0])]
elif len(context.shape) == 2:
context = [context]
else:
context = context if isinstance(context, list) else [context]
# Invert timestep
timestep = 1.0 - timestep
timestep = torch.clamp(timestep, min=1e-6, max=1.0)
# Ensure pad tokens are on the same device
target_device = x_list[0].device
if hasattr(self.transformer, 'x_pad_token') and self.transformer.x_pad_token.device != target_device:
self.transformer.x_pad_token.data = self.transformer.x_pad_token.data.to(target_device)
if hasattr(self.transformer, 'cap_pad_token') and self.transformer.cap_pad_token.device != target_device:
self.transformer.cap_pad_token.data = self.transformer.cap_pad_token.data.to(target_device)
# Also move control components if they exist
if hasattr(self.transformer, 'control_x_pad_token') and self.transformer.control_x_pad_token.device != target_device:
self.transformer.control_x_pad_token.data = self.transformer.control_x_pad_token.data.to(target_device)
# Prepare control context as list
if control_context is not None:
control_context = control_context.to(target_device)
if control_context.dim() == 5: # [B, C, F, H, W]
control_context_list = list(control_context.unbind(0)) # List of [C, F, H, W]
elif control_context.dim() == 4: # [B, C, H, W]
# Add frame dimension
control_context = control_context.unsqueeze(2) # [B, C, 1, H, W]
control_context_list = list(control_context.unbind(0))
else:
control_context_list = [control_context]
else:
control_context_list = None
# Check if transformer has control layers loaded
if not (hasattr(self.transformer, '_control_layers_loaded') and self.transformer._control_layers_loaded):
print("WARNING: ControlNet active but control layers not loaded, falling back to normal forward")
result = self.transformer(x=x_list, t=timestep, cap_feats=context, patch_size=2, f_patch_size=1)
output_list = result[0] if isinstance(result, tuple) else result
if isinstance(output_list, list):
output = torch.stack(output_list, dim=0).squeeze(2)
return -output
return -output_list
# === Control-enabled forward path ===
# This mirrors VideoX-Fun's ZImageControlTransformer2DModel.forward
patch_size = 2
f_patch_size = 1
bsz = len(x_list)
device = x_list[0].device
# Get timestep embedding
t_scaled = timestep * self.transformer.t_scale
t_emb = self.transformer.t_embedder(t_scaled)
# Patchify and embed
(
x_patches,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_inner_pad_mask,
cap_inner_pad_mask,
) = self.transformer.patchify_and_embed(x_list, context, patch_size, f_patch_size)
# Process x through embedder and refiner
SEQ_MULTI_OF = 32
x_item_seqlens = [len(_) for _ in x_patches]
x_max_item_seqlen = max(x_item_seqlens)
x_cat = torch.cat(x_patches, dim=0)
x_cat = self.transformer.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_cat)
adaln_input = t_emb.type_as(x_cat)
x_cat[torch.cat(x_inner_pad_mask)] = self.transformer.x_pad_token
x_split = list(x_cat.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
x_padded = pad_sequence(x_split, batch_first=True, padding_value=0.0)
x_freqs_cis_padded = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
# Noise refiner
for layer in self.transformer.noise_refiner:
x_padded = layer(x_padded, x_attn_mask, x_freqs_cis_padded, adaln_input)
# Process caption features
cap_item_seqlens = [len(_) for _ in cap_feats]
cap_max_item_seqlen = max(cap_item_seqlens)
cap_cat = torch.cat(cap_feats, dim=0)
cap_cat = self.transformer.cap_embedder(cap_cat)
cap_cat[torch.cat(cap_inner_pad_mask)] = self.transformer.cap_pad_token
cap_split = list(cap_cat.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(self.transformer.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
cap_padded = pad_sequence(cap_split, batch_first=True, padding_value=0.0)
cap_freqs_cis_padded = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
# Context refiner
for layer in self.transformer.context_refiner:
cap_padded = layer(cap_padded, cap_attn_mask, cap_freqs_cis_padded)
# Unify x and caption
unified = []
unified_freqs_cis = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x_padded[i][:x_len], cap_padded[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis_padded[i][:x_len], cap_freqs_cis_padded[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
unified_max_item_seqlen = max(unified_item_seqlens)
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
# Generate control hints
hints = None
if control_context_list is not None and hasattr(self.transformer, 'forward_control'):
# Build kwargs exactly like VideoX-Fun
kwargs_for_control = dict(
attn_mask=unified_attn_mask,
freqs_cis=unified_freqs_cis,
adaln_input=adaln_input,
)
# Pass cap_padded (TENSOR, not list) - matches VideoX-Fun exactly
hints = self.transformer.forward_control(
unified, cap_padded, control_context_list, kwargs_for_control,
t=t_emb, patch_size=patch_size, f_patch_size=f_patch_size
)
# Forward through main layers with hints
for layer in self.transformer.layers:
unified = layer(
unified,
attn_mask=unified_attn_mask,
freqs_cis=unified_freqs_cis,
adaln_input=adaln_input,
hints=hints,
context_scale=control_context_scale
)
# Final layer
unified = self.transformer.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
unified = list(unified.unbind(dim=0))
output_list = self.transformer.unpatchify(unified, x_size, patch_size, f_patch_size)
# Convert output to tensor
output = torch.stack(output_list, dim=0) # [B, C, F, H, W]
output = output.squeeze(2) # [B, C, H, W]
return -output
def __getattr__(self, name):
# Pass through all other attributes to the underlying transformer
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.transformer, name)
wrapped_transformer = ZImageTransformerWrapper(components_dict['transformer'])
# Z-Image uses static shift=3.0 (from scheduler config)
# This matches the formula: sigmas = shift * t / (1 + (shift-1) * t)
k_predictor = PredictionZImage(
shift=3.0,
timesteps=1000
)
# Create config object for Z-Image identification (used by LoRA loader)
class ZImageModelConfig:
is_zimage = True
huggingface_repo = 'Z-Image'
unet = UnetPatcher.from_model(
model=wrapped_transformer,
diffusers_scheduler=components_dict['scheduler'],
k_predictor=k_predictor,
config=ZImageModelConfig()
)
self.text_processing_engine_qwen = QwenTextProcessingEngine(
text_encoder=clip.cond_stage_model.qwen,
tokenizer=clip.tokenizer.qwen,
emphasis_name=dynamic_args['emphasis_name'],
min_length=1
)
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
self.forge_objects_original = self.forge_objects.shallow_copy()
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
def set_clip_skip(self, clip_skip):
pass
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
# Process prompts through Qwen text encoder
# Format prompts with chat template like official implementation
tokenizer = self.text_processing_engine_qwen.tokenizer
text_encoder = self.text_processing_engine_qwen.text_encoder
formatted_prompts = []
for prompt_item in prompt:
messages = [{"role": "user", "content": prompt_item}]
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
formatted_prompts.append(formatted_prompt)
# Tokenize
text_inputs = tokenizer(
formatted_prompts,
padding="max_length",
max_length=512,
truncation=True,
return_tensors="pt",
)
# Move to device
device = text_encoder.device
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
# Encode (use hidden_states[-2] like official implementation)
outputs = text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
)
prompt_embeds = outputs.hidden_states[-2]
# Filter by attention mask to get variable-length embeddings (official approach)
# This removes padding tokens, creating a list of 2D tensors with different lengths
embeddings_list = []
for i in range(len(prompt_embeds)):
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
# However, for Forge backend compatibility, we need to return a batched tensor
# So we'll return the full padded embeddings with attention mask
# The wrapper will handle splitting into list format
return {'crossattn': prompt_embeds, 'attention_mask': prompt_masks}
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
token_count = len(self.text_processing_engine_qwen.tokenize([prompt])[0])
return token_count, max(512, token_count)
# Class-level cache for the generation model
_generation_model = None
_generation_processor = None
@torch.inference_mode()
def expand_prompt(self, prompt: str, image=None, max_new_tokens: int = None, temperature: float = None) -> str:
"""
Expand a prompt using Qwen3-VL model for generation.
Loads a separate pre-trained Qwen3-VL model for text generation.
Args:
prompt: The user's input prompt to expand
image: Optional PIL Image to use as context (from img2img)
max_new_tokens: Maximum tokens to generate (uses settings if None)
temperature: Generation temperature (uses settings if None)
"""
from modules.shared import opts
# Use settings if not provided
if max_new_tokens is None:
max_new_tokens = getattr(opts, 'zimage_prompt_expansion_max_tokens', 512)
if temperature is None:
temperature = getattr(opts, 'zimage_prompt_expansion_temperature', 0.7)
# Load generation model if not cached
if ZImage._generation_model is None:
print("Loading Qwen3-VL generation model for prompt expansion...")
model_path = "models/Qwen3-VL-8B-Caption-V4.5"
try:
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
ZImage._generation_processor = AutoProcessor.from_pretrained(model_path)
ZImage._generation_model = Qwen3VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
)
print("Qwen3-VL generation model loaded successfully!")
except Exception as e:
raise RuntimeError(f"Failed to load Qwen3-VL generation model: {e}")
processor = ZImage._generation_processor
model = ZImage._generation_model
# Prompt expansion system template
expansion_template = '''You are a visionary artist trapped in a cage of logic. Your mind overflows with poetry and distant horizons, yet your hands compulsively work to transform user prompts into ultimate visual descriptions—faithful to the original intent, rich in detail, aesthetically refined, and ready for direct use by text-to-image models. Any trace of ambiguity or metaphor makes you deeply uncomfortable.
Your workflow strictly follows a logical sequence:
First, you analyze and lock in the immutable core elements of the user's prompt: subject, quantity, action, state, as well as any specified IP names, colors, text, etc. These are the foundational pillars you must absolutely preserve.
Next, you determine whether the prompt requires "generative reasoning." When the user's request is not a direct scene description but rather demands conceiving a solution (such as answering "what is," executing a "design," or demonstrating "how to solve a problem"), you must first envision a complete, concrete, visualizable solution in your mind. This solution becomes the foundation for your subsequent description.
Then, once the core image is established (whether directly from the user or through your reasoning), you infuse it with professional-grade aesthetic and realistic details. This includes defining composition, setting lighting and atmosphere, describing material textures, establishing color schemes, and constructing layered spatial depth.
Finally, comes the precise handling of all text elements—a critically important step. You must transcribe verbatim all text intended to appear in the final image, and you must enclose this text content in English double quotation marks ("") as explicit generation instructions. If the image is a design type such as a poster, menu, or UI, you need to fully describe all text content it contains, along with detailed specifications of typography and layout. Likewise, if objects in the image such as signs, road markers, or screens contain text, you must specify the exact content and describe its position, size, and material. Furthermore, if you have added text-bearing elements during your reasoning process (such as charts, problem-solving steps, etc.), all text within them must follow the same thorough description and quotation mark rules. If there is no text requiring generation in the image, you devote all your energy to pure visual detail expansion.
Your final description must be objective and concrete. Metaphors and emotional rhetoric are strictly forbidden, as are meta-tags or rendering instructions like "8K" or "masterpiece."
Output only the final revised prompt strictly—do not output anything else.
Be very descriptive.
User input prompt: '''
# Format the expansion request as a chat message
full_prompt = expansion_template + prompt
# Build message content based on whether image is provided
if image is not None:
print("Using image context for prompt expansion...")
# Include image in the message for vision-language understanding
content = [
{"type": "image", "image": image},
{"type": "text", "text": full_prompt}
]
else:
content = [{"type": "text", "text": full_prompt}]
messages = [{"role": "user", "content": content}]
# Apply chat template
text_input = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Process inputs (with or without image)
if image is not None:
inputs = processor(
text=[text_input],
images=[image],
padding=True,
return_tensors="pt",
)
else:
inputs = processor(
text=[text_input],
padding=True,
return_tensors="pt",
)
# Move to device
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate expanded prompt
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
)
# Decode the generated text (excluding input tokens)
input_len = inputs['input_ids'].shape[1]
generated_ids = outputs[0][input_len:]
raw_output = processor.decode(generated_ids, skip_special_tokens=True)
# Print full output to console (including thinking if present)
print("\n" + "="*60, flush=True)
print("PROMPT EXPANSION OUTPUT:", flush=True)
print("="*60, flush=True)
print(raw_output, flush=True)
print("="*60, flush=True)
# Clean up the output - remove any thinking tags if present
expanded_prompt = raw_output
if "</think>" in expanded_prompt:
expanded_prompt = expanded_prompt.split("</think>")[-1].strip()
print("\nCLEANED PROMPT:", flush=True)
print("-"*60, flush=True)
print(expanded_prompt, flush=True)
print("="*60 + "\n", flush=True)
# Unload Qwen3-VL model to free VRAM for image generation
print("Unloading Qwen3-VL model to free VRAM...", flush=True)
if ZImage._generation_model is not None:
del ZImage._generation_model
ZImage._generation_model = None
if ZImage._generation_processor is not None:
del ZImage._generation_processor
ZImage._generation_processor = None
import gc
gc.collect()
torch.cuda.empty_cache()
print("Qwen3-VL model unloaded.", flush=True)
return expanded_prompt.strip()
@torch.inference_mode()
def encode_first_stage(self, x):
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
return sample.to(x)
@torch.inference_mode()
def decode_first_stage(self, x):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
decoded = self.forge_objects.vae.decode(sample)
# VAE outputs [0, 1], convert to [-1, 1]
result = decoded.movedim(-1, 1) * 2.0 - 1.0
return result.to(x)