Merge pull request #6 from maybleMyers/merge

merge stuff
This commit is contained in:
benjimon 2025-06-29 07:51:39 -07:00 committed by GitHub
commit a5ce2b61a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 153 additions and 282 deletions

View File

@ -39,8 +39,7 @@ class Chroma(ForgeDiffusionEngine):
text_encoder=clip.cond_stage_model.t5xxl,
tokenizer=clip.tokenizer.t5xxl,
emphasis_name=dynamic_args['emphasis_name'],
min_length=1,
end_with_pad=True
min_length=1
)
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
@ -53,10 +52,7 @@ class Chroma(ForgeDiffusionEngine):
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
cond_t5 = self.text_processing_engine_t5(prompt)
cond = dict(crossattn=cond_t5)
cond['guidance'] = torch.FloatTensor([0] * len(prompt))
return cond
return self.text_processing_engine_t5(prompt)
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):

View File

@ -25,7 +25,7 @@ from backend.diffusion_engine.flux import Flux
from backend.diffusion_engine.chroma import Chroma
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Flux]
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Chroma, Flux]
logging.getLogger("diffusers").setLevel(logging.ERROR)
@ -481,15 +481,18 @@ def split_state_dict(sd, additional_state_dicts: list = None):
return state_dict, guess
class GuessChroma:
huggingface_repo = 'Chroma'
unet_extra_config = {
'guidance_out_dim': 3072,
'guidance_hidden_dim': 5120,
'guidance_n_layers': 5
}
unet_remove_config = ['guidance_embed']
# To be removed once PR merged on huggingface_guess
chroma_is_in_huggingface_guess = hasattr(huggingface_guess.model_list, "Chroma")
if not chroma_is_in_huggingface_guess:
class GuessChroma:
huggingface_repo = 'Chroma'
unet_extra_config = {
'guidance_out_dim': 3072,
'guidance_hidden_dim': 5120,
'guidance_n_layers': 5
}
unet_remove_config = ['guidance_embed']
@torch.inference_mode()
def forge_loader(sd, additional_state_dicts=None):
try:
@ -497,7 +500,8 @@ def forge_loader(sd, additional_state_dicts=None):
except:
raise ValueError('Failed to recognize model type!')
if estimated_config.huggingface_repo == "black-forest-labs/FLUX.1-schnell" \
if not chroma_is_in_huggingface_guess \
and estimated_config.huggingface_repo == "black-forest-labs/FLUX.1-schnell" \
and "transformer" in state_dicts \
and "distilled_guidance_layer.layers.0.in_layer.bias" in state_dicts["transformer"]:
estimated_config.huggingface_repo = GuessChroma.huggingface_repo
@ -506,8 +510,7 @@ def forge_loader(sd, additional_state_dicts=None):
for x in GuessChroma.unet_remove_config:
del estimated_config.unet_config[x]
state_dicts['text_encoder'] = state_dicts['text_encoder_2']
del state_dicts['text_encoder_2']
del state_dicts['text_encoder_2']
repo_name = estimated_config.huggingface_repo
local_path = os.path.join(dir_path, 'huggingface', repo_name)
@ -562,10 +565,8 @@ def forge_loader(sd, additional_state_dicts=None):
else:
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
if estimated_config.huggingface_repo == "Chroma":
print("load Chroma model")
if not chroma_is_in_huggingface_guess and estimated_config.huggingface_repo == "Chroma":
return Chroma(estimated_config=estimated_config, huggingface_components=huggingface_components)
for M in possible_models:
if any(isinstance(estimated_config, x) for x in M.matched_guesses):
return M(estimated_config=estimated_config, huggingface_components=huggingface_components)

View File

@ -9,160 +9,7 @@ from torch import nn
from einops import rearrange, repeat
from backend.attention import attention_function
from backend.utils import fp16_fix, tensor2parameter
def attention(q, k, v, pe):
q, k = apply_rope(q, k, pe)
x = attention_function(q, k, v, q.shape[1], skip_reshape=True)
return x
def rope(pos, dim, theta):
if pos.device.type == "mps" or pos.device.type == "xpu":
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
else:
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta ** scale)
# out = torch.einsum("...n,d->...nd", pos, omega)
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
del cos_out, sin_out
# out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
b, n, d, _ = out.shape
out = out.view(b, n, d, 2, 2)
return out.float()
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
del xq_, xk_
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def timestep_embedding(t, dim, max_period=10000, time_factor=1000.0):
t = time_factor * t
half = dim // 2
# TODO: Once A trainer for flux get popular, make timestep_embedding consistent to that trainer
# Do not block CUDA steam, but having about 1e-4 differences with Flux official codes:
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
# Block CUDA steam, but consistent with official codes:
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
del freqs
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
del args
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class EmbedND(nn.Module):
def __init__(self, dim, theta, axes_dim):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
del ids, n_axes
return emb.unsqueeze(1)
class MLPEmbedder(nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x):
x = self.silu(self.in_layer(x))
return self.out_layer(x)
if hasattr(torch, 'rms_norm'):
functional_rms_norm = torch.rms_norm
else:
def functional_rms_norm(x, normalized_shape, weight, eps):
if x.dtype in [torch.bfloat16, torch.float32]:
n = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps) * weight
else:
n = torch.rsqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps).to(x.dtype) * weight
return x * n
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = None # to trigger module_profile
self.scale = nn.Parameter(torch.ones(dim))
self.eps = 1e-6
self.normalized_shape = [dim]
def forward(self, x):
if self.scale.dtype != x.dtype:
self.scale = tensor2parameter(self.scale.to(dtype=x.dtype))
return functional_rms_norm(x, self.normalized_shape, self.scale, self.eps)
class QKNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q, k, v):
del v
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(k), k.to(q)
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x, pe):
qkv = self.qkv(x)
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
B, L, _ = qkv.shape
qkv = qkv.view(B, L, 3, self.num_heads, -1)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
del q, k, v
x = self.proj(x)
return x
from backend.nn.flux import attention, rope, timestep_embedding, EmbedND, MLPEmbedder, RMSNorm, QKNorm, SelfAttention
class Approximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 4):
@ -174,12 +21,9 @@ class Approximator(nn.Module):
def forward(self, x):
x = self.in_proj(x)
for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))
x = self.out_proj(x)
return x
@dataclass
@ -188,25 +32,12 @@ class ModulationOut:
scale: torch.Tensor
gate: torch.Tensor
class Modulation(nn.Module):
def __init__(self, dim, double):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec):
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return out
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio, qkv_bias=False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
@ -215,7 +46,6 @@ class DoubleStreamBlock(nn.Module):
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
@ -268,7 +98,6 @@ class SingleStreamBlock(nn.Module):
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def forward(self, x, mod, pe):
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
@ -296,7 +125,6 @@ class LastLayer(nn.Module):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x, mod):
shift, scale = mod
@ -415,18 +243,18 @@ class IntegratedChromaTransformer2DModel(nn.Module):
idx += 2 # Advance by 2 vectors
return block_dict
def inner_forward(self, img, img_ids, txt, txt_ids, timesteps, guidance=None):
def inner_forward(self, img, img_ids, txt, txt_ids, timesteps):
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
device = img.device
dtype = img.dtype
dtype = img.dtype # torch.bfloat16
nb_double_block = len(self.double_blocks)
nb_single_block = len(self.single_blocks)
mod_index_length = nb_double_block*12 + nb_single_block*3 + 2
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(device=device, dtype=dtype)
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(device=device, dtype=dtype)
distil_guidance = timestep_embedding(torch.zeros_like(timesteps), 16).to(device=device, dtype=dtype)
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(device=device, dtype=dtype)
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1)
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
@ -435,7 +263,6 @@ class IntegratedChromaTransformer2DModel(nn.Module):
mod_vectors_dict = self.distribute_modulations(mod_vectors, nb_single_block, nb_double_block)
txt = self.txt_in(txt)
del guidance
ids = torch.cat((txt_ids, img_ids), dim=1)
del txt_ids, img_ids
pe = self.pe_embedder(ids)
@ -455,7 +282,7 @@ class IntegratedChromaTransformer2DModel(nn.Module):
img = self.final_layer(img, final_mod)
return img
def forward(self, x, timestep, context, guidance=None, **kwargs):
def forward(self, x, timestep, context, **kwargs):
bs, c, h, w = x.shape
input_device = x.device
input_dtype = x.dtype
@ -473,7 +300,7 @@ class IntegratedChromaTransformer2DModel(nn.Module):
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=input_device, dtype=input_dtype)
del input_device, input_dtype
out = self.inner_forward(img, img_ids, context, txt_ids, timestep, guidance)
out = self.inner_forward(img, img_ids, context, txt_ids, timestep)
del img, img_ids, txt_ids, timestep, context
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
del h_len, w_len, bs

View File

@ -139,7 +139,7 @@ class ForgeOperations:
if prefix + 'weight' in state_dict:
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy))
if prefix + 'scale_weight' in state_dict:
self.scale_weight = torch.nn.Parameter(state_dict[prefix + 'scale_weight'])
self.scale_weight = torch.nn.Parameter(state_dict[prefix + 'scale_weight'])
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy

View File

@ -5,6 +5,9 @@ from modules.ui_components import InputAccordion
import modules.scripts as scripts
from modules.torch_utils import float64
from concurrent.futures import ThreadPoolExecutor
from scipy.ndimage import convolve
from joblib import Parallel, delayed, cpu_count
class SoftInpaintingSettings:
def __init__(self,
@ -244,7 +247,76 @@ def apply_masks(
return masks_for_overlay
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
def weighted_histogram_filter_single_pixel(idx, img, kernel, kernel_center, percentile_min, percentile_max, min_width):
"""
Apply the weighted histogram filter to a single pixel.
This function is now refactored to be accessible for parallelization.
"""
idx = np.array(idx)
kernel_min = -kernel_center
kernel_max = np.array(kernel.shape) - kernel_center
# Precompute the minimum and maximum valid indices for the kernel
min_index = np.maximum(0, idx + kernel_min)
max_index = np.minimum(np.array(img.shape), idx + kernel_max)
window_shape = max_index - min_index
# Initialize values and weights arrays
values = []
weights = []
for window_tup in np.ndindex(*window_shape):
window_index = np.array(window_tup)
image_index = window_index + min_index
centered_kernel_index = image_index - idx
kernel_index = centered_kernel_index + kernel_center
values.append(img[tuple(image_index)])
weights.append(kernel[tuple(kernel_index)])
# Convert to NumPy arrays
values = np.array(values)
weights = np.array(weights)
# Sort values and weights by values
sorted_indices = np.argsort(values)
values = values[sorted_indices]
weights = weights[sorted_indices]
# Calculate cumulative weights
cumulative_weights = np.cumsum(weights)
# Define window boundaries
sum_weights = cumulative_weights[-1]
window_min = sum_weights * percentile_min
window_max = sum_weights * percentile_max
window_width = window_max - window_min
# Ensure window is at least `min_width` wide
if window_width < min_width:
window_center = (window_min + window_max) / 2
window_min = window_center - min_width / 2
window_max = window_center + min_width / 2
if window_max > sum_weights:
window_max = sum_weights
window_min = sum_weights - min_width
if window_min < 0:
window_min = 0
window_max = min_width
# Calculate overlap for each value
overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
overlap_end = np.minimum(window_max, cumulative_weights)
overlap = np.maximum(0, overlap_end - overlap_start)
# Weighted average calculation
result = np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
return result
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0, n_jobs=-1):
"""
Generalization convolution filter capable of applying
weighted mean, median, maximum, and minimum filters
@ -271,101 +343,74 @@ def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, pe
(nparray): A filtered copy of the input image "img", a 2-D array of floats.
"""
# Converts an index tuple into a vector.
def vec(x):
return np.array(x)
kernel_min = -kernel_center
kernel_max = vec(kernel.shape) - kernel_center
# Ensure kernel_center is a 1D array
if isinstance(kernel_center, int):
kernel_center = np.array([kernel_center, kernel_center])
elif len(kernel_center) == 1:
kernel_center = np.array([kernel_center[0], kernel_center[0]])
kernel_radius = max(kernel_center)
padded_img = np.pad(img, kernel_radius, mode='constant', constant_values=0)
img_out = np.zeros_like(img)
img_shape = img.shape
pixel_coords = [(i, j) for i in range(img_shape[0]) for j in range(img_shape[1])]
def weighted_histogram_filter_single(idx):
idx = vec(idx)
min_index = np.maximum(0, idx + kernel_min)
max_index = np.minimum(vec(img.shape), idx + kernel_max)
window_shape = max_index - min_index
"""
Single-pixel weighted histogram calculation.
"""
row, col = idx
idx = (row + kernel_radius, col + kernel_radius)
min_index = np.array(idx) - kernel_center
max_index = min_index + kernel.shape
class WeightedElement:
"""
An element of the histogram, its weight
and bounds.
"""
window = padded_img[min_index[0]:max_index[0], min_index[1]:max_index[1]]
window_values = window.flatten()
window_weights = kernel.flatten()
def __init__(self, value, weight):
self.value: float = value
self.weight: float = weight
self.window_min: float = 0.0
self.window_max: float = 1.0
sorted_indices = np.argsort(window_values)
values = window_values[sorted_indices]
weights = window_weights[sorted_indices]
# Collect the values in the image as WeightedElements,
# weighted by their corresponding kernel values.
values = []
for window_tup in np.ndindex(tuple(window_shape)):
window_index = vec(window_tup)
image_index = window_index + min_index
centered_kernel_index = image_index - idx
kernel_index = centered_kernel_index + kernel_center
element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
values.append(element)
cumulative_weights = np.cumsum(weights)
sum_weights = cumulative_weights[-1]
window_min = max(0, sum_weights * percentile_min)
window_max = min(sum_weights, sum_weights * percentile_max)
def sort_key(x: WeightedElement):
return x.value
values.sort(key=sort_key)
# Calculate the height of the stack (sum)
# and each sample's range they occupy in the stack
sum = 0
for i in range(len(values)):
values[i].window_min = sum
sum += values[i].weight
values[i].window_max = sum
# Calculate what range of this stack ("window")
# we want to get the weighted average across.
window_min = sum * percentile_min
window_max = sum * percentile_max
window_width = window_max - window_min
# Ensure the window is within the stack and at least a certain size.
if window_width < min_width:
window_center = (window_min + window_max) / 2
window_min = window_center - min_width / 2
window_max = window_center + min_width / 2
window_min = max(0, window_center - min_width / 2)
window_max = min(sum_weights, window_center + min_width / 2)
if window_max > sum:
window_max = sum
window_min = sum - min_width
overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
overlap_end = np.minimum(window_max, cumulative_weights)
overlap = np.maximum(0, overlap_end - overlap_start)
if window_min < 0:
window_min = 0
window_max = min_width
return np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
value = 0
value_weight = 0
# Split pixel_coords into equal chunks based on n_jobs
n_jobs = -1
if cpu_count() > 6:
n_jobs = 6 # More than 6 isn't worth unless it's more than 3000x3000px
# Get the weighted average of all the samples
# that overlap with the window, weighted
# by the size of their overlap.
for i in range(len(values)):
if window_min >= values[i].window_max:
continue
if window_max <= values[i].window_min:
break
chunk_size = len(pixel_coords) // n_jobs
pixel_chunks = [pixel_coords[i:i + chunk_size] for i in range(0, len(pixel_coords), chunk_size)]
s = max(window_min, values[i].window_min)
e = min(window_max, values[i].window_max)
w = e - s
# joblib to process chunks in parallel
def process_chunk(chunk):
chunk_result = {}
for idx in chunk:
chunk_result[idx] = weighted_histogram_filter_single(idx)
return chunk_result
value += values[i].value * w
value_weight += w
results = Parallel(n_jobs=n_jobs, backend="loky")( # loky is fastest in my configuration
delayed(process_chunk)(chunk) for chunk in pixel_chunks
)
return value / value_weight if value_weight != 0 else 0
img_out = img.copy()
# Apply the kernel operation over each pixel.
for index in np.ndindex(img.shape):
img_out[index] = weighted_histogram_filter_single(index)
# Combine results into the output image
for chunk_result in results:
for (row, col), value in chunk_result.items():
img_out[row, col] = value
return img_out
@ -485,7 +530,7 @@ el_ids = SoftInpaintingSettings(
class Script(scripts.Script):
def __init__(self):
# self.section = "inpaint"
self.section = "inpaint"
self.masks_for_overlay = None
self.overlay_images = None

View File

@ -98,4 +98,6 @@ class Script(scripts.Script):
processed = Processed(p, result_images, seed, initial_info)
p.n_iter = upscale_count
return processed