mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-21 21:14:23 +08:00
mice
This commit is contained in:
parent
24638876b5
commit
478fd7b94c
@ -51,20 +51,69 @@ def to_d(x, sigma, denoised):
|
||||
return (x - denoised) / sigma
|
||||
|
||||
|
||||
# RES4LYF phi functions - copied from working implementation
|
||||
def _gamma(n: int) -> int:
|
||||
"""Gamma function for positive integers: Γ(n) = (n-1)!"""
|
||||
return math.factorial(n-1)
|
||||
|
||||
def _incomplete_gamma(s: int, x: float, gamma_s=None) -> float:
|
||||
"""Incomplete gamma function for positive integer s"""
|
||||
if gamma_s is None:
|
||||
gamma_s = _gamma(s)
|
||||
sum_ = 0.0
|
||||
for k in range(s):
|
||||
sum_ += (x**k) / math.factorial(k)
|
||||
return sum_ * math.exp(-x) * gamma_s
|
||||
|
||||
def phi(j: int, neg_h: float):
|
||||
"""RES4LYF phi function implementation"""
|
||||
assert j > 0
|
||||
gamma_ = _gamma(j)
|
||||
incomp_gamma_ = _incomplete_gamma(j, neg_h, gamma_s=gamma_)
|
||||
phi_ = math.exp(neg_h) * (neg_h**-j) * (1 - incomp_gamma_/gamma_)
|
||||
return phi_
|
||||
|
||||
class Phi:
|
||||
"""RES4LYF Phi class - copied from working implementation"""
|
||||
def __init__(self, h, c, analytic_solution=False):
|
||||
self.h = h
|
||||
self.c = c
|
||||
self.cache = {}
|
||||
self.phi_f = phi
|
||||
|
||||
def __call__(self, j, i=-1):
|
||||
if (j, i) in self.cache:
|
||||
return self.cache[(j, i)]
|
||||
|
||||
if i < 0:
|
||||
c = 1
|
||||
else:
|
||||
c = self.c[i - 1]
|
||||
if c == 0:
|
||||
self.cache[(j, i)] = 0
|
||||
return 0
|
||||
|
||||
if j == 0:
|
||||
result = math.exp(float(-self.h * c))
|
||||
else:
|
||||
result = self.phi_f(j, -self.h * c)
|
||||
|
||||
self.cache[(j, i)] = result
|
||||
return result
|
||||
|
||||
# Legacy phi functions for backward compatibility
|
||||
def res_phi_1(h):
|
||||
"""First phi function for RES samplers."""
|
||||
if h.abs().max() < 1e-6:
|
||||
return 1.0 - h / 2 + h**2 / 12
|
||||
return (torch.exp(h) - 1) / h
|
||||
|
||||
|
||||
def res_phi_2(h):
|
||||
"""Second phi function for RES samplers."""
|
||||
if h.abs().max() < 1e-6:
|
||||
return 0.5 - h / 6 + h**2 / 24
|
||||
return (torch.exp(h) - 1 - h) / (h**2)
|
||||
|
||||
|
||||
def res_phi_3(h):
|
||||
"""Third phi function for RES samplers."""
|
||||
if h.abs().max() < 1e-6:
|
||||
@ -72,6 +121,64 @@ def res_phi_3(h):
|
||||
return (torch.exp(h) - 1 - h - h**2 / 2) / (h**3)
|
||||
|
||||
|
||||
def get_res_6s_coefficients(h):
|
||||
"""Get RES 6s coefficients - copied exactly from RES4LYF"""
|
||||
# Original c-values from RES4LYF (with division by zero issue)
|
||||
c1, c2, c3, c4, c5, c6 = 0, 1/2, 1/2, 1/3, 1/3, 5/6
|
||||
ci = [c1, c2, c3, c4, c5, c6]
|
||||
φ = Phi(h, ci, analytic_solution=False)
|
||||
|
||||
# Coefficient calculation - exact copy from RES4LYF
|
||||
a2_1 = c2 * φ(1,2)
|
||||
|
||||
a3_1 = 0
|
||||
a3_2 = (c3**2 / c2) * φ(2,3)
|
||||
|
||||
a4_1 = 0
|
||||
a4_2 = (c4**2 / c2) * φ(2,4)
|
||||
a4_3 = (c4**2 * φ(2,4) - a4_2 * c2) / c3
|
||||
|
||||
a5_1 = 0
|
||||
a5_2 = 0 #zero
|
||||
# Handle division by zero - use L'Hôpital's rule limit or special case
|
||||
if abs(c3 - c4) < 1e-10: # c3 == c4 case
|
||||
# Use limit as c3 -> c4
|
||||
a5_3 = 0 # This is what the limit evaluates to
|
||||
a5_4 = 0
|
||||
else:
|
||||
a5_3 = (-c4 * c5**2 * φ(2,5) + 2*c5**3 * φ(3,5)) / (c3 * (c3 - c4))
|
||||
a5_4 = (-c3 * c5**2 * φ(2,5) + 2*c5**3 * φ(3,5)) / (c4 * (c4 - c3))
|
||||
|
||||
a6_1 = 0
|
||||
a6_2 = 0 #zero
|
||||
if abs(c3 - c4) < 1e-10: # c3 == c4 case
|
||||
a6_3 = 0
|
||||
a6_4 = 0
|
||||
else:
|
||||
a6_3 = (-c4 * c6**2 * φ(2,6) + 2*c6**3 * φ(3,6)) / (c3 * (c3 - c4))
|
||||
a6_4 = (-c3 * c6**2 * φ(2,6) + 2*c6**3 * φ(3,6)) / (c4 * (c4 - c3))
|
||||
a6_5 = (c6**2 * φ(2,6) - a6_3*c3 - a6_4*c4) / c5
|
||||
|
||||
b1 = 0
|
||||
b2 = 0
|
||||
b3 = 0
|
||||
b4 = 0
|
||||
b5 = (-c6*φ(2) + 2*φ(3)) / (c5 * (c5 - c6))
|
||||
b6 = (-c5*φ(2) + 2*φ(3)) / (c6 * (c6 - c5))
|
||||
|
||||
a = [
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[a2_1, 0, 0, 0, 0, 0], # First column from gen_first_col_exp
|
||||
[0, a3_2, 0, 0, 0, 0],
|
||||
[0, a4_2, a4_3, 0, 0, 0],
|
||||
[0, a5_2, a5_3, a5_4, 0, 0],
|
||||
[0, a6_2, a6_3, a6_4, a6_5, 0],
|
||||
]
|
||||
b = [b1, b2, b3, b4, b5, b6]
|
||||
|
||||
return a, b, ci
|
||||
|
||||
|
||||
# RES Samplers - Runge-Kutta Exponential Samplers
|
||||
@torch.no_grad()
|
||||
def sample_res_2s(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
@ -102,73 +209,50 @@ def sample_res_2s(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_6s(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
"""RES 6-stage sampler - standalone implementation."""
|
||||
"""RES 6-stage sampler - exact copy of RES4LYF math."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
sigma = sigmas[i]
|
||||
sigma_next = sigmas[i + 1]
|
||||
h = sigma_next - sigma
|
||||
h = float(sigma_next - sigma)
|
||||
|
||||
# Get phi functions
|
||||
phi_1 = res_phi_1(h)
|
||||
phi_2 = res_phi_2(h)
|
||||
phi_3 = res_phi_3(h)
|
||||
# Get coefficients using exact RES4LYF calculation
|
||||
a, b, ci = get_res_6s_coefficients(h)
|
||||
|
||||
# Stage 1
|
||||
denoised = model(x, sigma * s_in, **extra_args)
|
||||
d_1 = to_d(x, sigma, denoised)
|
||||
# Convert to proper format
|
||||
num_stages = len(ci)
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised})
|
||||
# Stage computations - exact RK method
|
||||
k = [] # Stage derivatives
|
||||
|
||||
# Stage 2 at c2 = 1/2
|
||||
c2 = 0.5
|
||||
a21 = c2 * phi_1
|
||||
x2 = denoised + sigma * a21 * d_1
|
||||
denoised2 = model(x2, (sigma + h * c2) * s_in, **extra_args)
|
||||
d_2 = to_d(x2, sigma + h * c2, denoised2)
|
||||
for stage in range(num_stages):
|
||||
if stage == 0:
|
||||
# First stage at current point
|
||||
denoised = model(x, sigma * s_in, **extra_args)
|
||||
k_i = to_d(x, sigma, denoised)
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised})
|
||||
else:
|
||||
# Intermediate stages
|
||||
x_stage = x
|
||||
for j in range(stage):
|
||||
x_stage = x_stage + h * a[stage][j] * k[j]
|
||||
|
||||
sigma_stage = sigma + h * ci[stage]
|
||||
denoised_stage = model(x_stage, sigma_stage * s_in, **extra_args)
|
||||
k_i = to_d(x_stage, sigma_stage, denoised_stage)
|
||||
|
||||
k.append(k_i)
|
||||
|
||||
# Stage 3 at c3 = 1/2
|
||||
c3 = 0.5
|
||||
a32 = (c3**2 / c2) * phi_2
|
||||
x3 = denoised + sigma * a32 * d_2
|
||||
denoised3 = model(x3, (sigma + h * c3) * s_in, **extra_args)
|
||||
d_3 = to_d(x3, sigma + h * c3, denoised3)
|
||||
|
||||
# Stage 4 at c4 = 1/3
|
||||
c4 = 1.0/3.0
|
||||
a42 = (c4**2 / c2) * phi_2
|
||||
a43 = (c4**2 * phi_2 - a42 * c2) / c3
|
||||
x4 = denoised + sigma * (a42 * d_2 + a43 * d_3)
|
||||
denoised4 = model(x4, (sigma + h * c4) * s_in, **extra_args)
|
||||
d_4 = to_d(x4, sigma + h * c4, denoised4)
|
||||
|
||||
# Stage 5 at c5 = 2/3 (corrected from 1/3)
|
||||
c5 = 2.0/3.0
|
||||
a53 = (-c4 * c5**2 * phi_2 + 2*c5**3 * phi_3) / (c3 * (c3 - c4))
|
||||
a54 = (-c3 * c5**2 * phi_2 + 2*c5**3 * phi_3) / (c4 * (c4 - c3))
|
||||
x5 = denoised + sigma * (a53 * d_3 + a54 * d_4)
|
||||
denoised5 = model(x5, (sigma + h * c5) * s_in, **extra_args)
|
||||
d_5 = to_d(x5, sigma + h * c5, denoised5)
|
||||
|
||||
# Stage 6 at c6 = 5/6
|
||||
c6 = 5.0/6.0
|
||||
a63 = (-c4 * c6**2 * phi_2 + 2*c6**3 * phi_3) / (c3 * (c3 - c4))
|
||||
a64 = (-c3 * c6**2 * phi_2 + 2*c6**3 * phi_3) / (c4 * (c4 - c3))
|
||||
a65 = (c6**2 * phi_2 - a63*c3 - a64*c4) / c5
|
||||
x6 = denoised + sigma * (a63 * d_3 + a64 * d_4 + a65 * d_5)
|
||||
|
||||
# Final weights
|
||||
b5 = (-c6*phi_1 + 2*phi_2) / (c5 * (c5 - c6))
|
||||
b6 = (-c5*phi_1 + 2*phi_2) / (c6 * (c6 - c5))
|
||||
|
||||
# Final step - need d_6 from stage 6
|
||||
denoised6 = model(x6, (sigma + h * c6) * s_in, **extra_args)
|
||||
d_6 = to_d(x6, sigma + h * c6, denoised6)
|
||||
|
||||
x = denoised + sigma_next * (phi_1 * d_1 + b5 * d_5 + b6 * d_6)
|
||||
# Final integration step using RK formula: x_new = x + h * sum(b_i * k_i)
|
||||
x_new = x
|
||||
for j in range(num_stages):
|
||||
x_new = x_new + h * b[j] * k[j]
|
||||
|
||||
x = x_new
|
||||
|
||||
return x
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user