fix for XPU (#1997)

use float32 for XPU, same as previous fix for MPS
This commit is contained in:
DenOfEquity 2024-10-06 14:33:47 +01:00 committed by GitHub
parent 4f7f815b9f
commit 2467c88c50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,7 +19,7 @@ def attention(q, k, v, pe):
def rope(pos, dim, theta):
if pos.device.type == "mps":
if pos.device.type == "mps" or pos.device.type == "xpu":
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
else:
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim