diff --git a/test/run_tests.sh b/test/run_tests.sh index c1720b53e99..7e78c56a54a 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -252,6 +252,7 @@ function run_xla_op_tests3 { # NOTE: this line below is testing export and don't care about GPU PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py" run_test "$CDIR/test_pallas.py" + XLA_DISABLE_FUNCTIONALIZATION=0 run_test "$CDIR/test_functionalization.py" # CUDA tests if [ -x "$(command -v nvidia-smi)" ]; then diff --git a/test/test_functionalization.py b/test/test_functionalization.py new file mode 100644 index 00000000000..5e524910d6f --- /dev/null +++ b/test/test_functionalization.py @@ -0,0 +1,39 @@ +import io +import os + +os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "0" +os.environ["XLA_ENABLE_PARAM_ALIASING"] = "0" + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import unittest + + +class TestFunctionalization(unittest.TestCase): + + def test_xm_save(self): + """ + Test that xm.save() does torch._functionalize_sync() + """ + xla_device = xm.xla_device() + t1 = torch.tensor([1], device=xla_device) + t2 = t1.detach() + xm.mark_step() + + t2.add_(t2) + xm.mark_step() + + # mark_step() causes t1 and t2 to be out of sync on the XLA side. + # _functionalize_sync() is needed to get them back in sync. + + fobj = io.BytesIO() + xm.save({'t1': t1}, fobj) + fobj.seek(0) + saved = torch.load(fobj) + + self.assertEqual(t1.item(), saved['t1'].item()) + + +if __name__ == "__main__": + unittest.main() diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 6e1936c258a..48af793d114 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1281,6 +1281,9 @@ def convert_fn(tensors): tensors, devices=[], wait=True, sync_xla_data=True) if not convert: return tensors + for t in tensors: + if torch._is_functional_tensor(t): + torch._functionalize_sync(t) return torch_xla._XLAC._xla_get_cpu_tensors(tensors) def select_fn(v):