Skip to content

Commit

Permalink
test inplace update together with xm.save
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochengji committed Dec 12, 2024
1 parent 31ca070 commit 5b6917e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
23 changes: 22 additions & 1 deletion test/test_inplace_update.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import io
import sys
import unittest

import torch
import torch_xla
import torch_xla.core.xla_model as xm
from test_utils import temporary_env


class InplaceUpdateTest(unittest.TestCase):
Expand Down Expand Up @@ -49,6 +50,26 @@ def test_non_aten_op_after_partial_update(self):
xm.mark_step()
self.assertTrue(torch.all(torch.eq(y, expected)))

def test_xm_save(self):
with temporary_env(
XLA_DISABLE_FUNCTIONALIZATION="0", XLA_ENABLE_PARAM_ALIASING="0"):
xla_device = xm.xla_device()
t1 = torch.tensor([1], device=xla_device)
t2 = t1.detach()
xm.mark_step()

t2.add_(t2)
xm.mark_step()

# mark_step() causes t1 and t2 to be out of sync on the XLA side.

fobj = io.BytesIO()
xm.save({'t1': t1}, fobj)
fobj.seek(0)
saved = torch.load(fobj)

self.assertEqual(t1.item(), saved['t1'].item())


if __name__ == '__main__':
test = unittest.main()
Expand Down
30 changes: 30 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
from contextlib import contextmanager
import itertools
import math
import os
Expand Down Expand Up @@ -390,3 +391,32 @@ def runAtenTest(self, tensors, fn, device=None, rel_err=1e-2, abs_err=1e-5):
results = xu.as_list(fn(*tensors))
xla_results = xu.as_list(fn(*xla_tensors))
self.compareResults(results, xla_results, rel_err=rel_err, abs_err=abs_err)


@contextmanager
def temporary_env(**kwargs):
"""
Temporarily set environment variables within the context.
Args:
**kwargs: Key-value pairs representing environment variables to set.
For example: temporary_env(PATH='/new/path', DEBUG='1')
"""
original_env = {}

# Store original values and set new ones
for key, value in kwargs.items():
original_env[key] = os.environ.get(key, None)
os.environ[key] = value

try:
yield
finally:
# Restore original environment variables
for key, old_value in original_env.items():
if old_value is None:
# The variable was not originally set
del os.environ[key]
else:
# Restore the original value
os.environ[key] = old_value

0 comments on commit 5b6917e

Please sign in to comment.