Skip to content

Commit

Permalink
xm.save() should not set sync_xla_data=True when sync'ing.
Browse files Browse the repository at this point in the history
Setting sync_xla_data=True performs tensor graph sync as if
it's a mark step, which triggers buffer aliasing to be performed.
However, it's not safe to do so unless all live tensors are
being sync'd.

Also fix torch_xla.utils.serialization.save() which has the same
issue.

This fixes #8422
  • Loading branch information
mcuiaws committed Dec 17, 2024
1 parent 0121444 commit 8788167
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
23 changes: 23 additions & 0 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys

import torch
Expand Down Expand Up @@ -162,6 +163,28 @@ def test_separate_graphs(self):

self.assertEqual(t1.item(), 3)

def test_xm_save_no_aliasing(self):
"""
Test that xm.save() does not perform aliasing.
"""
xla_device = xm.xla_device()
t0 = torch.tensor([1], device=xla_device)
t1 = torch.tensor([2], device=xla_device)
xm.mark_step()

t2 = t0 + t1
t1.add_(1)

# Save the new value of t1 should not result in the old value
# being donated...
xm.save(t1, os.devnull)

# otherwise this mark_step could crash, or compute the wrong value
# for t2.
xm.mark_step()

self.assertEqual(t2.item(), 3)


if __name__ == '__main__':
test = unittest.main()
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,7 @@ def _maybe_convert_to_cpu(data: Any, convert: bool = True) -> ToXlaTensorArena:

def convert_fn(tensors):
torch_xla._XLAC._xla_sync_multi(
tensors, devices=[], wait=True, sync_xla_data=True)
tensors, devices=[], wait=True, sync_xla_data=False)
if not convert:
return tensors
return torch_xla._XLAC._xla_get_cpu_tensors(tensors)
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _rewrite_data(path, data, save_tensors):

def convert_fn(tensors):
torch_xla._XLAC._xla_sync_multi(
tensors, devices=[], wait=True, sync_xla_data=True)
tensors, devices=[], wait=True, sync_xla_data=False)
rewritten_tensors = []
for i, t in enumerate(tensors):
if save_tensors:
Expand Down

0 comments on commit 8788167

Please sign in to comment.