From cc6ee589397f087c11096074bf6ea3a22ba1b0d6 Mon Sep 17 00:00:00 2001 From: maybleMyers Date: Tue, 2 Sep 2025 04:02:24 -0700 Subject: [PATCH] fix 16s --- backend/modules/k_diffusion_extra.py | 190 +++++++++++++++++++-------- 1 file changed, 135 insertions(+), 55 deletions(-) diff --git a/backend/modules/k_diffusion_extra.py b/backend/modules/k_diffusion_extra.py index 041bbd30..f162e755 100644 --- a/backend/modules/k_diffusion_extra.py +++ b/backend/modules/k_diffusion_extra.py @@ -4,6 +4,7 @@ import torch import sys import os import math +from itertools import permutations, combinations from tqdm import trange @@ -74,6 +75,25 @@ def phi(j: int, neg_h: float): phi_ = math.exp(neg_h) * (neg_h**-j) * (1 - incomp_gamma_/gamma_) return phi_ +# Additional phi functions for higher orders +def res_phi_4(h): + """Fourth phi function for RES samplers.""" + if h.abs().max() < 1e-6: + return 1/24 - h / 120 + h**2 / 720 + return (torch.exp(h) - 1 - h - h**2 / 2 - h**3 / 6) / (h**4) + +def res_phi_5(h): + """Fifth phi function for RES samplers.""" + if h.abs().max() < 1e-6: + return 1/120 - h / 720 + h**2 / 5040 + return (torch.exp(h) - 1 - h - h**2 / 2 - h**3 / 6 - h**4 / 24) / (h**5) + +def res_phi_6(h): + """Sixth phi function for RES samplers.""" + if h.abs().max() < 1e-6: + return 1/720 - h / 5040 + h**2 / 40320 + return (torch.exp(h) - 1 - h - h**2 / 2 - h**3 / 6 - h**4 / 24 - h**5 / 120) / (h**6) + class Phi: """RES4LYF Phi class - copied from working implementation""" def __init__(self, h, c, analytic_solution=False): @@ -122,6 +142,47 @@ def res_phi_3(h): return (torch.exp(h) - 1 - h - h**2 / 2) / (h**3) +# RES4LYF helper functions - EXACT copies +def theta_numerator(j, cd, ci, ck, cj, cl): + if j == 2: + numerator = -cj * cd * ck * cl + if j == 3: + numerator = 2 * (cj * ck * cd + cj*ck*cl + ck*cd*cl + cd*cl*cj) + if j == 4: + numerator = -6*(cj*ck + cj*cd + cj*cl + ck*cd + ck*cl + cd*cl) + if j == 5: + numerator = 24 * (cj + ck + cl + cd) + if j == 6: + numerator = -120 + return numerator + +def theta(j, cd, ci, ck, cj, cl): + if j == 2: + numerator = -cj * cd * ck * cl + if j == 3: + numerator = 2 * (cj * ck * cd + cj*ck*cl + ck*cd*cl + cd*cl*cj) + if j == 4: + numerator = -6*(cj*ck + cj*cd + cj*cl + ck*cd + ck*cl + cd*cl) + if j == 5: + numerator = 24 * (cj + ck + cl + cd) + if j == 6: + numerator = -120 + return numerator / (ci * (ci - cj) * (ci - ck) * (ci - cl) * (ci - cd)) + +def prod_diff(cj, ck, cl=None, cd=None): + if cl is None and cd is None: + return cj * (cj - ck) + if cd is None: + return cj * (cj - ck) * (cj - cl) + else: + return cj * (cj - ck) * (cj - cl) * (cj - cd) + +def denominator(ci, *args): + result = ci + for arg in args: + result *= (ci - arg) + return result + def get_res_6s_coefficients(h): """Get RES 6s coefficients - copied exactly from RES4LYF""" # Original c-values from RES4LYF (with division by zero issue) @@ -180,6 +241,51 @@ def get_res_6s_coefficients(h): return a, b, ci +def get_res_16s_coefficients(h): + """Get RES 16s coefficients - EXACT copy from RES4LYF""" + use_analytic_solution = False # Use same as RES4LYF default + + c1 = 0 + c2 = c3 = c5 = c8 = c12 = 1/2 + c4 = c11 = c15 = 1/3 + c6 = c9 = c13 = 1/5 + c7 = c10 = c14 = 1/4 + c16 = 1 + ci = [c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16] + φ = Phi(h, ci, analytic_solution=use_analytic_solution) + + a3_2 = (1/2) * φ(2,3) + + a = [[0.0 for _ in range(16)] for _ in range(16)] + b = [[0.0 for _ in range(16)]] + + for i in range(3, 5): # i=3,4 j=2 + j=2 + a[i-1][j-1] = (ci[i-1]**2 / ci[j-1]) * φ(j,i) + + for i in range(5, 8): # i=5,6,7 j,k ∈ {3, 4}, j != k + jk = list(permutations([3, 4], 2)) + for j,k in jk: + a[i-1][j-1] = (-ci[i-1]**2 * ci[k-1] * φ(2,i) + 2*ci[i-1]**3 * φ(3,i)) / prod_diff(ci[j-1], ci[k-1]) + + for i in range(8, 12): # i=8,9,10,11 j,k,l ∈ {5, 6, 7}, j != k != l + jkl = list(permutations([5, 6, 7], 3)) + for j,k,l in jkl: + a[i-1][j-1] = (ci[i-1]**2 * ci[k-1] * ci[l-1] * φ(2,i) - 2*ci[i-1]**3 * (ci[k-1] + ci[l-1]) * φ(3,i) + 6*ci[i-1]**4 * φ(4,i)) / (ci[j-1] * (ci[j-1] - ci[k-1]) * (ci[j-1] - ci[l-1])) + + for i in range(12,16): # i=12,13,14,15 + jkld = list(permutations([8,9,10,11], 4)) + for j,k,l,d in jkld: + numerator = -ci[i-1]**2 * ci[d-1]*ci[k-1]*ci[l-1] * φ(2,i) + 2*ci[i-1]**3 * (ci[d-1]*ci[k-1] + ci[d-1]*ci[l-1] + ci[k-1]*ci[l-1]) * φ(3,i) - 6*ci[i-1]**4 * (ci[d-1] + ci[k-1] + ci[l-1]) * φ(4,i) + 24*ci[i-1]**5 * φ(5,i) + a[i-1][j-1] = numerator / prod_diff(ci[j-1], ci[k-1], ci[l-1], ci[d-1]) + + ijdkl = list(permutations([12,13,14,15,16], 5)) + for i,j,d,k,l in ijdkl: + b[0][i-1] = theta(2, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1]) * φ(2) + theta(3, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1])*φ(3) + theta(4, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1])*φ(4) + theta(5, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1])*φ(5) + theta(6, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1]) * φ(6) + + return a, b, ci + + # RES Samplers - Runge-Kutta Exponential Samplers @torch.no_grad() def sample_res_2s(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): @@ -260,75 +366,49 @@ def sample_res_6s(model, x, sigmas, extra_args=None, callback=None, disable=None @torch.no_grad() def sample_res_16s(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): - """RES 16-stage sampler - high-order exponential Runge-Kutta method.""" + """RES 16-stage sampler - EXACT copy of RES4LYF math.""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): sigma = sigmas[i] sigma_next = sigmas[i + 1] - h = sigma_next - sigma + h = float(sigma_next - sigma) - # Get phi functions - phi_1 = res_phi_1(h) - phi_2 = res_phi_2(h) - phi_3 = res_phi_3(h) + # Get coefficients using EXACT RES4LYF calculation + a, b, ci = get_res_16s_coefficients(h) - # Stage 1 - denoised = model(x, sigma * s_in, **extra_args) - d_1 = to_d(x, sigma, denoised) + # Convert to proper format + num_stages = len(ci) - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised}) + # Stage computations - exact RK method + k = [] # Stage derivatives - # High-order multi-stage method with multiple intermediate evaluations - # Using a simplified 8-stage approach that approximates 16-stage behavior - stages = [] - c_vals = [0, 1/8, 1/4, 3/8, 1/2, 5/8, 3/4, 7/8, 1.0] - - for stage in range(1, 9): # 8 stages - c = c_vals[stage] - - if stage == 1: - # First intermediate stage - x_stage = denoised + sigma * c * phi_1 * d_1 - sigma_stage = sigma + h * c - denoised_stage = model(x_stage, sigma_stage * s_in, **extra_args) - d_stage = to_d(x_stage, sigma_stage, denoised_stage) - stages.append(d_stage) - - elif stage == 2: - # Second stage using first intermediate - a21 = c * phi_1 - a22 = (c**2 / c_vals[1]) * phi_2 - x_stage = denoised + sigma * (a21 * d_1 + a22 * stages[0]) - sigma_stage = sigma + h * c - denoised_stage = model(x_stage, sigma_stage * s_in, **extra_args) - d_stage = to_d(x_stage, sigma_stage, denoised_stage) - stages.append(d_stage) + for stage in range(num_stages): + if stage == 0: + # First stage at current point + denoised = model(x, sigma * s_in, **extra_args) + k_i = to_d(x, sigma, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised}) else: - # Higher stages - simplified combination - weights = [phi_1 * c / stage] # Weight for d_1 - x_stage = denoised + sigma * weights[0] * d_1 + # Intermediate stages + x_stage = x + for j in range(stage): + x_stage = x_stage + h * a[stage][j] * k[j] - # Add contributions from previous stages - for j, prev_d in enumerate(stages[:min(stage-1, 3)]): # Limit to avoid instability - weight = phi_2 * c * (0.5 ** (j + 1)) / stage - x_stage += sigma * weight * prev_d - - sigma_stage = sigma + h * c + sigma_stage = sigma + h * ci[stage] denoised_stage = model(x_stage, sigma_stage * s_in, **extra_args) - d_stage = to_d(x_stage, sigma_stage, denoised_stage) - stages.append(d_stage) - - # Final combination with high-order weights - final_d = phi_1 * d_1 - for j, stage_d in enumerate(stages[:6]): # Use first 6 stages - weight = phi_2 * (0.6 ** j) / (j + 2) # Decreasing weights - final_d += weight * stage_d + k_i = to_d(x_stage, sigma_stage, denoised_stage) - # Final step - x = denoised + sigma_next * final_d + k.append(k_i) + + # Final integration step using RK formula: x_new = x + h * sum(b_i * k_i) + x_new = x + for j in range(num_stages): + x_new = x_new + h * b[0][j] * k[j] # Note: b is nested list for RES 16s + + x = x_new return x