From ef857715d40b1e2ab4becff6690426dcb07ef6b8 Mon Sep 17 00:00:00 2001 From: mcuiaws Date: Thu, 19 Dec 2024 11:04:18 -0800 Subject: [PATCH] xm.save() should not set sync_xla_data=True when sync'ing. (#8484) (#8504) --- test/test_input_output_aliases.py | 22 ++++++++++++++++++++++ torch_xla/core/xla_model.py | 2 +- torch_xla/utils/serialization.py | 2 +- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index 8e7b0a7ed6e..df21a1ad5e3 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -1,3 +1,4 @@ +import os import sys import torch @@ -162,6 +163,27 @@ def test_separate_graphs(self): self.assertEqual(t1.item(), 3) + def test_xm_save_no_aliasing(self): + """ + Test that xm.save() does not perform aliasing. + """ + xla_device = xm.xla_device() + t0 = torch.tensor([1], device=xla_device) + t1 = torch.tensor([2], device=xla_device) + xm.mark_step() + + t2 = t0 + t1 + t1.add_(1) + + # Save the new value of t1 should not result in the old value + # being donated... + xm.save(t1, os.devnull) + + # otherwise this mark_step could crash, or compute the wrong value + # for t2. + xm.mark_step() + + self.assertEqual(t2.item(), 3) if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index fb6b6bac634..213f5d319dd 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1293,7 +1293,7 @@ def _maybe_convert_to_cpu(data: Any, convert: bool = True) -> ToXlaTensorArena: def convert_fn(tensors): torch_xla._XLAC._xla_sync_multi( - tensors, devices=[], wait=True, sync_xla_data=True) + tensors, devices=[], wait=True, sync_xla_data=False) if not convert: return tensors return torch_xla._XLAC._xla_get_cpu_tensors(tensors) diff --git a/torch_xla/utils/serialization.py b/torch_xla/utils/serialization.py index ed3797ad945..05cfa93e2ea 100644 --- a/torch_xla/utils/serialization.py +++ b/torch_xla/utils/serialization.py @@ -25,7 +25,7 @@ def _rewrite_data(path, data, save_tensors): def convert_fn(tensors): torch_xla._XLAC._xla_sync_multi( - tensors, devices=[], wait=True, sync_xla_data=True) + tensors, devices=[], wait=True, sync_xla_data=False) rewritten_tensors = [] for i, t in enumerate(tensors): if save_tensors: