mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-21 21:14:23 +08:00
commit
a5ce2b61a1
@ -39,8 +39,7 @@ class Chroma(ForgeDiffusionEngine):
|
||||
text_encoder=clip.cond_stage_model.t5xxl,
|
||||
tokenizer=clip.tokenizer.t5xxl,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
min_length=1,
|
||||
end_with_pad=True
|
||||
min_length=1
|
||||
)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
@ -53,10 +52,7 @@ class Chroma(ForgeDiffusionEngine):
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
cond_t5 = self.text_processing_engine_t5(prompt)
|
||||
cond = dict(crossattn=cond_t5)
|
||||
cond['guidance'] = torch.FloatTensor([0] * len(prompt))
|
||||
return cond
|
||||
return self.text_processing_engine_t5(prompt)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
|
||||
@ -25,7 +25,7 @@ from backend.diffusion_engine.flux import Flux
|
||||
from backend.diffusion_engine.chroma import Chroma
|
||||
|
||||
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Flux]
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Chroma, Flux]
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
@ -481,15 +481,18 @@ def split_state_dict(sd, additional_state_dicts: list = None):
|
||||
|
||||
return state_dict, guess
|
||||
|
||||
class GuessChroma:
|
||||
huggingface_repo = 'Chroma'
|
||||
unet_extra_config = {
|
||||
'guidance_out_dim': 3072,
|
||||
'guidance_hidden_dim': 5120,
|
||||
'guidance_n_layers': 5
|
||||
}
|
||||
unet_remove_config = ['guidance_embed']
|
||||
# To be removed once PR merged on huggingface_guess
|
||||
chroma_is_in_huggingface_guess = hasattr(huggingface_guess.model_list, "Chroma")
|
||||
|
||||
if not chroma_is_in_huggingface_guess:
|
||||
class GuessChroma:
|
||||
huggingface_repo = 'Chroma'
|
||||
unet_extra_config = {
|
||||
'guidance_out_dim': 3072,
|
||||
'guidance_hidden_dim': 5120,
|
||||
'guidance_n_layers': 5
|
||||
}
|
||||
unet_remove_config = ['guidance_embed']
|
||||
@torch.inference_mode()
|
||||
def forge_loader(sd, additional_state_dicts=None):
|
||||
try:
|
||||
@ -497,7 +500,8 @@ def forge_loader(sd, additional_state_dicts=None):
|
||||
except:
|
||||
raise ValueError('Failed to recognize model type!')
|
||||
|
||||
if estimated_config.huggingface_repo == "black-forest-labs/FLUX.1-schnell" \
|
||||
if not chroma_is_in_huggingface_guess \
|
||||
and estimated_config.huggingface_repo == "black-forest-labs/FLUX.1-schnell" \
|
||||
and "transformer" in state_dicts \
|
||||
and "distilled_guidance_layer.layers.0.in_layer.bias" in state_dicts["transformer"]:
|
||||
estimated_config.huggingface_repo = GuessChroma.huggingface_repo
|
||||
@ -506,8 +510,7 @@ def forge_loader(sd, additional_state_dicts=None):
|
||||
for x in GuessChroma.unet_remove_config:
|
||||
del estimated_config.unet_config[x]
|
||||
state_dicts['text_encoder'] = state_dicts['text_encoder_2']
|
||||
del state_dicts['text_encoder_2']
|
||||
|
||||
del state_dicts['text_encoder_2']
|
||||
repo_name = estimated_config.huggingface_repo
|
||||
|
||||
local_path = os.path.join(dir_path, 'huggingface', repo_name)
|
||||
@ -562,10 +565,8 @@ def forge_loader(sd, additional_state_dicts=None):
|
||||
else:
|
||||
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
|
||||
|
||||
if estimated_config.huggingface_repo == "Chroma":
|
||||
print("load Chroma model")
|
||||
if not chroma_is_in_huggingface_guess and estimated_config.huggingface_repo == "Chroma":
|
||||
return Chroma(estimated_config=estimated_config, huggingface_components=huggingface_components)
|
||||
|
||||
for M in possible_models:
|
||||
if any(isinstance(estimated_config, x) for x in M.matched_guesses):
|
||||
return M(estimated_config=estimated_config, huggingface_components=huggingface_components)
|
||||
|
||||
@ -9,160 +9,7 @@ from torch import nn
|
||||
from einops import rearrange, repeat
|
||||
from backend.attention import attention_function
|
||||
from backend.utils import fp16_fix, tensor2parameter
|
||||
|
||||
|
||||
def attention(q, k, v, pe):
|
||||
q, k = apply_rope(q, k, pe)
|
||||
x = attention_function(q, k, v, q.shape[1], skip_reshape=True)
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos, dim, theta):
|
||||
if pos.device.type == "mps" or pos.device.type == "xpu":
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
|
||||
else:
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta ** scale)
|
||||
|
||||
# out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
|
||||
|
||||
cos_out = torch.cos(out)
|
||||
sin_out = torch.sin(out)
|
||||
out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||
del cos_out, sin_out
|
||||
|
||||
# out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
b, n, d, _ = out.shape
|
||||
out = out.view(b, n, d, 2, 2)
|
||||
|
||||
return out.float()
|
||||
|
||||
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
del xq_, xk_
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def timestep_embedding(t, dim, max_period=10000, time_factor=1000.0):
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
|
||||
# TODO: Once A trainer for flux get popular, make timestep_embedding consistent to that trainer
|
||||
|
||||
# Do not block CUDA steam, but having about 1e-4 differences with Flux official codes:
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
||||
|
||||
# Block CUDA steam, but consistent with official codes:
|
||||
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
del freqs
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
del args
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim, theta, axes_dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids):
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
del ids, n_axes
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.silu(self.in_layer(x))
|
||||
return self.out_layer(x)
|
||||
|
||||
|
||||
if hasattr(torch, 'rms_norm'):
|
||||
functional_rms_norm = torch.rms_norm
|
||||
else:
|
||||
def functional_rms_norm(x, normalized_shape, weight, eps):
|
||||
if x.dtype in [torch.bfloat16, torch.float32]:
|
||||
n = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps) * weight
|
||||
else:
|
||||
n = torch.rsqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps).to(x.dtype) * weight
|
||||
return x * n
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.weight = None # to trigger module_profile
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
self.eps = 1e-6
|
||||
self.normalized_shape = [dim]
|
||||
|
||||
def forward(self, x):
|
||||
if self.scale.dtype != x.dtype:
|
||||
self.scale = tensor2parameter(self.scale.to(dtype=x.dtype))
|
||||
return functional_rms_norm(x, self.normalized_shape, self.scale, self.eps)
|
||||
|
||||
|
||||
class QKNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim)
|
||||
self.key_norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
del v
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(k), k.to(q)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x, pe):
|
||||
qkv = self.qkv(x)
|
||||
|
||||
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
B, L, _ = qkv.shape
|
||||
qkv = qkv.view(B, L, 3, self.num_heads, -1)
|
||||
q, k, v = qkv.permute(2, 0, 3, 1, 4)
|
||||
del qkv
|
||||
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
x = attention(q, k, v, pe=pe)
|
||||
del q, k, v
|
||||
|
||||
x = self.proj(x)
|
||||
return x
|
||||
from backend.nn.flux import attention, rope, timestep_embedding, EmbedND, MLPEmbedder, RMSNorm, QKNorm, SelfAttention
|
||||
|
||||
class Approximator(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 4):
|
||||
@ -174,12 +21,9 @@ class Approximator(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.in_proj(x)
|
||||
|
||||
for layer, norms in zip(self.layers, self.norms):
|
||||
x = x + layer(norms(x))
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x
|
||||
|
||||
@dataclass
|
||||
@ -188,25 +32,12 @@ class ModulationOut:
|
||||
scale: torch.Tensor
|
||||
gate: torch.Tensor
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim, double):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def forward(self, vec):
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
return out
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio, qkv_bias=False):
|
||||
super().__init__()
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
@ -215,7 +46,6 @@ class DoubleStreamBlock(nn.Module):
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
@ -268,7 +98,6 @@ class SingleStreamBlock(nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def forward(self, x, mod, pe):
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
@ -296,7 +125,6 @@ class LastLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x, mod):
|
||||
shift, scale = mod
|
||||
@ -415,18 +243,18 @@ class IntegratedChromaTransformer2DModel(nn.Module):
|
||||
idx += 2 # Advance by 2 vectors
|
||||
return block_dict
|
||||
|
||||
def inner_forward(self, img, img_ids, txt, txt_ids, timesteps, guidance=None):
|
||||
def inner_forward(self, img, img_ids, txt, txt_ids, timesteps):
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
img = self.img_in(img)
|
||||
device = img.device
|
||||
dtype = img.dtype
|
||||
dtype = img.dtype # torch.bfloat16
|
||||
nb_double_block = len(self.double_blocks)
|
||||
nb_single_block = len(self.single_blocks)
|
||||
|
||||
mod_index_length = nb_double_block*12 + nb_single_block*3 + 2
|
||||
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(device=device, dtype=dtype)
|
||||
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(device=device, dtype=dtype)
|
||||
distil_guidance = timestep_embedding(torch.zeros_like(timesteps), 16).to(device=device, dtype=dtype)
|
||||
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(device=device, dtype=dtype)
|
||||
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1)
|
||||
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
|
||||
@ -435,7 +263,6 @@ class IntegratedChromaTransformer2DModel(nn.Module):
|
||||
mod_vectors_dict = self.distribute_modulations(mod_vectors, nb_single_block, nb_double_block)
|
||||
|
||||
txt = self.txt_in(txt)
|
||||
del guidance
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
del txt_ids, img_ids
|
||||
pe = self.pe_embedder(ids)
|
||||
@ -455,7 +282,7 @@ class IntegratedChromaTransformer2DModel(nn.Module):
|
||||
img = self.final_layer(img, final_mod)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, guidance=None, **kwargs):
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
input_device = x.device
|
||||
input_dtype = x.dtype
|
||||
@ -473,7 +300,7 @@ class IntegratedChromaTransformer2DModel(nn.Module):
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=input_device, dtype=input_dtype)
|
||||
del input_device, input_dtype
|
||||
out = self.inner_forward(img, img_ids, context, txt_ids, timestep, guidance)
|
||||
out = self.inner_forward(img, img_ids, context, txt_ids, timestep)
|
||||
del img, img_ids, txt_ids, timestep, context
|
||||
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
|
||||
del h_len, w_len, bs
|
||||
|
||||
@ -139,7 +139,7 @@ class ForgeOperations:
|
||||
if prefix + 'weight' in state_dict:
|
||||
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy))
|
||||
if prefix + 'scale_weight' in state_dict:
|
||||
self.scale_weight = torch.nn.Parameter(state_dict[prefix + 'scale_weight'])
|
||||
self.scale_weight = torch.nn.Parameter(state_dict[prefix + 'scale_weight'])
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
||||
del self.dummy
|
||||
|
||||
@ -5,6 +5,9 @@ from modules.ui_components import InputAccordion
|
||||
import modules.scripts as scripts
|
||||
from modules.torch_utils import float64
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from scipy.ndimage import convolve
|
||||
from joblib import Parallel, delayed, cpu_count
|
||||
|
||||
class SoftInpaintingSettings:
|
||||
def __init__(self,
|
||||
@ -244,7 +247,76 @@ def apply_masks(
|
||||
return masks_for_overlay
|
||||
|
||||
|
||||
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
|
||||
|
||||
|
||||
def weighted_histogram_filter_single_pixel(idx, img, kernel, kernel_center, percentile_min, percentile_max, min_width):
|
||||
"""
|
||||
Apply the weighted histogram filter to a single pixel.
|
||||
This function is now refactored to be accessible for parallelization.
|
||||
"""
|
||||
idx = np.array(idx)
|
||||
kernel_min = -kernel_center
|
||||
kernel_max = np.array(kernel.shape) - kernel_center
|
||||
|
||||
# Precompute the minimum and maximum valid indices for the kernel
|
||||
min_index = np.maximum(0, idx + kernel_min)
|
||||
max_index = np.minimum(np.array(img.shape), idx + kernel_max)
|
||||
window_shape = max_index - min_index
|
||||
|
||||
# Initialize values and weights arrays
|
||||
values = []
|
||||
weights = []
|
||||
|
||||
for window_tup in np.ndindex(*window_shape):
|
||||
window_index = np.array(window_tup)
|
||||
image_index = window_index + min_index
|
||||
centered_kernel_index = image_index - idx
|
||||
kernel_index = centered_kernel_index + kernel_center
|
||||
values.append(img[tuple(image_index)])
|
||||
weights.append(kernel[tuple(kernel_index)])
|
||||
|
||||
# Convert to NumPy arrays
|
||||
values = np.array(values)
|
||||
weights = np.array(weights)
|
||||
|
||||
# Sort values and weights by values
|
||||
sorted_indices = np.argsort(values)
|
||||
values = values[sorted_indices]
|
||||
weights = weights[sorted_indices]
|
||||
|
||||
# Calculate cumulative weights
|
||||
cumulative_weights = np.cumsum(weights)
|
||||
|
||||
# Define window boundaries
|
||||
sum_weights = cumulative_weights[-1]
|
||||
window_min = sum_weights * percentile_min
|
||||
window_max = sum_weights * percentile_max
|
||||
window_width = window_max - window_min
|
||||
|
||||
# Ensure window is at least `min_width` wide
|
||||
if window_width < min_width:
|
||||
window_center = (window_min + window_max) / 2
|
||||
window_min = window_center - min_width / 2
|
||||
window_max = window_center + min_width / 2
|
||||
|
||||
if window_max > sum_weights:
|
||||
window_max = sum_weights
|
||||
window_min = sum_weights - min_width
|
||||
|
||||
if window_min < 0:
|
||||
window_min = 0
|
||||
window_max = min_width
|
||||
|
||||
# Calculate overlap for each value
|
||||
overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
|
||||
overlap_end = np.minimum(window_max, cumulative_weights)
|
||||
overlap = np.maximum(0, overlap_end - overlap_start)
|
||||
|
||||
# Weighted average calculation
|
||||
result = np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
|
||||
return result
|
||||
|
||||
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0, n_jobs=-1):
|
||||
"""
|
||||
Generalization convolution filter capable of applying
|
||||
weighted mean, median, maximum, and minimum filters
|
||||
@ -271,101 +343,74 @@ def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, pe
|
||||
(nparray): A filtered copy of the input image "img", a 2-D array of floats.
|
||||
"""
|
||||
|
||||
# Converts an index tuple into a vector.
|
||||
def vec(x):
|
||||
return np.array(x)
|
||||
|
||||
kernel_min = -kernel_center
|
||||
kernel_max = vec(kernel.shape) - kernel_center
|
||||
# Ensure kernel_center is a 1D array
|
||||
if isinstance(kernel_center, int):
|
||||
kernel_center = np.array([kernel_center, kernel_center])
|
||||
elif len(kernel_center) == 1:
|
||||
kernel_center = np.array([kernel_center[0], kernel_center[0]])
|
||||
kernel_radius = max(kernel_center)
|
||||
padded_img = np.pad(img, kernel_radius, mode='constant', constant_values=0)
|
||||
img_out = np.zeros_like(img)
|
||||
img_shape = img.shape
|
||||
pixel_coords = [(i, j) for i in range(img_shape[0]) for j in range(img_shape[1])]
|
||||
|
||||
def weighted_histogram_filter_single(idx):
|
||||
idx = vec(idx)
|
||||
min_index = np.maximum(0, idx + kernel_min)
|
||||
max_index = np.minimum(vec(img.shape), idx + kernel_max)
|
||||
window_shape = max_index - min_index
|
||||
"""
|
||||
Single-pixel weighted histogram calculation.
|
||||
"""
|
||||
row, col = idx
|
||||
idx = (row + kernel_radius, col + kernel_radius)
|
||||
min_index = np.array(idx) - kernel_center
|
||||
max_index = min_index + kernel.shape
|
||||
|
||||
class WeightedElement:
|
||||
"""
|
||||
An element of the histogram, its weight
|
||||
and bounds.
|
||||
"""
|
||||
window = padded_img[min_index[0]:max_index[0], min_index[1]:max_index[1]]
|
||||
window_values = window.flatten()
|
||||
window_weights = kernel.flatten()
|
||||
|
||||
def __init__(self, value, weight):
|
||||
self.value: float = value
|
||||
self.weight: float = weight
|
||||
self.window_min: float = 0.0
|
||||
self.window_max: float = 1.0
|
||||
sorted_indices = np.argsort(window_values)
|
||||
values = window_values[sorted_indices]
|
||||
weights = window_weights[sorted_indices]
|
||||
|
||||
# Collect the values in the image as WeightedElements,
|
||||
# weighted by their corresponding kernel values.
|
||||
values = []
|
||||
for window_tup in np.ndindex(tuple(window_shape)):
|
||||
window_index = vec(window_tup)
|
||||
image_index = window_index + min_index
|
||||
centered_kernel_index = image_index - idx
|
||||
kernel_index = centered_kernel_index + kernel_center
|
||||
element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
|
||||
values.append(element)
|
||||
cumulative_weights = np.cumsum(weights)
|
||||
sum_weights = cumulative_weights[-1]
|
||||
window_min = max(0, sum_weights * percentile_min)
|
||||
window_max = min(sum_weights, sum_weights * percentile_max)
|
||||
|
||||
def sort_key(x: WeightedElement):
|
||||
return x.value
|
||||
|
||||
values.sort(key=sort_key)
|
||||
|
||||
# Calculate the height of the stack (sum)
|
||||
# and each sample's range they occupy in the stack
|
||||
sum = 0
|
||||
for i in range(len(values)):
|
||||
values[i].window_min = sum
|
||||
sum += values[i].weight
|
||||
values[i].window_max = sum
|
||||
|
||||
# Calculate what range of this stack ("window")
|
||||
# we want to get the weighted average across.
|
||||
window_min = sum * percentile_min
|
||||
window_max = sum * percentile_max
|
||||
window_width = window_max - window_min
|
||||
|
||||
# Ensure the window is within the stack and at least a certain size.
|
||||
if window_width < min_width:
|
||||
window_center = (window_min + window_max) / 2
|
||||
window_min = window_center - min_width / 2
|
||||
window_max = window_center + min_width / 2
|
||||
window_min = max(0, window_center - min_width / 2)
|
||||
window_max = min(sum_weights, window_center + min_width / 2)
|
||||
|
||||
if window_max > sum:
|
||||
window_max = sum
|
||||
window_min = sum - min_width
|
||||
overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
|
||||
overlap_end = np.minimum(window_max, cumulative_weights)
|
||||
overlap = np.maximum(0, overlap_end - overlap_start)
|
||||
|
||||
if window_min < 0:
|
||||
window_min = 0
|
||||
window_max = min_width
|
||||
return np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
|
||||
|
||||
value = 0
|
||||
value_weight = 0
|
||||
# Split pixel_coords into equal chunks based on n_jobs
|
||||
n_jobs = -1
|
||||
if cpu_count() > 6:
|
||||
n_jobs = 6 # More than 6 isn't worth unless it's more than 3000x3000px
|
||||
|
||||
# Get the weighted average of all the samples
|
||||
# that overlap with the window, weighted
|
||||
# by the size of their overlap.
|
||||
for i in range(len(values)):
|
||||
if window_min >= values[i].window_max:
|
||||
continue
|
||||
if window_max <= values[i].window_min:
|
||||
break
|
||||
chunk_size = len(pixel_coords) // n_jobs
|
||||
pixel_chunks = [pixel_coords[i:i + chunk_size] for i in range(0, len(pixel_coords), chunk_size)]
|
||||
|
||||
s = max(window_min, values[i].window_min)
|
||||
e = min(window_max, values[i].window_max)
|
||||
w = e - s
|
||||
# joblib to process chunks in parallel
|
||||
def process_chunk(chunk):
|
||||
chunk_result = {}
|
||||
for idx in chunk:
|
||||
chunk_result[idx] = weighted_histogram_filter_single(idx)
|
||||
return chunk_result
|
||||
|
||||
value += values[i].value * w
|
||||
value_weight += w
|
||||
results = Parallel(n_jobs=n_jobs, backend="loky")( # loky is fastest in my configuration
|
||||
delayed(process_chunk)(chunk) for chunk in pixel_chunks
|
||||
)
|
||||
|
||||
return value / value_weight if value_weight != 0 else 0
|
||||
|
||||
img_out = img.copy()
|
||||
|
||||
# Apply the kernel operation over each pixel.
|
||||
for index in np.ndindex(img.shape):
|
||||
img_out[index] = weighted_histogram_filter_single(index)
|
||||
# Combine results into the output image
|
||||
for chunk_result in results:
|
||||
for (row, col), value in chunk_result.items():
|
||||
img_out[row, col] = value
|
||||
|
||||
return img_out
|
||||
|
||||
@ -485,7 +530,7 @@ el_ids = SoftInpaintingSettings(
|
||||
|
||||
class Script(scripts.Script):
|
||||
def __init__(self):
|
||||
# self.section = "inpaint"
|
||||
self.section = "inpaint"
|
||||
self.masks_for_overlay = None
|
||||
self.overlay_images = None
|
||||
|
||||
|
||||
@ -98,4 +98,6 @@ class Script(scripts.Script):
|
||||
|
||||
processed = Processed(p, result_images, seed, initial_info)
|
||||
|
||||
p.n_iter = upscale_count
|
||||
|
||||
return processed
|
||||
|
||||
Loading…
Reference in New Issue
Block a user