diff --git a/test/test_einsum_autocast.py b/test/test_einsum_autocast.py new file mode 100644 index 00000000000..72520dc62aa --- /dev/null +++ b/test/test_einsum_autocast.py @@ -0,0 +1,24 @@ +import os +import re +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import unittest + +device = xm.xla_device() + +class TestAutocastXla(unittest.TestCase): + def test_einsum(self): + data = torch.randn(16, 10).to(torch.bfloat16).to(device) + target = torch.randn(5, 10).to(torch.bfloat16).to(device) + with torch.autocast("xla"): + product = torch.einsum("...n,mn->...m", data, target) + hlo = torch_xla._XLAC._get_xla_tensors_hlo([product]) + + self.assertTrue(re.search(r".*dot.*bf16", hlo) is not None) + + self.assertTrue(re.search(r".*dot.*f32", hlo) is None) + + +if __name__ == "__main__": + unittest.main() diff --git a/torch_xla/csrc/autocast_mode.cpp b/torch_xla/csrc/autocast_mode.cpp index b151d6460c0..c49caba9295 100644 --- a/torch_xla/csrc/autocast_mode.cpp +++ b/torch_xla/csrc/autocast_mode.cpp @@ -48,6 +48,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { KERNEL_XLA(prelu, lower_precision_fp) KERNEL_XLA(relu, lower_precision_fp) KERNEL_XLA(max_pool2d, lower_precision_fp) + KERNEL_XLA(einsum, lower_precision_fp) // Disable `scaled_dot_product_attention` for now since it causes // undefined symbol with official torch whl. // KERNEL_XLA(scaled_dot_product_attention, lower_precision_fp)