stable-diffusion-webui-forge/modules_forge/alter_samplers.py
2025-09-02 02:58:17 -07:00

26 lines
1.0 KiB
Python

from modules import sd_samplers_kdiffusion, sd_samplers_common
from backend.modules import k_diffusion_extra
class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
def __init__(self, sd_model, sampler_name):
self.sampler_name = sampler_name
self.unet = sd_model.forge_objects.unet
sampler_function = getattr(k_diffusion_extra, "sample_{}".format(sampler_name))
super().__init__(sampler_function, sd_model, None)
def build_constructor(sampler_name):
def constructor(m):
return AlterSampler(m, sampler_name)
return constructor
samplers_data_alter = [
sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}),
sd_samplers_common.SamplerData('RES 2s', build_constructor(sampler_name='res_2s'), ['res_2s', 'res2s'], {}),
sd_samplers_common.SamplerData('RES 6s', build_constructor(sampler_name='res_6s'), ['res_6s', 'res6s'], {}),
sd_samplers_common.SamplerData('RES 16s', build_constructor(sampler_name='res_16s'), ['res_16s', 'res16s'], {}),
]