diff --git a/test/test_autocast_xla.py b/test/test_autocast_xla.py index 2ba73e81a51..96efb880e84 100644 --- a/test/test_autocast_xla.py +++ b/test/test_autocast_xla.py @@ -22,6 +22,15 @@ def test_cross_entropy_loss(self): self.assertRegex(hlo, r".*exponential.*f32.*exponential.*f32") self.assertRegex(hlo, r".*log.*f32.*log.*f32") + def test_softmax(self): + data = torch.randn(16, 20).to(torch.bfloat16).to(device) + + with torch.autocast("xla"): + output = torch.nn.Softmax(dim=1)(data) + hlo = torch_xla._XLAC._get_xla_tensors_hlo([output]) + self.assertRegex(hlo, r".*convert.*f32.*convert.*bf16") + self.assertRegex(hlo, r".*exponential.*f32.*exponential.*f32") + def test_patchedlinear_autocast(self): hidden_size = 10 intermediate_size = 15 diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp index c49caba9295..066c432ac37 100644 --- a/torch_xla/csrc/autocast_mode.cpp +++ b/torch_xla/csrc/autocast_mode.cpp @@ -57,6 +57,9 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { // Commented out ops are included in the AutoCastCPU Policy, // but not lowered. Enable if op is lowered. KERNEL_XLA(batch_norm, fp32) + KERNEL_XLA(_softmax, fp32) + KERNEL_XLA2(softmax, int, fp32) + KERNEL_XLA2(softmax, Dimname, fp32) KERNEL_XLA2(log_softmax, int, fp32) KERNEL_XLA2(log_softmax, Dimname, fp32) KERNEL_XLA(binary_cross_entropy, fp32)