diff --git a/backend/diffusion_engine/chroma.py b/backend/diffusion_engine/chroma.py index de708347..a10256d7 100644 --- a/backend/diffusion_engine/chroma.py +++ b/backend/diffusion_engine/chroma.py @@ -39,8 +39,7 @@ class Chroma(ForgeDiffusionEngine): text_encoder=clip.cond_stage_model.t5xxl, tokenizer=clip.tokenizer.t5xxl, emphasis_name=dynamic_args['emphasis_name'], - min_length=1, - end_with_pad=True + min_length=1 ) self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None) @@ -53,10 +52,7 @@ class Chroma(ForgeDiffusionEngine): @torch.inference_mode() def get_learned_conditioning(self, prompt: list[str]): memory_management.load_model_gpu(self.forge_objects.clip.patcher) - cond_t5 = self.text_processing_engine_t5(prompt) - cond = dict(crossattn=cond_t5) - cond['guidance'] = torch.FloatTensor([0] * len(prompt)) - return cond + return self.text_processing_engine_t5(prompt) @torch.inference_mode() def get_prompt_lengths_on_ui(self, prompt): diff --git a/backend/loader.py b/backend/loader.py index 2d802adf..aab5bf19 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -25,7 +25,7 @@ from backend.diffusion_engine.flux import Flux from backend.diffusion_engine.chroma import Chroma -possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Flux] +possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Chroma, Flux] logging.getLogger("diffusers").setLevel(logging.ERROR) @@ -481,15 +481,18 @@ def split_state_dict(sd, additional_state_dicts: list = None): return state_dict, guess -class GuessChroma: - huggingface_repo = 'Chroma' - unet_extra_config = { - 'guidance_out_dim': 3072, - 'guidance_hidden_dim': 5120, - 'guidance_n_layers': 5 - } - unet_remove_config = ['guidance_embed'] +# To be removed once PR merged on huggingface_guess +chroma_is_in_huggingface_guess = hasattr(huggingface_guess.model_list, "Chroma") +if not chroma_is_in_huggingface_guess: + class GuessChroma: + huggingface_repo = 'Chroma' + unet_extra_config = { + 'guidance_out_dim': 3072, + 'guidance_hidden_dim': 5120, + 'guidance_n_layers': 5 + } + unet_remove_config = ['guidance_embed'] @torch.inference_mode() def forge_loader(sd, additional_state_dicts=None): try: @@ -497,7 +500,8 @@ def forge_loader(sd, additional_state_dicts=None): except: raise ValueError('Failed to recognize model type!') - if estimated_config.huggingface_repo == "black-forest-labs/FLUX.1-schnell" \ + if not chroma_is_in_huggingface_guess \ + and estimated_config.huggingface_repo == "black-forest-labs/FLUX.1-schnell" \ and "transformer" in state_dicts \ and "distilled_guidance_layer.layers.0.in_layer.bias" in state_dicts["transformer"]: estimated_config.huggingface_repo = GuessChroma.huggingface_repo @@ -506,8 +510,7 @@ def forge_loader(sd, additional_state_dicts=None): for x in GuessChroma.unet_remove_config: del estimated_config.unet_config[x] state_dicts['text_encoder'] = state_dicts['text_encoder_2'] - del state_dicts['text_encoder_2'] - + del state_dicts['text_encoder_2'] repo_name = estimated_config.huggingface_repo local_path = os.path.join(dir_path, 'huggingface', repo_name) @@ -562,10 +565,8 @@ def forge_loader(sd, additional_state_dicts=None): else: huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type) - if estimated_config.huggingface_repo == "Chroma": - print("load Chroma model") + if not chroma_is_in_huggingface_guess and estimated_config.huggingface_repo == "Chroma": return Chroma(estimated_config=estimated_config, huggingface_components=huggingface_components) - for M in possible_models: if any(isinstance(estimated_config, x) for x in M.matched_guesses): return M(estimated_config=estimated_config, huggingface_components=huggingface_components) diff --git a/backend/nn/chroma.py b/backend/nn/chroma.py index 64bae0a7..5aa12e7a 100644 --- a/backend/nn/chroma.py +++ b/backend/nn/chroma.py @@ -9,160 +9,7 @@ from torch import nn from einops import rearrange, repeat from backend.attention import attention_function from backend.utils import fp16_fix, tensor2parameter - - -def attention(q, k, v, pe): - q, k = apply_rope(q, k, pe) - x = attention_function(q, k, v, q.shape[1], skip_reshape=True) - return x - - -def rope(pos, dim, theta): - if pos.device.type == "mps" or pos.device.type == "xpu": - scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim - else: - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim - omega = 1.0 / (theta ** scale) - - # out = torch.einsum("...n,d->...nd", pos, omega) - out = pos.unsqueeze(-1) * omega.unsqueeze(0) - - cos_out = torch.cos(out) - sin_out = torch.sin(out) - out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) - del cos_out, sin_out - - # out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) - b, n, d, _ = out.shape - out = out.view(b, n, d, 2, 2) - - return out.float() - - -def apply_rope(xq, xk, freqs_cis): - xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - del xq_, xk_ - return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - - -def timestep_embedding(t, dim, max_period=10000, time_factor=1000.0): - t = time_factor * t - half = dim // 2 - - # TODO: Once A trainer for flux get popular, make timestep_embedding consistent to that trainer - - # Do not block CUDA steam, but having about 1e-4 differences with Flux official codes: - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) - - # Block CUDA steam, but consistent with official codes: - # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) - - args = t[:, None].float() * freqs[None] - del freqs - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - del args - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - if torch.is_floating_point(t): - embedding = embedding.to(t) - return embedding - - -class EmbedND(nn.Module): - def __init__(self, dim, theta, axes_dim): - super().__init__() - self.dim = dim - self.theta = theta - self.axes_dim = axes_dim - - def forward(self, ids): - n_axes = ids.shape[-1] - emb = torch.cat( - [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], - dim=-3, - ) - del ids, n_axes - return emb.unsqueeze(1) - - -class MLPEmbedder(nn.Module): - def __init__(self, in_dim, hidden_dim): - super().__init__() - self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) - self.silu = nn.SiLU() - self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) - - def forward(self, x): - x = self.silu(self.in_layer(x)) - return self.out_layer(x) - - -if hasattr(torch, 'rms_norm'): - functional_rms_norm = torch.rms_norm -else: - def functional_rms_norm(x, normalized_shape, weight, eps): - if x.dtype in [torch.bfloat16, torch.float32]: - n = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps) * weight - else: - n = torch.rsqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps).to(x.dtype) * weight - return x * n - - -class RMSNorm(nn.Module): - def __init__(self, dim): - super().__init__() - self.weight = None # to trigger module_profile - self.scale = nn.Parameter(torch.ones(dim)) - self.eps = 1e-6 - self.normalized_shape = [dim] - - def forward(self, x): - if self.scale.dtype != x.dtype: - self.scale = tensor2parameter(self.scale.to(dtype=x.dtype)) - return functional_rms_norm(x, self.normalized_shape, self.scale, self.eps) - - -class QKNorm(nn.Module): - def __init__(self, dim): - super().__init__() - self.query_norm = RMSNorm(dim) - self.key_norm = RMSNorm(dim) - - def forward(self, q, k, v): - del v - q = self.query_norm(q) - k = self.key_norm(k) - return q.to(k), k.to(q) - - -class SelfAttention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.norm = QKNorm(head_dim) - self.proj = nn.Linear(dim, dim) - - def forward(self, x, pe): - qkv = self.qkv(x) - - # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) - B, L, _ = qkv.shape - qkv = qkv.view(B, L, 3, self.num_heads, -1) - q, k, v = qkv.permute(2, 0, 3, 1, 4) - del qkv - - q, k = self.norm(q, k, v) - - x = attention(q, k, v, pe=pe) - del q, k, v - - x = self.proj(x) - return x +from backend.nn.flux import attention, rope, timestep_embedding, EmbedND, MLPEmbedder, RMSNorm, QKNorm, SelfAttention class Approximator(nn.Module): def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 4): @@ -174,12 +21,9 @@ class Approximator(nn.Module): def forward(self, x): x = self.in_proj(x) - for layer, norms in zip(self.layers, self.norms): x = x + layer(norms(x)) - x = self.out_proj(x) - return x @dataclass @@ -188,25 +32,12 @@ class ModulationOut: scale: torch.Tensor gate: torch.Tensor -class Modulation(nn.Module): - def __init__(self, dim, double): - super().__init__() - self.is_double = double - self.multiplier = 6 if double else 3 - self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) - - def forward(self, vec): - out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) - return out - - class DoubleStreamBlock(nn.Module): def __init__(self, hidden_size, num_heads, mlp_ratio, qkv_bias=False): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) self.num_heads = num_heads self.hidden_size = hidden_size - self.img_mod = Modulation(hidden_size, double=True) self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -215,7 +46,6 @@ class DoubleStreamBlock(nn.Module): nn.GELU(approximate="tanh"), nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) - self.txt_mod = Modulation(hidden_size, double=True) self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -268,7 +98,6 @@ class SingleStreamBlock(nn.Module): self.hidden_size = hidden_size self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.mlp_act = nn.GELU(approximate="tanh") - self.modulation = Modulation(hidden_size, double=False) def forward(self, x, mod, pe): x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift @@ -296,7 +125,6 @@ class LastLayer(nn.Module): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) - self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) def forward(self, x, mod): shift, scale = mod @@ -415,18 +243,18 @@ class IntegratedChromaTransformer2DModel(nn.Module): idx += 2 # Advance by 2 vectors return block_dict - def inner_forward(self, img, img_ids, txt, txt_ids, timesteps, guidance=None): + def inner_forward(self, img, img_ids, txt, txt_ids, timesteps): if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") img = self.img_in(img) device = img.device - dtype = img.dtype + dtype = img.dtype # torch.bfloat16 nb_double_block = len(self.double_blocks) nb_single_block = len(self.single_blocks) mod_index_length = nb_double_block*12 + nb_single_block*3 + 2 distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(device=device, dtype=dtype) - distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(device=device, dtype=dtype) + distil_guidance = timestep_embedding(torch.zeros_like(timesteps), 16).to(device=device, dtype=dtype) modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(device=device, dtype=dtype) modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1) timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1) @@ -435,7 +263,6 @@ class IntegratedChromaTransformer2DModel(nn.Module): mod_vectors_dict = self.distribute_modulations(mod_vectors, nb_single_block, nb_double_block) txt = self.txt_in(txt) - del guidance ids = torch.cat((txt_ids, img_ids), dim=1) del txt_ids, img_ids pe = self.pe_embedder(ids) @@ -455,7 +282,7 @@ class IntegratedChromaTransformer2DModel(nn.Module): img = self.final_layer(img, final_mod) return img - def forward(self, x, timestep, context, guidance=None, **kwargs): + def forward(self, x, timestep, context, **kwargs): bs, c, h, w = x.shape input_device = x.device input_dtype = x.dtype @@ -473,7 +300,7 @@ class IntegratedChromaTransformer2DModel(nn.Module): img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=input_device, dtype=input_dtype) del input_device, input_dtype - out = self.inner_forward(img, img_ids, context, txt_ids, timestep, guidance) + out = self.inner_forward(img, img_ids, context, txt_ids, timestep) del img, img_ids, txt_ids, timestep, context out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w] del h_len, w_len, bs diff --git a/backend/operations.py b/backend/operations.py index acd54b32..7a7e6500 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -139,7 +139,7 @@ class ForgeOperations: if prefix + 'weight' in state_dict: self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy)) if prefix + 'scale_weight' in state_dict: - self.scale_weight = torch.nn.Parameter(state_dict[prefix + 'scale_weight']) + self.scale_weight = torch.nn.Parameter(state_dict[prefix + 'scale_weight']) if prefix + 'bias' in state_dict: self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) del self.dummy diff --git a/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py index d5716d37..d55dd3d0 100644 --- a/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py +++ b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py @@ -5,6 +5,9 @@ from modules.ui_components import InputAccordion import modules.scripts as scripts from modules.torch_utils import float64 +from concurrent.futures import ThreadPoolExecutor +from scipy.ndimage import convolve +from joblib import Parallel, delayed, cpu_count class SoftInpaintingSettings: def __init__(self, @@ -244,7 +247,76 @@ def apply_masks( return masks_for_overlay -def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): + + +def weighted_histogram_filter_single_pixel(idx, img, kernel, kernel_center, percentile_min, percentile_max, min_width): + """ + Apply the weighted histogram filter to a single pixel. + This function is now refactored to be accessible for parallelization. + """ + idx = np.array(idx) + kernel_min = -kernel_center + kernel_max = np.array(kernel.shape) - kernel_center + + # Precompute the minimum and maximum valid indices for the kernel + min_index = np.maximum(0, idx + kernel_min) + max_index = np.minimum(np.array(img.shape), idx + kernel_max) + window_shape = max_index - min_index + + # Initialize values and weights arrays + values = [] + weights = [] + + for window_tup in np.ndindex(*window_shape): + window_index = np.array(window_tup) + image_index = window_index + min_index + centered_kernel_index = image_index - idx + kernel_index = centered_kernel_index + kernel_center + values.append(img[tuple(image_index)]) + weights.append(kernel[tuple(kernel_index)]) + + # Convert to NumPy arrays + values = np.array(values) + weights = np.array(weights) + + # Sort values and weights by values + sorted_indices = np.argsort(values) + values = values[sorted_indices] + weights = weights[sorted_indices] + + # Calculate cumulative weights + cumulative_weights = np.cumsum(weights) + + # Define window boundaries + sum_weights = cumulative_weights[-1] + window_min = sum_weights * percentile_min + window_max = sum_weights * percentile_max + window_width = window_max - window_min + + # Ensure window is at least `min_width` wide + if window_width < min_width: + window_center = (window_min + window_max) / 2 + window_min = window_center - min_width / 2 + window_max = window_center + min_width / 2 + + if window_max > sum_weights: + window_max = sum_weights + window_min = sum_weights - min_width + + if window_min < 0: + window_min = 0 + window_max = min_width + + # Calculate overlap for each value + overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1]))) + overlap_end = np.minimum(window_max, cumulative_weights) + overlap = np.maximum(0, overlap_end - overlap_start) + + # Weighted average calculation + result = np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0 + return result + +def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0, n_jobs=-1): """ Generalization convolution filter capable of applying weighted mean, median, maximum, and minimum filters @@ -271,101 +343,74 @@ def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, pe (nparray): A filtered copy of the input image "img", a 2-D array of floats. """ - # Converts an index tuple into a vector. - def vec(x): - return np.array(x) - - kernel_min = -kernel_center - kernel_max = vec(kernel.shape) - kernel_center + # Ensure kernel_center is a 1D array + if isinstance(kernel_center, int): + kernel_center = np.array([kernel_center, kernel_center]) + elif len(kernel_center) == 1: + kernel_center = np.array([kernel_center[0], kernel_center[0]]) + kernel_radius = max(kernel_center) + padded_img = np.pad(img, kernel_radius, mode='constant', constant_values=0) + img_out = np.zeros_like(img) + img_shape = img.shape + pixel_coords = [(i, j) for i in range(img_shape[0]) for j in range(img_shape[1])] def weighted_histogram_filter_single(idx): - idx = vec(idx) - min_index = np.maximum(0, idx + kernel_min) - max_index = np.minimum(vec(img.shape), idx + kernel_max) - window_shape = max_index - min_index + """ + Single-pixel weighted histogram calculation. + """ + row, col = idx + idx = (row + kernel_radius, col + kernel_radius) + min_index = np.array(idx) - kernel_center + max_index = min_index + kernel.shape - class WeightedElement: - """ - An element of the histogram, its weight - and bounds. - """ + window = padded_img[min_index[0]:max_index[0], min_index[1]:max_index[1]] + window_values = window.flatten() + window_weights = kernel.flatten() - def __init__(self, value, weight): - self.value: float = value - self.weight: float = weight - self.window_min: float = 0.0 - self.window_max: float = 1.0 + sorted_indices = np.argsort(window_values) + values = window_values[sorted_indices] + weights = window_weights[sorted_indices] - # Collect the values in the image as WeightedElements, - # weighted by their corresponding kernel values. - values = [] - for window_tup in np.ndindex(tuple(window_shape)): - window_index = vec(window_tup) - image_index = window_index + min_index - centered_kernel_index = image_index - idx - kernel_index = centered_kernel_index + kernel_center - element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) - values.append(element) + cumulative_weights = np.cumsum(weights) + sum_weights = cumulative_weights[-1] + window_min = max(0, sum_weights * percentile_min) + window_max = min(sum_weights, sum_weights * percentile_max) - def sort_key(x: WeightedElement): - return x.value - - values.sort(key=sort_key) - - # Calculate the height of the stack (sum) - # and each sample's range they occupy in the stack - sum = 0 - for i in range(len(values)): - values[i].window_min = sum - sum += values[i].weight - values[i].window_max = sum - - # Calculate what range of this stack ("window") - # we want to get the weighted average across. - window_min = sum * percentile_min - window_max = sum * percentile_max window_width = window_max - window_min - - # Ensure the window is within the stack and at least a certain size. if window_width < min_width: window_center = (window_min + window_max) / 2 - window_min = window_center - min_width / 2 - window_max = window_center + min_width / 2 + window_min = max(0, window_center - min_width / 2) + window_max = min(sum_weights, window_center + min_width / 2) - if window_max > sum: - window_max = sum - window_min = sum - min_width + overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1]))) + overlap_end = np.minimum(window_max, cumulative_weights) + overlap = np.maximum(0, overlap_end - overlap_start) - if window_min < 0: - window_min = 0 - window_max = min_width + return np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0 - value = 0 - value_weight = 0 + # Split pixel_coords into equal chunks based on n_jobs + n_jobs = -1 + if cpu_count() > 6: + n_jobs = 6 # More than 6 isn't worth unless it's more than 3000x3000px - # Get the weighted average of all the samples - # that overlap with the window, weighted - # by the size of their overlap. - for i in range(len(values)): - if window_min >= values[i].window_max: - continue - if window_max <= values[i].window_min: - break + chunk_size = len(pixel_coords) // n_jobs + pixel_chunks = [pixel_coords[i:i + chunk_size] for i in range(0, len(pixel_coords), chunk_size)] - s = max(window_min, values[i].window_min) - e = min(window_max, values[i].window_max) - w = e - s + # joblib to process chunks in parallel + def process_chunk(chunk): + chunk_result = {} + for idx in chunk: + chunk_result[idx] = weighted_histogram_filter_single(idx) + return chunk_result - value += values[i].value * w - value_weight += w + results = Parallel(n_jobs=n_jobs, backend="loky")( # loky is fastest in my configuration + delayed(process_chunk)(chunk) for chunk in pixel_chunks + ) - return value / value_weight if value_weight != 0 else 0 - - img_out = img.copy() - - # Apply the kernel operation over each pixel. - for index in np.ndindex(img.shape): - img_out[index] = weighted_histogram_filter_single(index) + # Combine results into the output image + for chunk_result in results: + for (row, col), value in chunk_result.items(): + img_out[row, col] = value return img_out @@ -485,7 +530,7 @@ el_ids = SoftInpaintingSettings( class Script(scripts.Script): def __init__(self): - # self.section = "inpaint" + self.section = "inpaint" self.masks_for_overlay = None self.overlay_images = None diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index e614c23b..64e34cd9 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -98,4 +98,6 @@ class Script(scripts.Script): processed = Processed(p, result_images, seed, initial_info) + p.n_iter = upscale_count + return processed