Skip to content

Commit

Permalink
Added softmax to autocast policy (pytorch#8509)
Browse files Browse the repository at this point in the history
  • Loading branch information
avizon-aws authored Dec 20, 2024
1 parent e4a9e12 commit 38d0868
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
9 changes: 9 additions & 0 deletions test/test_autocast_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ def test_cross_entropy_loss(self):
self.assertRegex(hlo, r".*exponential.*f32.*exponential.*f32")
self.assertRegex(hlo, r".*log.*f32.*log.*f32")

def test_softmax(self):
data = torch.randn(16, 20).to(torch.bfloat16).to(device)

with torch.autocast("xla"):
output = torch.nn.Softmax(dim=1)(data)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
self.assertRegex(hlo, r".*convert.*f32.*convert.*bf16")
self.assertRegex(hlo, r".*exponential.*f32.*exponential.*f32")

def test_patchedlinear_autocast(self):
hidden_size = 10
intermediate_size = 15
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) {
// Commented out ops are included in the AutoCastCPU Policy,
// but not lowered. Enable if op is lowered.
KERNEL_XLA(batch_norm, fp32)
KERNEL_XLA(_softmax, fp32)
KERNEL_XLA2(softmax, int, fp32)
KERNEL_XLA2(softmax, Dimname, fp32)
KERNEL_XLA2(log_softmax, int, fp32)
KERNEL_XLA2(log_softmax, Dimname, fp32)
KERNEL_XLA(binary_cross_entropy, fp32)
Expand Down

0 comments on commit 38d0868

Please sign in to comment.