mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-21 21:14:23 +08:00
fix 16s
This commit is contained in:
parent
bb0aaeacb9
commit
cc6ee58939
@ -4,6 +4,7 @@ import torch
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
from itertools import permutations, combinations
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
@ -74,6 +75,25 @@ def phi(j: int, neg_h: float):
|
||||
phi_ = math.exp(neg_h) * (neg_h**-j) * (1 - incomp_gamma_/gamma_)
|
||||
return phi_
|
||||
|
||||
# Additional phi functions for higher orders
|
||||
def res_phi_4(h):
|
||||
"""Fourth phi function for RES samplers."""
|
||||
if h.abs().max() < 1e-6:
|
||||
return 1/24 - h / 120 + h**2 / 720
|
||||
return (torch.exp(h) - 1 - h - h**2 / 2 - h**3 / 6) / (h**4)
|
||||
|
||||
def res_phi_5(h):
|
||||
"""Fifth phi function for RES samplers."""
|
||||
if h.abs().max() < 1e-6:
|
||||
return 1/120 - h / 720 + h**2 / 5040
|
||||
return (torch.exp(h) - 1 - h - h**2 / 2 - h**3 / 6 - h**4 / 24) / (h**5)
|
||||
|
||||
def res_phi_6(h):
|
||||
"""Sixth phi function for RES samplers."""
|
||||
if h.abs().max() < 1e-6:
|
||||
return 1/720 - h / 5040 + h**2 / 40320
|
||||
return (torch.exp(h) - 1 - h - h**2 / 2 - h**3 / 6 - h**4 / 24 - h**5 / 120) / (h**6)
|
||||
|
||||
class Phi:
|
||||
"""RES4LYF Phi class - copied from working implementation"""
|
||||
def __init__(self, h, c, analytic_solution=False):
|
||||
@ -122,6 +142,47 @@ def res_phi_3(h):
|
||||
return (torch.exp(h) - 1 - h - h**2 / 2) / (h**3)
|
||||
|
||||
|
||||
# RES4LYF helper functions - EXACT copies
|
||||
def theta_numerator(j, cd, ci, ck, cj, cl):
|
||||
if j == 2:
|
||||
numerator = -cj * cd * ck * cl
|
||||
if j == 3:
|
||||
numerator = 2 * (cj * ck * cd + cj*ck*cl + ck*cd*cl + cd*cl*cj)
|
||||
if j == 4:
|
||||
numerator = -6*(cj*ck + cj*cd + cj*cl + ck*cd + ck*cl + cd*cl)
|
||||
if j == 5:
|
||||
numerator = 24 * (cj + ck + cl + cd)
|
||||
if j == 6:
|
||||
numerator = -120
|
||||
return numerator
|
||||
|
||||
def theta(j, cd, ci, ck, cj, cl):
|
||||
if j == 2:
|
||||
numerator = -cj * cd * ck * cl
|
||||
if j == 3:
|
||||
numerator = 2 * (cj * ck * cd + cj*ck*cl + ck*cd*cl + cd*cl*cj)
|
||||
if j == 4:
|
||||
numerator = -6*(cj*ck + cj*cd + cj*cl + ck*cd + ck*cl + cd*cl)
|
||||
if j == 5:
|
||||
numerator = 24 * (cj + ck + cl + cd)
|
||||
if j == 6:
|
||||
numerator = -120
|
||||
return numerator / (ci * (ci - cj) * (ci - ck) * (ci - cl) * (ci - cd))
|
||||
|
||||
def prod_diff(cj, ck, cl=None, cd=None):
|
||||
if cl is None and cd is None:
|
||||
return cj * (cj - ck)
|
||||
if cd is None:
|
||||
return cj * (cj - ck) * (cj - cl)
|
||||
else:
|
||||
return cj * (cj - ck) * (cj - cl) * (cj - cd)
|
||||
|
||||
def denominator(ci, *args):
|
||||
result = ci
|
||||
for arg in args:
|
||||
result *= (ci - arg)
|
||||
return result
|
||||
|
||||
def get_res_6s_coefficients(h):
|
||||
"""Get RES 6s coefficients - copied exactly from RES4LYF"""
|
||||
# Original c-values from RES4LYF (with division by zero issue)
|
||||
@ -180,6 +241,51 @@ def get_res_6s_coefficients(h):
|
||||
return a, b, ci
|
||||
|
||||
|
||||
def get_res_16s_coefficients(h):
|
||||
"""Get RES 16s coefficients - EXACT copy from RES4LYF"""
|
||||
use_analytic_solution = False # Use same as RES4LYF default
|
||||
|
||||
c1 = 0
|
||||
c2 = c3 = c5 = c8 = c12 = 1/2
|
||||
c4 = c11 = c15 = 1/3
|
||||
c6 = c9 = c13 = 1/5
|
||||
c7 = c10 = c14 = 1/4
|
||||
c16 = 1
|
||||
ci = [c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16]
|
||||
φ = Phi(h, ci, analytic_solution=use_analytic_solution)
|
||||
|
||||
a3_2 = (1/2) * φ(2,3)
|
||||
|
||||
a = [[0.0 for _ in range(16)] for _ in range(16)]
|
||||
b = [[0.0 for _ in range(16)]]
|
||||
|
||||
for i in range(3, 5): # i=3,4 j=2
|
||||
j=2
|
||||
a[i-1][j-1] = (ci[i-1]**2 / ci[j-1]) * φ(j,i)
|
||||
|
||||
for i in range(5, 8): # i=5,6,7 j,k ∈ {3, 4}, j != k
|
||||
jk = list(permutations([3, 4], 2))
|
||||
for j,k in jk:
|
||||
a[i-1][j-1] = (-ci[i-1]**2 * ci[k-1] * φ(2,i) + 2*ci[i-1]**3 * φ(3,i)) / prod_diff(ci[j-1], ci[k-1])
|
||||
|
||||
for i in range(8, 12): # i=8,9,10,11 j,k,l ∈ {5, 6, 7}, j != k != l
|
||||
jkl = list(permutations([5, 6, 7], 3))
|
||||
for j,k,l in jkl:
|
||||
a[i-1][j-1] = (ci[i-1]**2 * ci[k-1] * ci[l-1] * φ(2,i) - 2*ci[i-1]**3 * (ci[k-1] + ci[l-1]) * φ(3,i) + 6*ci[i-1]**4 * φ(4,i)) / (ci[j-1] * (ci[j-1] - ci[k-1]) * (ci[j-1] - ci[l-1]))
|
||||
|
||||
for i in range(12,16): # i=12,13,14,15
|
||||
jkld = list(permutations([8,9,10,11], 4))
|
||||
for j,k,l,d in jkld:
|
||||
numerator = -ci[i-1]**2 * ci[d-1]*ci[k-1]*ci[l-1] * φ(2,i) + 2*ci[i-1]**3 * (ci[d-1]*ci[k-1] + ci[d-1]*ci[l-1] + ci[k-1]*ci[l-1]) * φ(3,i) - 6*ci[i-1]**4 * (ci[d-1] + ci[k-1] + ci[l-1]) * φ(4,i) + 24*ci[i-1]**5 * φ(5,i)
|
||||
a[i-1][j-1] = numerator / prod_diff(ci[j-1], ci[k-1], ci[l-1], ci[d-1])
|
||||
|
||||
ijdkl = list(permutations([12,13,14,15,16], 5))
|
||||
for i,j,d,k,l in ijdkl:
|
||||
b[0][i-1] = theta(2, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1]) * φ(2) + theta(3, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1])*φ(3) + theta(4, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1])*φ(4) + theta(5, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1])*φ(5) + theta(6, ci[d-1], ci[i-1], ci[k-1], ci[j-1], ci[l-1]) * φ(6)
|
||||
|
||||
return a, b, ci
|
||||
|
||||
|
||||
# RES Samplers - Runge-Kutta Exponential Samplers
|
||||
@torch.no_grad()
|
||||
def sample_res_2s(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
@ -260,75 +366,49 @@ def sample_res_6s(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_res_16s(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
"""RES 16-stage sampler - high-order exponential Runge-Kutta method."""
|
||||
"""RES 16-stage sampler - EXACT copy of RES4LYF math."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
sigma = sigmas[i]
|
||||
sigma_next = sigmas[i + 1]
|
||||
h = sigma_next - sigma
|
||||
h = float(sigma_next - sigma)
|
||||
|
||||
# Get phi functions
|
||||
phi_1 = res_phi_1(h)
|
||||
phi_2 = res_phi_2(h)
|
||||
phi_3 = res_phi_3(h)
|
||||
# Get coefficients using EXACT RES4LYF calculation
|
||||
a, b, ci = get_res_16s_coefficients(h)
|
||||
|
||||
# Stage 1
|
||||
denoised = model(x, sigma * s_in, **extra_args)
|
||||
d_1 = to_d(x, sigma, denoised)
|
||||
# Convert to proper format
|
||||
num_stages = len(ci)
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised})
|
||||
# Stage computations - exact RK method
|
||||
k = [] # Stage derivatives
|
||||
|
||||
# High-order multi-stage method with multiple intermediate evaluations
|
||||
# Using a simplified 8-stage approach that approximates 16-stage behavior
|
||||
stages = []
|
||||
c_vals = [0, 1/8, 1/4, 3/8, 1/2, 5/8, 3/4, 7/8, 1.0]
|
||||
|
||||
for stage in range(1, 9): # 8 stages
|
||||
c = c_vals[stage]
|
||||
|
||||
if stage == 1:
|
||||
# First intermediate stage
|
||||
x_stage = denoised + sigma * c * phi_1 * d_1
|
||||
sigma_stage = sigma + h * c
|
||||
denoised_stage = model(x_stage, sigma_stage * s_in, **extra_args)
|
||||
d_stage = to_d(x_stage, sigma_stage, denoised_stage)
|
||||
stages.append(d_stage)
|
||||
|
||||
elif stage == 2:
|
||||
# Second stage using first intermediate
|
||||
a21 = c * phi_1
|
||||
a22 = (c**2 / c_vals[1]) * phi_2
|
||||
x_stage = denoised + sigma * (a21 * d_1 + a22 * stages[0])
|
||||
sigma_stage = sigma + h * c
|
||||
denoised_stage = model(x_stage, sigma_stage * s_in, **extra_args)
|
||||
d_stage = to_d(x_stage, sigma_stage, denoised_stage)
|
||||
stages.append(d_stage)
|
||||
for stage in range(num_stages):
|
||||
if stage == 0:
|
||||
# First stage at current point
|
||||
denoised = model(x, sigma * s_in, **extra_args)
|
||||
k_i = to_d(x, sigma, denoised)
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigma, 'sigma_hat': sigma, 'denoised': denoised})
|
||||
else:
|
||||
# Higher stages - simplified combination
|
||||
weights = [phi_1 * c / stage] # Weight for d_1
|
||||
x_stage = denoised + sigma * weights[0] * d_1
|
||||
# Intermediate stages
|
||||
x_stage = x
|
||||
for j in range(stage):
|
||||
x_stage = x_stage + h * a[stage][j] * k[j]
|
||||
|
||||
# Add contributions from previous stages
|
||||
for j, prev_d in enumerate(stages[:min(stage-1, 3)]): # Limit to avoid instability
|
||||
weight = phi_2 * c * (0.5 ** (j + 1)) / stage
|
||||
x_stage += sigma * weight * prev_d
|
||||
|
||||
sigma_stage = sigma + h * c
|
||||
sigma_stage = sigma + h * ci[stage]
|
||||
denoised_stage = model(x_stage, sigma_stage * s_in, **extra_args)
|
||||
d_stage = to_d(x_stage, sigma_stage, denoised_stage)
|
||||
stages.append(d_stage)
|
||||
|
||||
# Final combination with high-order weights
|
||||
final_d = phi_1 * d_1
|
||||
for j, stage_d in enumerate(stages[:6]): # Use first 6 stages
|
||||
weight = phi_2 * (0.6 ** j) / (j + 2) # Decreasing weights
|
||||
final_d += weight * stage_d
|
||||
k_i = to_d(x_stage, sigma_stage, denoised_stage)
|
||||
|
||||
# Final step
|
||||
x = denoised + sigma_next * final_d
|
||||
k.append(k_i)
|
||||
|
||||
# Final integration step using RK formula: x_new = x + h * sum(b_i * k_i)
|
||||
x_new = x
|
||||
for j in range(num_stages):
|
||||
x_new = x_new + h * b[0][j] * k[j] # Note: b is nested list for RES 16s
|
||||
|
||||
x = x_new
|
||||
|
||||
return x
|
||||
|
||||
Loading…
Reference in New Issue
Block a user