mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-06-30 21:09:33 +08:00
Merge pull request #14559 from Nuullll/ipex-sdpa-fix
[IPEX] Fix SDPA attn_mask dtype
This commit is contained in:
commit
b00b429477
@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention(
|
||||
# cast to same dtype first
|
||||
key = key.to(query.dtype)
|
||||
value = value.to(query.dtype)
|
||||
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||
attn_mask = attn_mask.to(query.dtype)
|
||||
|
||||
N = query.shape[:-2] # Batch size
|
||||
L = query.size(-2) # Target sequence length
|
||||
|
||||
Loading…
Reference in New Issue
Block a user