Skip to content

Commit

Permalink
Torch functional tensors needs sync'ing before transferring to CPU.
Browse files Browse the repository at this point in the history
This patch fixes #8482.
  • Loading branch information
mcuiaws committed Dec 11, 2024
1 parent 7dd2697 commit a2ae905
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ function run_xla_op_tests3 {
# NOTE: this line below is testing export and don't care about GPU
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py"
run_test "$CDIR/test_pallas.py"
XLA_DISABLE_FUNCTIONALIZATION=0 run_test "$CDIR/test_functionalization.py"

# CUDA tests
if [ -x "$(command -v nvidia-smi)" ]; then
Expand Down
39 changes: 39 additions & 0 deletions test/test_functionalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import io
import os

os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "0"
os.environ["XLA_ENABLE_PARAM_ALIASING"] = "0"

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import unittest


class TestFunctionalization(unittest.TestCase):

def test_xm_save(self):
"""
Test that xm.save() does torch._functionalize_sync()
"""
xla_device = xm.xla_device()
t1 = torch.tensor([1], device=xla_device)
t2 = t1.detach()
xm.mark_step()

t2.add_(t2)
xm.mark_step()

# mark_step() causes t1 and t2 to be out of sync on the XLA side.
# _functionalize_sync() is needed to get them back in sync.

fobj = io.BytesIO()
xm.save({'t1': t1}, fobj)
fobj.seek(0)
saved = torch.load(fobj)

self.assertEqual(t1.item(), saved['t1'].item())


if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,9 @@ def convert_fn(tensors):
tensors, devices=[], wait=True, sync_xla_data=True)
if not convert:
return tensors
for t in tensors:
if torch._is_functional_tensor(t):
torch._functionalize_sync(t)
return torch_xla._XLAC._xla_get_cpu_tensors(tensors)

def select_fn(v):
Expand Down

0 comments on commit a2ae905

Please sign in to comment.