Merge pull request #17 from maybleMyers/diffusers_mask

add diffusers style t5 mask to help details
This commit is contained in:
benjimon 2025-09-29 07:12:23 -07:00 committed by GitHub
commit 6ae9f548ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 83 additions and 16 deletions

View File

@ -52,7 +52,11 @@ class Chroma(ForgeDiffusionEngine):
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
return self.text_processing_engine_t5(prompt)
# Get embeddings with attention mask
embeddings, attention_mask = self.text_processing_engine_t5(prompt, return_attention_mask=True)
# Store attention mask in a dict along with embeddings
# Use 'crossattn' key to match the conditioning system's expectations
return {'crossattn': embeddings, 'attention_mask': attention_mask}
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):

View File

@ -60,7 +60,11 @@ class ChromaDCT(ForgeDiffusionEngine):
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
return self.text_processing_engine_t5(prompt)
# Get embeddings with attention mask
embeddings, attention_mask = self.text_processing_engine_t5(prompt, return_attention_mask=True)
# Store attention mask in a dict along with embeddings
# Use 'crossattn' key to match the conditioning system's expectations
return {'crossattn': embeddings, 'attention_mask': attention_mask}
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):

View File

@ -28,12 +28,27 @@ class KModel(torch.nn.Module):
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)
context = c_crossattn
# Handle context which may now be a dict with crossattn and attention_mask
if isinstance(c_crossattn, dict):
context = c_crossattn['crossattn']
attention_mask = c_crossattn.get('attention_mask', None)
else:
# Backward compatibility: if context is just a tensor
context = c_crossattn
attention_mask = None
dtype = self.computation_dtype
xc = xc.to(dtype)
t = self.predictor.timestep(t).float()
context = context.to(dtype)
# Keep attention mask as boolean if present
if attention_mask is not None:
# Store attention mask in transformer options for potential future use
transformer_options = transformer_options.copy() if transformer_options else {}
transformer_options['attention_mask'] = attention_mask
extra_conds = {}
for o in kwargs:
extra = kwargs[o]

View File

@ -179,7 +179,16 @@ class T5Stack(torch.nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
# CRITICAL: Keep attention mask as boolean to avoid dtype bug
# The bug occurs when masks are converted to float16/bfloat16
# PyTorch's scaled_dot_product_attention expects boolean masks
# First ensure mask is boolean
if attention_mask.dtype != torch.bool:
attention_mask = attention_mask.to(torch.bool)
# Create attention mask in the format expected by attention functions
# This creates a mask where True values are attended to
mask = 1.0 - attention_mask.float().reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
past_bias = None

View File

@ -55,14 +55,22 @@ class T5TextProcessingEngine:
tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
return tokenized
def encode_with_transformers(self, tokens):
def encode_with_transformers(self, tokens, attention_mask=None):
device = memory_management.text_encoder_device()
tokens = tokens.to(device)
self.text_encoder.shared.to(device=device, dtype=torch.float32)
z = self.text_encoder(
input_ids=tokens,
)
if attention_mask is not None:
# CRITICAL: Keep attention mask as boolean to avoid dtype conversion bug
attention_mask = attention_mask.to(device=device, dtype=torch.bool)
z = self.text_encoder(
input_ids=tokens,
attention_mask=attention_mask,
)
else:
z = self.text_encoder(
input_ids=tokens,
)
return z
@ -113,18 +121,20 @@ class T5TextProcessingEngine:
return chunks, token_count
def __call__(self, texts):
def __call__(self, texts, return_attention_mask=False):
zs = []
attention_masks = []
cache = {}
self.emphasis = emphasis.get_current_option(opts.emphasis)()
for line in texts:
if line in cache:
line_z_values = cache[line]
line_z_values, line_attention_masks = cache[line]
else:
chunks, token_count = self.tokenize_line(line)
line_z_values = []
line_attention_masks = []
# pad all chunks to length of longest chunk
max_tokens = 0
@ -134,24 +144,49 @@ class T5TextProcessingEngine:
for chunk in chunks:
tokens = chunk.tokens
multipliers = chunk.multipliers
# Track actual token count before padding
actual_token_count = len(tokens)
remaining_count = max_tokens - len(tokens)
if remaining_count > 0:
tokens += [self.id_pad] * remaining_count
multipliers += [1.0] * remaining_count
z = self.process_tokens([tokens], [multipliers])[0]
# Create attention mask - boolean tensor
# Chroma requires attention mask to include one padding token
# So we extend the mask by 1 if there's padding
attention_mask = torch.zeros(len(tokens), dtype=torch.bool)
if remaining_count > 0:
# Include actual tokens plus one padding token
attention_mask[:actual_token_count + 1] = True
else:
# All tokens are real (no padding)
attention_mask[:] = True
z = self.process_tokens([tokens], [multipliers], [attention_mask])[0]
line_z_values.append(z)
cache[line] = line_z_values
line_attention_masks.append(attention_mask)
cache[line] = (line_z_values, line_attention_masks)
zs.extend(line_z_values)
attention_masks.extend(line_attention_masks)
return torch.stack(zs)
if return_attention_mask:
return torch.stack(zs), torch.stack(attention_masks)
else:
return torch.stack(zs)
def process_tokens(self, batch_tokens, batch_multipliers):
def process_tokens(self, batch_tokens, batch_multipliers, batch_attention_masks=None):
tokens = torch.asarray(batch_tokens)
z = self.encode_with_transformers(tokens)
if batch_attention_masks is not None:
attention_mask = torch.stack(batch_attention_masks) if len(batch_attention_masks) > 1 else batch_attention_masks[0].unsqueeze(0)
else:
attention_mask = None
z = self.encode_with_transformers(tokens, attention_mask)
self.emphasis.tokens = batch_tokens
self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z)