This commit is contained in:
maybleMyers 2025-09-02 04:02:24 -07:00
parent bb0aaeacb9
commit cc6ee58939

View File

@ -4,6 +4,7 @@ import torch
import sys
import os
import math
from itertools import permutations, combinations
from tqdm import trange
@ -74,6 +75,25 @@ def phi(j: int, neg_h: float):
phi_ = math.exp(neg_h) * (neg_h**-j) * (1 - incomp_gamma_/gamma_)
return phi_
# Additional phi functions for higher orders
def res_phi_4(h):
"""Fourth phi function for RES samplers."""
if h.abs().max() < 1e-6:
return 1/24 - h / 120 + h**2 / 720
return (torch.exp(h) - 1 - h - h**2 / 2 - h**3 / 6) / (h**4)
def res_phi_5(h):
"""Fifth phi function for RES samplers."""
if h.abs().max() < 1e-6:
return 1/120 - h / 720 + h**2 / 5040
return (torch.exp(h) - 1 - h - h**2 / 2 - h**3 / 6 - h**4 / 24) / (h**5)
def res_phi_6(h):
"""Sixth phi function for RES samplers."""
if h.abs().max() < 1e-6:
return 1/720 - h / 5040 + h**2 / 40320
return (torch.exp(h) - 1 - h - h**2 / 2 - h**3 / 6 - h**4 / 24 - h**5 / 120) / (h**6)
class Phi:
"""RES4LYF Phi class - copied from working implementation"""
def __init__(self, h, c, analytic_solution=False):
@ -122,6 +142,47 @@ def res_phi_3(h):
return (torch.exp(h) - 1 - h - h**2 / 2) / (h**3)
# RES4LYF helper functions - EXACT copies
def theta_numerator(j, cd, ci, ck, cj, cl):
if j == 2:
numerator = -cj * cd * ck * cl
if j == 3:
numerator = 2 * (cj * ck * cd + cj*ck*cl + ck*cd*cl + cd*cl*cj)
if j == 4:
numerator = -6*(cj*ck + cj*cd + cj*cl + ck*cd + ck*cl + cd*cl)
if j == 5:
numerator = 24 * (cj + ck + cl + cd)
if j == 6:
numerator = -120
return numerator
def theta(j, cd, ci, ck, cj, cl):
if j == 2:
numerator = -cj * cd * ck * cl
if j == 3:
numerator = 2 * (cj * ck * cd + cj*ck*cl + ck*cd*cl + cd*cl*cj)
if j == 4:
numerator = -6*(cj*ck + cj*cd + cj*cl + ck*cd + ck*cl + cd*cl)
if j == 5:
numerator = 24 * (cj + ck + cl + cd)
if j == 6:
numerator = -120
return numerator / (ci * (ci - cj) * (ci - ck) * (ci - cl) * (ci - cd))
def prod_diff(cj, ck, cl=None, cd=None):
if cl is None and cd is None:
return cj * (cj - ck)
if cd is None:
return cj * (cj - ck) * (cj - cl)
else:
return cj * (cj - ck) * (cj - cl) * (cj - cd)
def denominator(ci, *args):
result = ci
for arg in args:
result *= (ci - arg)
return result
def get_res_6s_coefficients(h):
"""Get RES 6s coefficients - copied exactly from RES4LYF"""
# Original c-values from RES4LYF (with division by zero issue)
@ -180,6 +241,51 @@ def get_res_6s_coefficients(h):
return a, b, ci
def get_res_16s_coefficients(h):
"""Get RES 16s coefficients - EXACT copy from RES4LYF"""
use_analytic_solution = False # Use same as RES4LYF default
c1 = 0
c2 = c3 = c5 = c8 = c12 = 1/2
c4 = c11 = c15 = 1/3
c6 = c9 = c13 = 1/5
c7 = c10 = c14 = 1/4
c16 = 1
ci = [c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16]
φ = Phi(h, ci, analytic_solution=use_analytic_solution)
a3_2 = (1/2) * φ(2,3)
a = [[0.0 for _ in range(16)] for _ in range(16)]
b = [[0.0 for _ in range(16)]]
for i in range(3, 5): # i=3,4 j=2
j=2
a[i-1][j-1] = (ci[i-1]**2 / ci[j-1]) * φ(j,i)
for i in range(5, 8): # i=5,6,7 j,k ∈ {3, 4}, j != k
jk = list(permutations([3, 4], 2))
for j,k in jk:
a[i-1][j-1] = (-ci[i-1]**2 * ci[k-1] * φ(2,i) + 2*ci[i-1]**3 * φ(3,i)) / prod_diff(ci[j-1], ci[k-1])
for i in range(8, 12): # i=8,9,10,11 j,k,l ∈ {5, 6, 7}, j != k != l
jkl = list(permutations([5, 6, 7], 3))
for j,k,l in jkl:
a[i-1][j-1] = (ci[i-1]**2 * ci[k-1] * ci[l-1] * φ(2,i) - 2*ci[i-1]**3 * (ci[k-1] + ci[l-1]) * φ(3,i) + 6*ci[i-1]**4 * φ(4,i)) / (ci[j-1] * (ci[j-1] - ci[k-1]) * (ci[j-1] - ci[l-1]))
for i in range(12,16): # i=12,13,14,15
jkld = list(permutations([8,9,10,11], 4))
for j,k,l,d in jkld:
numerator = -ci[i-1]**2 * ci[d-1]*ci[k-1]*ci[l-1] * φ(2,i) + 2*ci[i-1]**3 * (ci[d-1]*ci[k-1] + ci[d-1]*ci[l-1] + ci[k-1]*ci[l-1]) * φ(3,i) - 6*ci[i-1]**4 * (ci[d-1] + ci[k-1] + ci[l-1]) * φ(4,i) + 24*ci[i-1]**5 * φ(5,i)
a[i-1][j-1] = numerator / prod_diff(ci[j-1], ci[k-1], ci[l-1], ci[d-1])
ijdkl = list(permutations([12,13,14,15,16], 5))
for i,j,d,k,l in ijdkl:
b[0][i-1] = theta(2, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1]) * φ(2) + theta(3, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1])*φ(3) + theta(4, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1])*φ(4) + theta(5, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1])*φ(5) + theta(6, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1]) * φ(6)
return a, b, ci
# RES Samplers - Runge-Kutta Exponential Samplers
@torch.no_grad()
def sample_res_2s(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
@ -260,75 +366,49 @@ def sample_res_6s(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_res_16s(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
"""RES 16-stage sampler - high-order exponential Runge-Kutta method."""
"""RES 16-stage sampler - EXACT copy of RES4LYF math."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
h = sigma_next - sigma
h = float(sigma_next - sigma)
# Get phi functions
phi_1 = res_phi_1(h)
phi_2 = res_phi_2(h)
phi_3 = res_phi_3(h)
# Get coefficients using EXACT RES4LYF calculation
a, b, ci = get_res_16s_coefficients(h)
# Stage 1
denoised = model(x, sigma * s_in, **extra_args)
d_1 = to_d(x, sigma, denoised)
# Convert to proper format
num_stages = len(ci)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised})
# Stage computations - exact RK method
k = [] # Stage derivatives
# High-order multi-stage method with multiple intermediate evaluations
# Using a simplified 8-stage approach that approximates 16-stage behavior
stages = []
c_vals = [0, 1/8, 1/4, 3/8, 1/2, 5/8, 3/4, 7/8, 1.0]
for stage in range(1, 9): # 8 stages
c = c_vals[stage]
if stage == 1:
# First intermediate stage
x_stage = denoised + sigma * c * phi_1 * d_1
sigma_stage = sigma + h * c
denoised_stage = model(x_stage, sigma_stage * s_in, **extra_args)
d_stage = to_d(x_stage, sigma_stage, denoised_stage)
stages.append(d_stage)
elif stage == 2:
# Second stage using first intermediate
a21 = c * phi_1
a22 = (c**2 / c_vals[1]) * phi_2
x_stage = denoised + sigma * (a21 * d_1 + a22 * stages[0])
sigma_stage = sigma + h * c
denoised_stage = model(x_stage, sigma_stage * s_in, **extra_args)
d_stage = to_d(x_stage, sigma_stage, denoised_stage)
stages.append(d_stage)
for stage in range(num_stages):
if stage == 0:
# First stage at current point
denoised = model(x, sigma * s_in, **extra_args)
k_i = to_d(x, sigma, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised})
else:
# Higher stages - simplified combination
weights = [phi_1 * c / stage] # Weight for d_1
x_stage = denoised + sigma * weights[0] * d_1
# Intermediate stages
x_stage = x
for j in range(stage):
x_stage = x_stage + h * a[stage][j] * k[j]
# Add contributions from previous stages
for j, prev_d in enumerate(stages[:min(stage-1, 3)]): # Limit to avoid instability
weight = phi_2 * c * (0.5 ** (j + 1)) / stage
x_stage += sigma * weight * prev_d
sigma_stage = sigma + h * c
sigma_stage = sigma + h * ci[stage]
denoised_stage = model(x_stage, sigma_stage * s_in, **extra_args)
d_stage = to_d(x_stage, sigma_stage, denoised_stage)
stages.append(d_stage)
# Final combination with high-order weights
final_d = phi_1 * d_1
for j, stage_d in enumerate(stages[:6]): # Use first 6 stages
weight = phi_2 * (0.6 ** j) / (j + 2) # Decreasing weights
final_d += weight * stage_d
k_i = to_d(x_stage, sigma_stage, denoised_stage)
# Final step
x = denoised + sigma_next * final_d
k.append(k_i)
# Final integration step using RK formula: x_new = x + h * sum(b_i * k_i)
x_new = x
for j in range(num_stages):
x_new = x_new + h * b[0][j] * k[j] # Note: b is nested list for RES 16s
x = x_new
return x