stable-diffusion-webui-forge/backend/diffusion_engine/chroma_dct.py
2025-09-29 07:07:59 -07:00

88 lines
3.5 KiB
Python

import torch
from huggingface_guess import model_list
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
from backend.patcher.clip import CLIP
from backend.patcher.unet import UnetPatcher
from backend.text_processing.t5_engine import T5TextProcessingEngine
from backend.args import dynamic_args
from backend.modules.k_prediction import PredictionFlux
from backend import memory_management
class ChromaDCT(ForgeDiffusionEngine):
def __init__(self, estimated_config, huggingface_components):
super().__init__(estimated_config, huggingface_components)
self.is_inpaint = False
clip = CLIP(
model_dict={
't5xxl': huggingface_components['text_encoder']
},
tokenizer_dict={
't5xxl': huggingface_components['tokenizer']
}
)
# ChromaDCT operates in pixel space, so no VAE needed
# vae = None # DCT model operates directly in pixel space
k_predictor = PredictionFlux(
mu=1.0
)
unet = UnetPatcher.from_model(
model=huggingface_components['transformer'],
diffusers_scheduler=None,
k_predictor=k_predictor,
config=estimated_config
)
# Removed ChromaDCT-specific memory estimation - use same as regular Chroma
# This fixes the performance issue where ChromaDCT was 2x slower
# Enable optimized offloading for ChromaDCT models
self.use_optimized_offloading = True
self.text_processing_engine_t5 = T5TextProcessingEngine(
text_encoder=clip.cond_stage_model.t5xxl,
tokenizer=clip.tokenizer.t5xxl,
emphasis_name=dynamic_args['emphasis_name'],
min_length=1
)
# Create forge objects without VAE since DCT operates in pixel space
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=None, clipvision=None)
self.forge_objects_original = self.forge_objects.shallow_copy()
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
def set_clip_skip(self, clip_skip):
pass
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
# Get embeddings with attention mask
embeddings, attention_mask = self.text_processing_engine_t5(prompt, return_attention_mask=True)
# Store attention mask in a dict along with embeddings
# Use 'crossattn' key to match the conditioning system's expectations
return {'crossattn': embeddings, 'attention_mask': attention_mask}
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
return token_count, max(255, token_count)
@torch.inference_mode()
def encode_first_stage(self, x):
# DCT model operates directly in pixel space
# Convert from [-1, 1] to [0, 1] range for pixel values
# Input x is expected to be in [-1, 1] range from UI
# DCT model expects pixel values in [0, 1] or [-1, 1] depending on training
# Based on the training code, it seems to expect values in standard range
return x
@torch.inference_mode()
def decode_first_stage(self, x):
# DCT model outputs directly in pixel space
# No decoding needed, just ensure proper range
# Output should be in [-1, 1] range for UI compatibility
return x.clamp(-1.0, 1.0)