mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-24 21:07:35 +08:00
Merge pull request #17 from maybleMyers/diffusers_mask
add diffusers style t5 mask to help details
This commit is contained in:
commit
6ae9f548ae
@ -52,7 +52,11 @@ class Chroma(ForgeDiffusionEngine):
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
return self.text_processing_engine_t5(prompt)
|
||||
# Get embeddings with attention mask
|
||||
embeddings, attention_mask = self.text_processing_engine_t5(prompt, return_attention_mask=True)
|
||||
# Store attention mask in a dict along with embeddings
|
||||
# Use 'crossattn' key to match the conditioning system's expectations
|
||||
return {'crossattn': embeddings, 'attention_mask': attention_mask}
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
|
||||
@ -60,7 +60,11 @@ class ChromaDCT(ForgeDiffusionEngine):
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
return self.text_processing_engine_t5(prompt)
|
||||
# Get embeddings with attention mask
|
||||
embeddings, attention_mask = self.text_processing_engine_t5(prompt, return_attention_mask=True)
|
||||
# Store attention mask in a dict along with embeddings
|
||||
# Use 'crossattn' key to match the conditioning system's expectations
|
||||
return {'crossattn': embeddings, 'attention_mask': attention_mask}
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
|
||||
@ -28,12 +28,27 @@ class KModel(torch.nn.Module):
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([xc] + [c_concat], dim=1)
|
||||
|
||||
context = c_crossattn
|
||||
# Handle context which may now be a dict with crossattn and attention_mask
|
||||
if isinstance(c_crossattn, dict):
|
||||
context = c_crossattn['crossattn']
|
||||
attention_mask = c_crossattn.get('attention_mask', None)
|
||||
else:
|
||||
# Backward compatibility: if context is just a tensor
|
||||
context = c_crossattn
|
||||
attention_mask = None
|
||||
|
||||
dtype = self.computation_dtype
|
||||
|
||||
xc = xc.to(dtype)
|
||||
t = self.predictor.timestep(t).float()
|
||||
context = context.to(dtype)
|
||||
|
||||
# Keep attention mask as boolean if present
|
||||
if attention_mask is not None:
|
||||
# Store attention mask in transformer options for potential future use
|
||||
transformer_options = transformer_options.copy() if transformer_options else {}
|
||||
transformer_options['attention_mask'] = attention_mask
|
||||
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
|
||||
@ -179,7 +179,16 @@ class T5Stack(torch.nn.Module):
|
||||
mask = None
|
||||
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
# CRITICAL: Keep attention mask as boolean to avoid dtype bug
|
||||
# The bug occurs when masks are converted to float16/bfloat16
|
||||
# PyTorch's scaled_dot_product_attention expects boolean masks
|
||||
# First ensure mask is boolean
|
||||
if attention_mask.dtype != torch.bool:
|
||||
attention_mask = attention_mask.to(torch.bool)
|
||||
|
||||
# Create attention mask in the format expected by attention functions
|
||||
# This creates a mask where True values are attended to
|
||||
mask = 1.0 - attention_mask.float().reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
|
||||
past_bias = None
|
||||
|
||||
@ -55,14 +55,22 @@ class T5TextProcessingEngine:
|
||||
tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
def encode_with_transformers(self, tokens, attention_mask=None):
|
||||
device = memory_management.text_encoder_device()
|
||||
tokens = tokens.to(device)
|
||||
self.text_encoder.shared.to(device=device, dtype=torch.float32)
|
||||
|
||||
z = self.text_encoder(
|
||||
input_ids=tokens,
|
||||
)
|
||||
if attention_mask is not None:
|
||||
# CRITICAL: Keep attention mask as boolean to avoid dtype conversion bug
|
||||
attention_mask = attention_mask.to(device=device, dtype=torch.bool)
|
||||
z = self.text_encoder(
|
||||
input_ids=tokens,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
z = self.text_encoder(
|
||||
input_ids=tokens,
|
||||
)
|
||||
|
||||
return z
|
||||
|
||||
@ -113,18 +121,20 @@ class T5TextProcessingEngine:
|
||||
|
||||
return chunks, token_count
|
||||
|
||||
def __call__(self, texts):
|
||||
def __call__(self, texts, return_attention_mask=False):
|
||||
zs = []
|
||||
attention_masks = []
|
||||
cache = {}
|
||||
|
||||
self.emphasis = emphasis.get_current_option(opts.emphasis)()
|
||||
|
||||
for line in texts:
|
||||
if line in cache:
|
||||
line_z_values = cache[line]
|
||||
line_z_values, line_attention_masks = cache[line]
|
||||
else:
|
||||
chunks, token_count = self.tokenize_line(line)
|
||||
line_z_values = []
|
||||
line_attention_masks = []
|
||||
|
||||
# pad all chunks to length of longest chunk
|
||||
max_tokens = 0
|
||||
@ -134,24 +144,49 @@ class T5TextProcessingEngine:
|
||||
for chunk in chunks:
|
||||
tokens = chunk.tokens
|
||||
multipliers = chunk.multipliers
|
||||
|
||||
|
||||
# Track actual token count before padding
|
||||
actual_token_count = len(tokens)
|
||||
|
||||
remaining_count = max_tokens - len(tokens)
|
||||
if remaining_count > 0:
|
||||
tokens += [self.id_pad] * remaining_count
|
||||
multipliers += [1.0] * remaining_count
|
||||
|
||||
z = self.process_tokens([tokens], [multipliers])[0]
|
||||
# Create attention mask - boolean tensor
|
||||
# Chroma requires attention mask to include one padding token
|
||||
# So we extend the mask by 1 if there's padding
|
||||
attention_mask = torch.zeros(len(tokens), dtype=torch.bool)
|
||||
if remaining_count > 0:
|
||||
# Include actual tokens plus one padding token
|
||||
attention_mask[:actual_token_count + 1] = True
|
||||
else:
|
||||
# All tokens are real (no padding)
|
||||
attention_mask[:] = True
|
||||
|
||||
z = self.process_tokens([tokens], [multipliers], [attention_mask])[0]
|
||||
line_z_values.append(z)
|
||||
cache[line] = line_z_values
|
||||
line_attention_masks.append(attention_mask)
|
||||
|
||||
cache[line] = (line_z_values, line_attention_masks)
|
||||
|
||||
zs.extend(line_z_values)
|
||||
attention_masks.extend(line_attention_masks)
|
||||
|
||||
return torch.stack(zs)
|
||||
if return_attention_mask:
|
||||
return torch.stack(zs), torch.stack(attention_masks)
|
||||
else:
|
||||
return torch.stack(zs)
|
||||
|
||||
def process_tokens(self, batch_tokens, batch_multipliers):
|
||||
def process_tokens(self, batch_tokens, batch_multipliers, batch_attention_masks=None):
|
||||
tokens = torch.asarray(batch_tokens)
|
||||
|
||||
z = self.encode_with_transformers(tokens)
|
||||
if batch_attention_masks is not None:
|
||||
attention_mask = torch.stack(batch_attention_masks) if len(batch_attention_masks) > 1 else batch_attention_masks[0].unsqueeze(0)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
z = self.encode_with_transformers(tokens, attention_mask)
|
||||
|
||||
self.emphasis.tokens = batch_tokens
|
||||
self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user