From 478fd7b94cedbab3a347d9130b342b467dcb73b7 Mon Sep 17 00:00:00 2001 From: maybleMyers Date: Tue, 2 Sep 2025 03:22:09 -0700 Subject: [PATCH] mice --- backend/modules/k_diffusion_extra.py | 200 +++++++++++++++++++-------- 1 file changed, 142 insertions(+), 58 deletions(-) diff --git a/backend/modules/k_diffusion_extra.py b/backend/modules/k_diffusion_extra.py index 928c9a44..ae4e135d 100644 --- a/backend/modules/k_diffusion_extra.py +++ b/backend/modules/k_diffusion_extra.py @@ -51,20 +51,69 @@ def to_d(x, sigma, denoised): return (x - denoised) / sigma +# RES4LYF phi functions - copied from working implementation +def _gamma(n: int) -> int: + """Gamma function for positive integers: Γ(n) = (n-1)!""" + return math.factorial(n-1) + +def _incomplete_gamma(s: int, x: float, gamma_s=None) -> float: + """Incomplete gamma function for positive integer s""" + if gamma_s is None: + gamma_s = _gamma(s) + sum_ = 0.0 + for k in range(s): + sum_ += (x**k) / math.factorial(k) + return sum_ * math.exp(-x) * gamma_s + +def phi(j: int, neg_h: float): + """RES4LYF phi function implementation""" + assert j > 0 + gamma_ = _gamma(j) + incomp_gamma_ = _incomplete_gamma(j, neg_h, gamma_s=gamma_) + phi_ = math.exp(neg_h) * (neg_h**-j) * (1 - incomp_gamma_/gamma_) + return phi_ + +class Phi: + """RES4LYF Phi class - copied from working implementation""" + def __init__(self, h, c, analytic_solution=False): + self.h = h + self.c = c + self.cache = {} + self.phi_f = phi + + def __call__(self, j, i=-1): + if (j, i) in self.cache: + return self.cache[(j, i)] + + if i < 0: + c = 1 + else: + c = self.c[i - 1] + if c == 0: + self.cache[(j, i)] = 0 + return 0 + + if j == 0: + result = math.exp(float(-self.h * c)) + else: + result = self.phi_f(j, -self.h * c) + + self.cache[(j, i)] = result + return result + +# Legacy phi functions for backward compatibility def res_phi_1(h): """First phi function for RES samplers.""" if h.abs().max() < 1e-6: return 1.0 - h / 2 + h**2 / 12 return (torch.exp(h) - 1) / h - def res_phi_2(h): """Second phi function for RES samplers.""" if h.abs().max() < 1e-6: return 0.5 - h / 6 + h**2 / 24 return (torch.exp(h) - 1 - h) / (h**2) - def res_phi_3(h): """Third phi function for RES samplers.""" if h.abs().max() < 1e-6: @@ -72,6 +121,64 @@ def res_phi_3(h): return (torch.exp(h) - 1 - h - h**2 / 2) / (h**3) +def get_res_6s_coefficients(h): + """Get RES 6s coefficients - copied exactly from RES4LYF""" + # Original c-values from RES4LYF (with division by zero issue) + c1, c2, c3, c4, c5, c6 = 0, 1/2, 1/2, 1/3, 1/3, 5/6 + ci = [c1, c2, c3, c4, c5, c6] + φ = Phi(h, ci, analytic_solution=False) + + # Coefficient calculation - exact copy from RES4LYF + a2_1 = c2 * φ(1,2) + + a3_1 = 0 + a3_2 = (c3**2 / c2) * φ(2,3) + + a4_1 = 0 + a4_2 = (c4**2 / c2) * φ(2,4) + a4_3 = (c4**2 * φ(2,4) - a4_2 * c2) / c3 + + a5_1 = 0 + a5_2 = 0 #zero + # Handle division by zero - use L'Hôpital's rule limit or special case + if abs(c3 - c4) < 1e-10: # c3 == c4 case + # Use limit as c3 -> c4 + a5_3 = 0 # This is what the limit evaluates to + a5_4 = 0 + else: + a5_3 = (-c4 * c5**2 * φ(2,5) + 2*c5**3 * φ(3,5)) / (c3 * (c3 - c4)) + a5_4 = (-c3 * c5**2 * φ(2,5) + 2*c5**3 * φ(3,5)) / (c4 * (c4 - c3)) + + a6_1 = 0 + a6_2 = 0 #zero + if abs(c3 - c4) < 1e-10: # c3 == c4 case + a6_3 = 0 + a6_4 = 0 + else: + a6_3 = (-c4 * c6**2 * φ(2,6) + 2*c6**3 * φ(3,6)) / (c3 * (c3 - c4)) + a6_4 = (-c3 * c6**2 * φ(2,6) + 2*c6**3 * φ(3,6)) / (c4 * (c4 - c3)) + a6_5 = (c6**2 * φ(2,6) - a6_3*c3 - a6_4*c4) / c5 + + b1 = 0 + b2 = 0 + b3 = 0 + b4 = 0 + b5 = (-c6*φ(2) + 2*φ(3)) / (c5 * (c5 - c6)) + b6 = (-c5*φ(2) + 2*φ(3)) / (c6 * (c6 - c5)) + + a = [ + [0, 0, 0, 0, 0, 0], + [a2_1, 0, 0, 0, 0, 0], # First column from gen_first_col_exp + [0, a3_2, 0, 0, 0, 0], + [0, a4_2, a4_3, 0, 0, 0], + [0, a5_2, a5_3, a5_4, 0, 0], + [0, a6_2, a6_3, a6_4, a6_5, 0], + ] + b = [b1, b2, b3, b4, b5, b6] + + return a, b, ci + + # RES Samplers - Runge-Kutta Exponential Samplers @torch.no_grad() def sample_res_2s(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): @@ -102,73 +209,50 @@ def sample_res_2s(model, x, sigmas, extra_args=None, callback=None, disable=None @torch.no_grad() def sample_res_6s(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): - """RES 6-stage sampler - standalone implementation.""" + """RES 6-stage sampler - exact copy of RES4LYF math.""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): sigma = sigmas[i] sigma_next = sigmas[i + 1] - h = sigma_next - sigma + h = float(sigma_next - sigma) - # Get phi functions - phi_1 = res_phi_1(h) - phi_2 = res_phi_2(h) - phi_3 = res_phi_3(h) + # Get coefficients using exact RES4LYF calculation + a, b, ci = get_res_6s_coefficients(h) - # Stage 1 - denoised = model(x, sigma * s_in, **extra_args) - d_1 = to_d(x, sigma, denoised) + # Convert to proper format + num_stages = len(ci) - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised}) + # Stage computations - exact RK method + k = [] # Stage derivatives - # Stage 2 at c2 = 1/2 - c2 = 0.5 - a21 = c2 * phi_1 - x2 = denoised + sigma * a21 * d_1 - denoised2 = model(x2, (sigma + h * c2) * s_in, **extra_args) - d_2 = to_d(x2, sigma + h * c2, denoised2) + for stage in range(num_stages): + if stage == 0: + # First stage at current point + denoised = model(x, sigma * s_in, **extra_args) + k_i = to_d(x, sigma, denoised) + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised}) + else: + # Intermediate stages + x_stage = x + for j in range(stage): + x_stage = x_stage + h * a[stage][j] * k[j] + + sigma_stage = sigma + h * ci[stage] + denoised_stage = model(x_stage, sigma_stage * s_in, **extra_args) + k_i = to_d(x_stage, sigma_stage, denoised_stage) + + k.append(k_i) - # Stage 3 at c3 = 1/2 - c3 = 0.5 - a32 = (c3**2 / c2) * phi_2 - x3 = denoised + sigma * a32 * d_2 - denoised3 = model(x3, (sigma + h * c3) * s_in, **extra_args) - d_3 = to_d(x3, sigma + h * c3, denoised3) - - # Stage 4 at c4 = 1/3 - c4 = 1.0/3.0 - a42 = (c4**2 / c2) * phi_2 - a43 = (c4**2 * phi_2 - a42 * c2) / c3 - x4 = denoised + sigma * (a42 * d_2 + a43 * d_3) - denoised4 = model(x4, (sigma + h * c4) * s_in, **extra_args) - d_4 = to_d(x4, sigma + h * c4, denoised4) - - # Stage 5 at c5 = 2/3 (corrected from 1/3) - c5 = 2.0/3.0 - a53 = (-c4 * c5**2 * phi_2 + 2*c5**3 * phi_3) / (c3 * (c3 - c4)) - a54 = (-c3 * c5**2 * phi_2 + 2*c5**3 * phi_3) / (c4 * (c4 - c3)) - x5 = denoised + sigma * (a53 * d_3 + a54 * d_4) - denoised5 = model(x5, (sigma + h * c5) * s_in, **extra_args) - d_5 = to_d(x5, sigma + h * c5, denoised5) - - # Stage 6 at c6 = 5/6 - c6 = 5.0/6.0 - a63 = (-c4 * c6**2 * phi_2 + 2*c6**3 * phi_3) / (c3 * (c3 - c4)) - a64 = (-c3 * c6**2 * phi_2 + 2*c6**3 * phi_3) / (c4 * (c4 - c3)) - a65 = (c6**2 * phi_2 - a63*c3 - a64*c4) / c5 - x6 = denoised + sigma * (a63 * d_3 + a64 * d_4 + a65 * d_5) - - # Final weights - b5 = (-c6*phi_1 + 2*phi_2) / (c5 * (c5 - c6)) - b6 = (-c5*phi_1 + 2*phi_2) / (c6 * (c6 - c5)) - - # Final step - need d_6 from stage 6 - denoised6 = model(x6, (sigma + h * c6) * s_in, **extra_args) - d_6 = to_d(x6, sigma + h * c6, denoised6) - - x = denoised + sigma_next * (phi_1 * d_1 + b5 * d_5 + b6 * d_6) + # Final integration step using RK formula: x_new = x + h * sum(b_i * k_i) + x_new = x + for j in range(num_stages): + x_new = x_new + h * b[j] * k[j] + + x = x_new return x