stable-diffusion-webui-forge/backend/diffusion_engine/chroma_dct_sampling.py
2025-08-28 08:50:58 -07:00

400 lines
13 KiB
Python

import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from .model import Chroma
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
latent_depth: int = 16,
spatial_compression: int = 8,
):
return torch.randn(
num_samples,
latent_depth,
# allow for packing
math.ceil(height / spatial_compression),
math.ceil(width / spatial_compression),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Chroma,
# model input
img: Tensor,
img_ids: Tensor,
# guidance
txt: Tensor,
# guidance ID
txt_ids: Tensor,
# mask
txt_mask: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 0.0,
):
# this is ignored for schnell
guidance_vec = torch.full(
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
txt_mask=txt_mask,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
return img
def denoise_batched_timesteps(
model: Chroma,
# model input
img: Tensor,
img_ids: Tensor,
# guidance
txt: Tensor,
# guidance ID
txt_ids: Tensor,
# mask
txt_mask: Tensor,
# sampling parameters
timesteps: Tensor, # Shape: (B, N), where N is the number of time points
guidance: float = 4.0,
):
"""
Performs ODE solving using the Euler method with potentially different
timestep sequences for each sample in the batch.
Args:
model: The flow matching model.
img: Input tensor (e.g., noise) shape (B, C, H, W).
img_ids: Image IDs tensor, shape (B, ...).
txt: Text conditioning tensor, shape (B, L, D).
txt_ids: Text IDs tensor, shape (B, L).
txt_mask: Text mask tensor, shape (B, L).
timesteps: Tensor containing the time points for each batch sample.
Shape (B, N), where B is the batch size and N is the
number of time points (e.g., [t_start, ..., t_end]).
Time should generally decrease (e.g., [1.0, 0.8, ..., 0.0]).
guidance: Classifier-free guidance strength.
Returns:
Denoised image tensor, shape (B, C, H, W).
"""
batch_size = img.shape[0]
num_time_points = timesteps.shape[1]
num_steps = num_time_points - 1 # Number of integration steps
if timesteps.shape[0] != batch_size:
raise ValueError(
f"Batch size mismatch: img has {batch_size}, "
f"but timesteps has {timesteps.shape[0]}"
)
if timesteps.ndim != 2:
raise ValueError(
f"timesteps tensor must be 2D (B, N), but got shape {timesteps.shape}"
)
# Guidance vector remains the same for all elements in this specific call
guidance_vec = torch.full(
(batch_size,), guidance, device=img.device, dtype=img.dtype
)
# Ensure timesteps tensor is on the same device and dtype as img
timesteps = timesteps.to(device=img.device, dtype=img.dtype)
# Iterate through the integration steps (from step 0 to N-2)
for i in range(num_steps):
# Get the current time for each batch element
t_curr_batch = timesteps[:, i] # Shape: (B,)
# Get the next time for each batch element
t_next_batch = timesteps[:, i + 1] # Shape: (B,)
# Model prediction using the current time for each batch element
# Your model already accepts batched timesteps (shape B,)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
txt_mask=txt_mask,
timesteps=t_curr_batch, # Pass the (B,) tensor of current times
guidance=guidance_vec,
)
# Calculate the step size (dt) for each batch element
# dt = t_next - t_curr (Note: if time goes 1->0, dt will be negative)
dt_batch = t_next_batch - t_curr_batch # Shape: (B,)
# Reshape dt for broadcasting: (B,) -> (B, 1, 1)
dt_batch_reshaped = dt_batch.view(batch_size, 1, 1)
# Euler step update: x_{t+1} = x_t + dt * v(x_t, t)
img = img + dt_batch_reshaped * pred
return img
def denoise_cfg(
model: Chroma,
# model input
img: Tensor,
img_ids: Tensor,
# guidance
txt: Tensor,
neg_txt: Tensor,
# guidance ID
txt_ids: Tensor,
neg_txt_ids: Tensor,
# mask
txt_mask: Tensor,
neg_txt_mask: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
cfg: float = 2.0,
first_n_steps_without_cfg: int = 4,
):
# this is ignored for schnell
guidance_vec = torch.full(
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
)
step_count = 0
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
txt_mask=txt_mask,
timesteps=t_vec,
guidance=guidance_vec,
)
# disable cfg for x steps before using cfg
if step_count < first_n_steps_without_cfg or first_n_steps_without_cfg == -1:
img = img.to(pred) + (t_prev - t_curr) * pred
else:
pred_neg = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
txt_mask=neg_txt_mask,
timesteps=t_vec,
guidance=guidance_vec,
)
pred_cfg = pred_neg + (pred - pred_neg) * cfg
img = img + (t_prev - t_curr) * pred_cfg
step_count += 1
return img
def denoise_cfg_batched_timesteps(
model: Chroma,
# model input
img: Tensor,
img_ids: Tensor,
# guidance
txt: Tensor,
neg_txt: Tensor,
# guidance ID
txt_ids: Tensor,
neg_txt_ids: Tensor,
# mask
txt_mask: Tensor,
neg_txt_mask: Tensor,
# sampling parameters
timesteps: Tensor, # Shape: (B, N), where N is the number of time points
guidance: float = 0.0,
cfg: float = 2.0,
first_n_steps_without_cfg: int = 4,
):
"""
Performs ODE solving using the Euler method with Classifier-Free Guidance (CFG)
and potentially different timestep sequences for each sample in the batch.
Args:
model: The flow matching model.
img: Input tensor (e.g., noise) shape (B, C, H, W).
img_ids: Image IDs tensor, shape (B, ...).
txt: Positive text conditioning tensor, shape (B, L, D).
neg_txt: Negative text conditioning tensor, shape (B, L, D).
txt_ids: Positive text IDs tensor, shape (B, L).
neg_txt_ids: Negative text IDs tensor, shape (B, L).
txt_mask: Positive text mask tensor, shape (B, L).
neg_txt_mask: Negative text mask tensor, shape (B, L).
timesteps: Tensor containing the time points for each batch sample.
Shape (B, N), where B is the batch size and N is the
number of time points (e.g., [t_start, ..., t_end]).
Time should generally decrease (e.g., [1.0, 0.8, ..., 0.0]).
guidance: Guidance strength passed to the model (potentially ignored).
cfg: Classifier-Free Guidance scale. A value of 1.0 disables CFG.
first_n_steps_without_cfg: The number of initial integration steps
(intervals) for which CFG will *not* be
applied, even if cfg > 1.0. Set to 0 to
apply CFG from the start, or -1 to always
apply CFG (if cfg > 1.0).
Returns:
Denoised image tensor, shape (B, C, H, W).
"""
batch_size = img.shape[0]
num_time_points = timesteps.shape[1]
num_steps = num_time_points - 1 # Number of integration steps
# --- Input Validation ---
if timesteps.shape[0] != batch_size:
raise ValueError(
f"Batch size mismatch: img has {batch_size}, "
f"but timesteps has {timesteps.shape[0]}"
)
if timesteps.ndim != 2:
raise ValueError(
f"timesteps tensor must be 2D (B, N), but got shape {timesteps.shape}"
)
# Check consistency of conditioning tensors
for name, tensor in [
("txt", txt),
("neg_txt", neg_txt),
("txt_ids", txt_ids),
("neg_txt_ids", neg_txt_ids),
("txt_mask", txt_mask),
("neg_txt_mask", neg_txt_mask),
]:
if tensor.shape[0] != batch_size:
raise ValueError(
f"Batch size mismatch: img has {batch_size}, "
f"but {name} has {tensor.shape[0]}"
)
# --- End Validation ---
# Guidance vector (its effect depends on the model)
guidance_vec = torch.full(
(batch_size,), guidance, device=img.device, dtype=img.dtype
)
# Ensure timesteps tensor is on the same device and dtype as img
timesteps = timesteps.to(device=img.device, dtype=img.dtype)
# Iterate through the integration steps (intervals)
for i in range(num_steps):
# Get the current time for each batch element
t_curr_batch = timesteps[:, i] # Shape: (B,)
# Get the next time for each batch element
t_next_batch = timesteps[:, i + 1] # Shape: (B,)
# --- Positive Prediction ---
pred_pos = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
txt_mask=txt_mask,
timesteps=t_curr_batch, # Batched timesteps
guidance=guidance_vec,
)
# --- CFG Logic ---
# Determine if CFG should be applied in this step
# Apply CFG if cfg > 1.0 AND (we are past the initial steps OR first_n_steps_without_cfg is -1)
apply_cfg = cfg > 1.0 and (
i >= first_n_steps_without_cfg or first_n_steps_without_cfg == -1
)
if apply_cfg:
# --- Negative Prediction ---
pred_neg = model(
img=img, # Use the *same* input image state as for positive pred
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
txt_mask=neg_txt_mask,
timesteps=t_curr_batch, # Use the same batched timesteps
guidance=guidance_vec, # Pass guidance here too
)
# Combine predictions using CFG formula
# pred = uncond + cfg * (cond - uncond)
pred_final = pred_neg + cfg * (pred_pos - pred_neg)
else:
# If not applying CFG, use the positive prediction directly
pred_final = pred_pos
# --- End CFG Logic ---
# Calculate the step size (dt) for each batch element
dt_batch = t_next_batch - t_curr_batch # Shape: (B,)
# Reshape dt for broadcasting: (B,) -> (B, 1, 1)
dt_batch_reshaped = dt_batch.view(batch_size, 1, 1)
# Euler step update: x_{t+1} = x_t + dt * v(x_t, t)
# Ensure img is on the correct device/dtype if pred_final changes it (unlikely but safe)
img = img.to(pred_final) + dt_batch_reshaped * pred_final
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)