mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-19 21:01:02 +08:00
400 lines
13 KiB
Python
400 lines
13 KiB
Python
import math
|
|
from typing import Callable
|
|
|
|
import torch
|
|
from einops import rearrange, repeat
|
|
from torch import Tensor
|
|
|
|
from .model import Chroma
|
|
|
|
|
|
def get_noise(
|
|
num_samples: int,
|
|
height: int,
|
|
width: int,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
seed: int,
|
|
latent_depth: int = 16,
|
|
spatial_compression: int = 8,
|
|
):
|
|
return torch.randn(
|
|
num_samples,
|
|
latent_depth,
|
|
# allow for packing
|
|
math.ceil(height / spatial_compression),
|
|
math.ceil(width / spatial_compression),
|
|
device=device,
|
|
dtype=dtype,
|
|
generator=torch.Generator(device=device).manual_seed(seed),
|
|
)
|
|
|
|
|
|
def time_shift(mu: float, sigma: float, t: Tensor):
|
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
|
|
|
|
|
def get_lin_function(
|
|
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
|
) -> Callable[[float], float]:
|
|
m = (y2 - y1) / (x2 - x1)
|
|
b = y1 - m * x1
|
|
return lambda x: m * x + b
|
|
|
|
|
|
def get_schedule(
|
|
num_steps: int,
|
|
image_seq_len: int,
|
|
base_shift: float = 0.5,
|
|
max_shift: float = 1.15,
|
|
shift: bool = True,
|
|
) -> list[float]:
|
|
# extra step for zero
|
|
timesteps = torch.linspace(1, 0, num_steps + 1)
|
|
|
|
# shifting the schedule to favor high timesteps for higher signal images
|
|
if shift:
|
|
# eastimate mu based on linear estimation between two points
|
|
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
|
timesteps = time_shift(mu, 1.0, timesteps)
|
|
|
|
return timesteps.tolist()
|
|
|
|
|
|
def denoise(
|
|
model: Chroma,
|
|
# model input
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
# guidance
|
|
txt: Tensor,
|
|
# guidance ID
|
|
txt_ids: Tensor,
|
|
# mask
|
|
txt_mask: Tensor,
|
|
# sampling parameters
|
|
timesteps: list[float],
|
|
guidance: float = 0.0,
|
|
):
|
|
# this is ignored for schnell
|
|
guidance_vec = torch.full(
|
|
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
|
|
)
|
|
|
|
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
|
pred = model(
|
|
img=img,
|
|
img_ids=img_ids,
|
|
txt=txt,
|
|
txt_ids=txt_ids,
|
|
txt_mask=txt_mask,
|
|
timesteps=t_vec,
|
|
guidance=guidance_vec,
|
|
)
|
|
|
|
img = img + (t_prev - t_curr) * pred
|
|
|
|
return img
|
|
|
|
|
|
def denoise_batched_timesteps(
|
|
model: Chroma,
|
|
# model input
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
# guidance
|
|
txt: Tensor,
|
|
# guidance ID
|
|
txt_ids: Tensor,
|
|
# mask
|
|
txt_mask: Tensor,
|
|
# sampling parameters
|
|
timesteps: Tensor, # Shape: (B, N), where N is the number of time points
|
|
guidance: float = 4.0,
|
|
):
|
|
"""
|
|
Performs ODE solving using the Euler method with potentially different
|
|
timestep sequences for each sample in the batch.
|
|
|
|
Args:
|
|
model: The flow matching model.
|
|
img: Input tensor (e.g., noise) shape (B, C, H, W).
|
|
img_ids: Image IDs tensor, shape (B, ...).
|
|
txt: Text conditioning tensor, shape (B, L, D).
|
|
txt_ids: Text IDs tensor, shape (B, L).
|
|
txt_mask: Text mask tensor, shape (B, L).
|
|
timesteps: Tensor containing the time points for each batch sample.
|
|
Shape (B, N), where B is the batch size and N is the
|
|
number of time points (e.g., [t_start, ..., t_end]).
|
|
Time should generally decrease (e.g., [1.0, 0.8, ..., 0.0]).
|
|
guidance: Classifier-free guidance strength.
|
|
Returns:
|
|
Denoised image tensor, shape (B, C, H, W).
|
|
"""
|
|
batch_size = img.shape[0]
|
|
num_time_points = timesteps.shape[1]
|
|
num_steps = num_time_points - 1 # Number of integration steps
|
|
|
|
if timesteps.shape[0] != batch_size:
|
|
raise ValueError(
|
|
f"Batch size mismatch: img has {batch_size}, "
|
|
f"but timesteps has {timesteps.shape[0]}"
|
|
)
|
|
if timesteps.ndim != 2:
|
|
raise ValueError(
|
|
f"timesteps tensor must be 2D (B, N), but got shape {timesteps.shape}"
|
|
)
|
|
|
|
# Guidance vector remains the same for all elements in this specific call
|
|
guidance_vec = torch.full(
|
|
(batch_size,), guidance, device=img.device, dtype=img.dtype
|
|
)
|
|
|
|
# Ensure timesteps tensor is on the same device and dtype as img
|
|
timesteps = timesteps.to(device=img.device, dtype=img.dtype)
|
|
|
|
# Iterate through the integration steps (from step 0 to N-2)
|
|
for i in range(num_steps):
|
|
# Get the current time for each batch element
|
|
t_curr_batch = timesteps[:, i] # Shape: (B,)
|
|
# Get the next time for each batch element
|
|
t_next_batch = timesteps[:, i + 1] # Shape: (B,)
|
|
|
|
# Model prediction using the current time for each batch element
|
|
# Your model already accepts batched timesteps (shape B,)
|
|
pred = model(
|
|
img=img,
|
|
img_ids=img_ids,
|
|
txt=txt,
|
|
txt_ids=txt_ids,
|
|
txt_mask=txt_mask,
|
|
timesteps=t_curr_batch, # Pass the (B,) tensor of current times
|
|
guidance=guidance_vec,
|
|
)
|
|
|
|
# Calculate the step size (dt) for each batch element
|
|
# dt = t_next - t_curr (Note: if time goes 1->0, dt will be negative)
|
|
dt_batch = t_next_batch - t_curr_batch # Shape: (B,)
|
|
|
|
# Reshape dt for broadcasting: (B,) -> (B, 1, 1)
|
|
dt_batch_reshaped = dt_batch.view(batch_size, 1, 1)
|
|
|
|
# Euler step update: x_{t+1} = x_t + dt * v(x_t, t)
|
|
img = img + dt_batch_reshaped * pred
|
|
|
|
return img
|
|
|
|
|
|
def denoise_cfg(
|
|
model: Chroma,
|
|
# model input
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
# guidance
|
|
txt: Tensor,
|
|
neg_txt: Tensor,
|
|
# guidance ID
|
|
txt_ids: Tensor,
|
|
neg_txt_ids: Tensor,
|
|
# mask
|
|
txt_mask: Tensor,
|
|
neg_txt_mask: Tensor,
|
|
# sampling parameters
|
|
timesteps: list[float],
|
|
guidance: float = 4.0,
|
|
cfg: float = 2.0,
|
|
first_n_steps_without_cfg: int = 4,
|
|
):
|
|
# this is ignored for schnell
|
|
guidance_vec = torch.full(
|
|
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
|
|
)
|
|
step_count = 0
|
|
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
|
pred = model(
|
|
img=img,
|
|
img_ids=img_ids,
|
|
txt=txt,
|
|
txt_ids=txt_ids,
|
|
txt_mask=txt_mask,
|
|
timesteps=t_vec,
|
|
guidance=guidance_vec,
|
|
)
|
|
# disable cfg for x steps before using cfg
|
|
if step_count < first_n_steps_without_cfg or first_n_steps_without_cfg == -1:
|
|
img = img.to(pred) + (t_prev - t_curr) * pred
|
|
else:
|
|
pred_neg = model(
|
|
img=img,
|
|
img_ids=img_ids,
|
|
txt=neg_txt,
|
|
txt_ids=neg_txt_ids,
|
|
txt_mask=neg_txt_mask,
|
|
timesteps=t_vec,
|
|
guidance=guidance_vec,
|
|
)
|
|
|
|
pred_cfg = pred_neg + (pred - pred_neg) * cfg
|
|
|
|
img = img + (t_prev - t_curr) * pred_cfg
|
|
|
|
step_count += 1
|
|
|
|
return img
|
|
|
|
|
|
def denoise_cfg_batched_timesteps(
|
|
model: Chroma,
|
|
# model input
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
# guidance
|
|
txt: Tensor,
|
|
neg_txt: Tensor,
|
|
# guidance ID
|
|
txt_ids: Tensor,
|
|
neg_txt_ids: Tensor,
|
|
# mask
|
|
txt_mask: Tensor,
|
|
neg_txt_mask: Tensor,
|
|
# sampling parameters
|
|
timesteps: Tensor, # Shape: (B, N), where N is the number of time points
|
|
guidance: float = 0.0,
|
|
cfg: float = 2.0,
|
|
first_n_steps_without_cfg: int = 4,
|
|
):
|
|
"""
|
|
Performs ODE solving using the Euler method with Classifier-Free Guidance (CFG)
|
|
and potentially different timestep sequences for each sample in the batch.
|
|
|
|
Args:
|
|
model: The flow matching model.
|
|
img: Input tensor (e.g., noise) shape (B, C, H, W).
|
|
img_ids: Image IDs tensor, shape (B, ...).
|
|
txt: Positive text conditioning tensor, shape (B, L, D).
|
|
neg_txt: Negative text conditioning tensor, shape (B, L, D).
|
|
txt_ids: Positive text IDs tensor, shape (B, L).
|
|
neg_txt_ids: Negative text IDs tensor, shape (B, L).
|
|
txt_mask: Positive text mask tensor, shape (B, L).
|
|
neg_txt_mask: Negative text mask tensor, shape (B, L).
|
|
timesteps: Tensor containing the time points for each batch sample.
|
|
Shape (B, N), where B is the batch size and N is the
|
|
number of time points (e.g., [t_start, ..., t_end]).
|
|
Time should generally decrease (e.g., [1.0, 0.8, ..., 0.0]).
|
|
guidance: Guidance strength passed to the model (potentially ignored).
|
|
cfg: Classifier-Free Guidance scale. A value of 1.0 disables CFG.
|
|
first_n_steps_without_cfg: The number of initial integration steps
|
|
(intervals) for which CFG will *not* be
|
|
applied, even if cfg > 1.0. Set to 0 to
|
|
apply CFG from the start, or -1 to always
|
|
apply CFG (if cfg > 1.0).
|
|
Returns:
|
|
Denoised image tensor, shape (B, C, H, W).
|
|
"""
|
|
batch_size = img.shape[0]
|
|
num_time_points = timesteps.shape[1]
|
|
num_steps = num_time_points - 1 # Number of integration steps
|
|
|
|
# --- Input Validation ---
|
|
if timesteps.shape[0] != batch_size:
|
|
raise ValueError(
|
|
f"Batch size mismatch: img has {batch_size}, "
|
|
f"but timesteps has {timesteps.shape[0]}"
|
|
)
|
|
if timesteps.ndim != 2:
|
|
raise ValueError(
|
|
f"timesteps tensor must be 2D (B, N), but got shape {timesteps.shape}"
|
|
)
|
|
# Check consistency of conditioning tensors
|
|
for name, tensor in [
|
|
("txt", txt),
|
|
("neg_txt", neg_txt),
|
|
("txt_ids", txt_ids),
|
|
("neg_txt_ids", neg_txt_ids),
|
|
("txt_mask", txt_mask),
|
|
("neg_txt_mask", neg_txt_mask),
|
|
]:
|
|
if tensor.shape[0] != batch_size:
|
|
raise ValueError(
|
|
f"Batch size mismatch: img has {batch_size}, "
|
|
f"but {name} has {tensor.shape[0]}"
|
|
)
|
|
# --- End Validation ---
|
|
|
|
# Guidance vector (its effect depends on the model)
|
|
guidance_vec = torch.full(
|
|
(batch_size,), guidance, device=img.device, dtype=img.dtype
|
|
)
|
|
|
|
# Ensure timesteps tensor is on the same device and dtype as img
|
|
timesteps = timesteps.to(device=img.device, dtype=img.dtype)
|
|
|
|
# Iterate through the integration steps (intervals)
|
|
for i in range(num_steps):
|
|
# Get the current time for each batch element
|
|
t_curr_batch = timesteps[:, i] # Shape: (B,)
|
|
# Get the next time for each batch element
|
|
t_next_batch = timesteps[:, i + 1] # Shape: (B,)
|
|
|
|
# --- Positive Prediction ---
|
|
pred_pos = model(
|
|
img=img,
|
|
img_ids=img_ids,
|
|
txt=txt,
|
|
txt_ids=txt_ids,
|
|
txt_mask=txt_mask,
|
|
timesteps=t_curr_batch, # Batched timesteps
|
|
guidance=guidance_vec,
|
|
)
|
|
|
|
# --- CFG Logic ---
|
|
# Determine if CFG should be applied in this step
|
|
# Apply CFG if cfg > 1.0 AND (we are past the initial steps OR first_n_steps_without_cfg is -1)
|
|
apply_cfg = cfg > 1.0 and (
|
|
i >= first_n_steps_without_cfg or first_n_steps_without_cfg == -1
|
|
)
|
|
|
|
if apply_cfg:
|
|
# --- Negative Prediction ---
|
|
pred_neg = model(
|
|
img=img, # Use the *same* input image state as for positive pred
|
|
img_ids=img_ids,
|
|
txt=neg_txt,
|
|
txt_ids=neg_txt_ids,
|
|
txt_mask=neg_txt_mask,
|
|
timesteps=t_curr_batch, # Use the same batched timesteps
|
|
guidance=guidance_vec, # Pass guidance here too
|
|
)
|
|
# Combine predictions using CFG formula
|
|
# pred = uncond + cfg * (cond - uncond)
|
|
pred_final = pred_neg + cfg * (pred_pos - pred_neg)
|
|
else:
|
|
# If not applying CFG, use the positive prediction directly
|
|
pred_final = pred_pos
|
|
# --- End CFG Logic ---
|
|
|
|
# Calculate the step size (dt) for each batch element
|
|
dt_batch = t_next_batch - t_curr_batch # Shape: (B,)
|
|
|
|
# Reshape dt for broadcasting: (B,) -> (B, 1, 1)
|
|
dt_batch_reshaped = dt_batch.view(batch_size, 1, 1)
|
|
|
|
# Euler step update: x_{t+1} = x_t + dt * v(x_t, t)
|
|
# Ensure img is on the correct device/dtype if pred_final changes it (unlikely but safe)
|
|
img = img.to(pred_final) + dt_batch_reshaped * pred_final
|
|
|
|
return img
|
|
|
|
|
|
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
|
return rearrange(
|
|
x,
|
|
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
|
h=math.ceil(height / 16),
|
|
w=math.ceil(width / 16),
|
|
ph=2,
|
|
pw=2,
|
|
)
|