Skip to content

Commit

Permalink
Enable cross entropy loss for xla autocast with FP32 precision (#7992) (
Browse files Browse the repository at this point in the history
  • Loading branch information
avizon-aws authored Sep 30, 2024
1 parent 0f645b2 commit 940bee4
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
29 changes: 29 additions & 0 deletions test/test_bf16_autocast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
import re
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import unittest

device = xm.xla_device()


class TestAutocastXla(unittest.TestCase):

def test_cross_entropy_loss(self):
data = torch.randn(16, 10).to(torch.bfloat16).to(device)
target = torch.randn(16, 10).to(torch.bfloat16).to(device)
with torch.autocast("xla"):
loss = torch.nn.CrossEntropyLoss()(data, target)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([loss])
self.assertTrue(
re.search(rf".*convert.*f32.*convert.*bf16", hlo) is not None)

self.assertTrue(
re.search(rf".*exponential.*f32.*exponential.*f32", hlo) is not None)

self.assertTrue(re.search(rf".*log.*f32.*log.*f32", hlo) is not None)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion torch_xla/csrc/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) {
KERNEL_XLA(hinge_embedding_loss, fp32)
// KERNEL_XLA(poisson_nll_loss, fp32)
KERNEL_XLA(smooth_l1_loss, fp32)
// KERNEL_XLA(cross_entropy_loss, fp32)
KERNEL_XLA(cross_entropy_loss, fp32)
KERNEL_XLA(l1_loss, fp32)
// KERNEL_XLA(huber_loss, fp32)
KERNEL_XLA(margin_ranking_loss, fp32)
Expand Down

0 comments on commit 940bee4

Please sign in to comment.