From 878816724934bc34f2c7b6b5124141116dd20078 Mon Sep 17 00:00:00 2001 From: Mike Cui Date: Wed, 11 Dec 2024 18:52:48 +0000 Subject: [PATCH] xm.save() should not set sync_xla_data=True when sync'ing. Setting sync_xla_data=True performs tensor graph sync as if it's a mark step, which triggers buffer aliasing to be performed. However, it's not safe to do so unless all live tensors are being sync'd. Also fix torch_xla.utils.serialization.save() which has the same issue. This fixes #8422 --- test/test_input_output_aliases.py | 23 +++++++++++++++++++++++ torch_xla/core/xla_model.py | 2 +- torch_xla/utils/serialization.py | 2 +- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index 8e7b0a7ed6e..cae60b7889d 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -1,3 +1,4 @@ +import os import sys import torch @@ -162,6 +163,28 @@ def test_separate_graphs(self): self.assertEqual(t1.item(), 3) + def test_xm_save_no_aliasing(self): + """ + Test that xm.save() does not perform aliasing. + """ + xla_device = xm.xla_device() + t0 = torch.tensor([1], device=xla_device) + t1 = torch.tensor([2], device=xla_device) + xm.mark_step() + + t2 = t0 + t1 + t1.add_(1) + + # Save the new value of t1 should not result in the old value + # being donated... + xm.save(t1, os.devnull) + + # otherwise this mark_step could crash, or compute the wrong value + # for t2. + xm.mark_step() + + self.assertEqual(t2.item(), 3) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index fb6b6bac634..213f5d319dd 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1293,7 +1293,7 @@ def _maybe_convert_to_cpu(data: Any, convert: bool = True) -> ToXlaTensorArena: def convert_fn(tensors): torch_xla._XLAC._xla_sync_multi( - tensors, devices=[], wait=True, sync_xla_data=True) + tensors, devices=[], wait=True, sync_xla_data=False) if not convert: return tensors return torch_xla._XLAC._xla_get_cpu_tensors(tensors) diff --git a/torch_xla/utils/serialization.py b/torch_xla/utils/serialization.py index ed3797ad945..05cfa93e2ea 100644 --- a/torch_xla/utils/serialization.py +++ b/torch_xla/utils/serialization.py @@ -25,7 +25,7 @@ def _rewrite_data(path, data, save_tensors): def convert_fn(tensors): torch_xla._XLAC._xla_sync_multi( - tensors, devices=[], wait=True, sync_xla_data=True) + tensors, devices=[], wait=True, sync_xla_data=False) rewritten_tensors = [] for i, t in enumerate(tensors): if save_tensors: