diff --git a/test/test_autocast_xla.py b/test/test_autocast_xla.py index 352cc7bed1e..657603b9df7 100644 --- a/test/test_autocast_xla.py +++ b/test/test_autocast_xla.py @@ -10,6 +10,7 @@ class TestAutocastXla(unittest.TestCase): + def test_cross_entropy_loss(self): data = torch.randn(16, 10).to(torch.bfloat16).to(device) target = torch.randn(16, 10).to(torch.bfloat16).to(device)