From e1a95ce72652800678bdb9b83243eb2b9f9071ea Mon Sep 17 00:00:00 2001 From: maybleMyers Date: Mon, 29 Sep 2025 07:07:59 -0700 Subject: [PATCH] add diffusers style t5 mask --- backend/diffusion_engine/chroma.py | 6 ++- backend/diffusion_engine/chroma_dct.py | 6 ++- backend/modules/k_model.py | 17 +++++++- backend/nn/t5.py | 11 ++++- backend/text_processing/t5_engine.py | 59 ++++++++++++++++++++------ 5 files changed, 83 insertions(+), 16 deletions(-) diff --git a/backend/diffusion_engine/chroma.py b/backend/diffusion_engine/chroma.py index a10256d7..25f98af1 100644 --- a/backend/diffusion_engine/chroma.py +++ b/backend/diffusion_engine/chroma.py @@ -52,7 +52,11 @@ class Chroma(ForgeDiffusionEngine): @torch.inference_mode() def get_learned_conditioning(self, prompt: list[str]): memory_management.load_model_gpu(self.forge_objects.clip.patcher) - return self.text_processing_engine_t5(prompt) + # Get embeddings with attention mask + embeddings, attention_mask = self.text_processing_engine_t5(prompt, return_attention_mask=True) + # Store attention mask in a dict along with embeddings + # Use 'crossattn' key to match the conditioning system's expectations + return {'crossattn': embeddings, 'attention_mask': attention_mask} @torch.inference_mode() def get_prompt_lengths_on_ui(self, prompt): diff --git a/backend/diffusion_engine/chroma_dct.py b/backend/diffusion_engine/chroma_dct.py index a782eff7..dd7058e3 100644 --- a/backend/diffusion_engine/chroma_dct.py +++ b/backend/diffusion_engine/chroma_dct.py @@ -60,7 +60,11 @@ class ChromaDCT(ForgeDiffusionEngine): @torch.inference_mode() def get_learned_conditioning(self, prompt: list[str]): memory_management.load_model_gpu(self.forge_objects.clip.patcher) - return self.text_processing_engine_t5(prompt) + # Get embeddings with attention mask + embeddings, attention_mask = self.text_processing_engine_t5(prompt, return_attention_mask=True) + # Store attention mask in a dict along with embeddings + # Use 'crossattn' key to match the conditioning system's expectations + return {'crossattn': embeddings, 'attention_mask': attention_mask} @torch.inference_mode() def get_prompt_lengths_on_ui(self, prompt): diff --git a/backend/modules/k_model.py b/backend/modules/k_model.py index 7ed1c5a0..62e0269b 100644 --- a/backend/modules/k_model.py +++ b/backend/modules/k_model.py @@ -28,12 +28,27 @@ class KModel(torch.nn.Module): if c_concat is not None: xc = torch.cat([xc] + [c_concat], dim=1) - context = c_crossattn + # Handle context which may now be a dict with crossattn and attention_mask + if isinstance(c_crossattn, dict): + context = c_crossattn['crossattn'] + attention_mask = c_crossattn.get('attention_mask', None) + else: + # Backward compatibility: if context is just a tensor + context = c_crossattn + attention_mask = None + dtype = self.computation_dtype xc = xc.to(dtype) t = self.predictor.timestep(t).float() context = context.to(dtype) + + # Keep attention mask as boolean if present + if attention_mask is not None: + # Store attention mask in transformer options for potential future use + transformer_options = transformer_options.copy() if transformer_options else {} + transformer_options['attention_mask'] = attention_mask + extra_conds = {} for o in kwargs: extra = kwargs[o] diff --git a/backend/nn/t5.py b/backend/nn/t5.py index 74e0ab70..1adeff0e 100644 --- a/backend/nn/t5.py +++ b/backend/nn/t5.py @@ -179,7 +179,16 @@ class T5Stack(torch.nn.Module): mask = None if attention_mask is not None: - mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + # CRITICAL: Keep attention mask as boolean to avoid dtype bug + # The bug occurs when masks are converted to float16/bfloat16 + # PyTorch's scaled_dot_product_attention expects boolean masks + # First ensure mask is boolean + if attention_mask.dtype != torch.bool: + attention_mask = attention_mask.to(torch.bool) + + # Create attention mask in the format expected by attention functions + # This creates a mask where True values are attended to + mask = 1.0 - attention_mask.float().reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) past_bias = None diff --git a/backend/text_processing/t5_engine.py b/backend/text_processing/t5_engine.py index 0d5b115d..78204cc1 100644 --- a/backend/text_processing/t5_engine.py +++ b/backend/text_processing/t5_engine.py @@ -55,14 +55,22 @@ class T5TextProcessingEngine: tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] return tokenized - def encode_with_transformers(self, tokens): + def encode_with_transformers(self, tokens, attention_mask=None): device = memory_management.text_encoder_device() tokens = tokens.to(device) self.text_encoder.shared.to(device=device, dtype=torch.float32) - z = self.text_encoder( - input_ids=tokens, - ) + if attention_mask is not None: + # CRITICAL: Keep attention mask as boolean to avoid dtype conversion bug + attention_mask = attention_mask.to(device=device, dtype=torch.bool) + z = self.text_encoder( + input_ids=tokens, + attention_mask=attention_mask, + ) + else: + z = self.text_encoder( + input_ids=tokens, + ) return z @@ -113,18 +121,20 @@ class T5TextProcessingEngine: return chunks, token_count - def __call__(self, texts): + def __call__(self, texts, return_attention_mask=False): zs = [] + attention_masks = [] cache = {} self.emphasis = emphasis.get_current_option(opts.emphasis)() for line in texts: if line in cache: - line_z_values = cache[line] + line_z_values, line_attention_masks = cache[line] else: chunks, token_count = self.tokenize_line(line) line_z_values = [] + line_attention_masks = [] # pad all chunks to length of longest chunk max_tokens = 0 @@ -134,24 +144,49 @@ class T5TextProcessingEngine: for chunk in chunks: tokens = chunk.tokens multipliers = chunk.multipliers - + + # Track actual token count before padding + actual_token_count = len(tokens) + remaining_count = max_tokens - len(tokens) if remaining_count > 0: tokens += [self.id_pad] * remaining_count multipliers += [1.0] * remaining_count - z = self.process_tokens([tokens], [multipliers])[0] + # Create attention mask - boolean tensor + # Chroma requires attention mask to include one padding token + # So we extend the mask by 1 if there's padding + attention_mask = torch.zeros(len(tokens), dtype=torch.bool) + if remaining_count > 0: + # Include actual tokens plus one padding token + attention_mask[:actual_token_count + 1] = True + else: + # All tokens are real (no padding) + attention_mask[:] = True + + z = self.process_tokens([tokens], [multipliers], [attention_mask])[0] line_z_values.append(z) - cache[line] = line_z_values + line_attention_masks.append(attention_mask) + + cache[line] = (line_z_values, line_attention_masks) zs.extend(line_z_values) + attention_masks.extend(line_attention_masks) - return torch.stack(zs) + if return_attention_mask: + return torch.stack(zs), torch.stack(attention_masks) + else: + return torch.stack(zs) - def process_tokens(self, batch_tokens, batch_multipliers): + def process_tokens(self, batch_tokens, batch_multipliers, batch_attention_masks=None): tokens = torch.asarray(batch_tokens) - z = self.encode_with_transformers(tokens) + if batch_attention_masks is not None: + attention_mask = torch.stack(batch_attention_masks) if len(batch_attention_masks) > 1 else batch_attention_masks[0].unsqueeze(0) + else: + attention_mask = None + + z = self.encode_with_transformers(tokens, attention_mask) self.emphasis.tokens = batch_tokens self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z)