Skip to content

Commit

Permalink
[torch_xla2] Fix type promotions by aten::copy_ and `aten::rand_lik…
Browse files Browse the repository at this point in the history
…e` (pytorch#7960)
  • Loading branch information
will-cromar authored and guyao committed Sep 9, 2024
1 parent 3daf9b3 commit a252c63
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 23 deletions.
35 changes: 25 additions & 10 deletions experimental/torch_xla2/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import unittest

import torch
Expand All @@ -7,20 +8,17 @@
from torch.utils import _pytree as pytree


def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True):
def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True, check_dtype=False):
if isinstance(output1, torch.Tensor):
testcase.assertIsInstance(output2, torch.Tensor)
output2_cpu = output2.detach().cpu()
if output2_cpu.dtype != output1.dtype:
output2_cpu = output2_cpu.to(output1.dtype)
testcase.assertTrue(
torch.allclose(
output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan))
torch.testing.assert_close(
output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan, check_dtype=check_dtype)
elif isinstance(output1, (tuple, list)):
testcase.assertIsInstance(output2, (tuple, list))
testcase.assertEqual(len(output1), len(output2))
for o1, o2 in zip(output1, output2):
diff_output(testcase, o1, o2, rtol, atol)
diff_output(testcase, o1, o2, rtol, atol, equal_nan=equal_nan, check_dtype=check_dtype)
else:
testcase.assertEqual(output1, output2)

Expand All @@ -32,6 +30,7 @@ def run_export_and_compare(testcase,
atol=1e-3,
rtol=1e-5,
equal_nan=True,
check_dtype=False,
ignore_indices=False):

with testcase.subTest("torch_eval"):
Expand All @@ -50,10 +49,11 @@ def run_export_and_compare(testcase,
res2[0],
atol=atol,
rtol=rtol,
equal_nan=equal_nan)
equal_nan=equal_nan,
check_dtype=check_dtype)
else:
diff_output(
testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan)
testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan, check_dtype=check_dtype)


class TestCoreAtenOps(unittest.TestCase):
Expand Down Expand Up @@ -207,7 +207,7 @@ def test_aten__adaptive_avg_pool2d_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._adaptive_avg_pool2d, args,
kwargs)

def test_aten_avg_pool2d_2(self):
args = (
torch.randn((1, 3, 6, 6)).to(torch.float32),
Expand Down Expand Up @@ -4364,6 +4364,21 @@ def test_aten_where_self_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.where.self, args, kwargs)

def test_aten_copy_dtype(self):
args = (
torch.ones((3, 3), dtype=torch.int32),
torch.zeros((3, 3), dtype=torch.float32),
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.copy_, args, kwargs, check_dtype=True)

def test_aten_rand_like(self):
args = (
torch.ones((3, 3), dtype=torch.bfloat16),
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.rand_like, args, kwargs, atol=math.inf, check_dtype=True)


if __name__ == "__main__":
test_base.main()
10 changes: 0 additions & 10 deletions experimental/torch_xla2/test/test_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,6 @@ def test_mul(self):
xt = torch_xla2.tensor.j2t(x._elem)
self.assertEqual(xt, torch.tensor([4, 10, 18], dtype=torch.int32))

def test_div(self):
with self.env:
x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.tensor([4, 5, 6], dtype=torch.int32)

x.div_(y)
xt = torch_xla2.tensor.j2t(x._elem)
self.assertEqual(xt,
torch.tensor([1. / 4, 2. / 5, 3. / 6], dtype=torch.float))


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion experimental/torch_xla2/test_dist/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def f(index: torch_xla2.tensor.XLATensor2):
("op", "expected"),
[
(dist.ReduceOp.SUM, sum(range(4))),
(dist.ReduceOp.AVG, sum(range(4)) / 4),
(dist.ReduceOp.AVG, sum(range(4)) // 4),
(dist.ReduceOp.MIN, 0),
(dist.ReduceOp.MAX, 3),
],
Expand Down
2 changes: 1 addition & 1 deletion experimental/torch_xla2/torch_xla2/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def bernoulli(self, *, generator=None):


def rand_like(self, **kwargs):
dtype = kwargs.get('dtype')
dtype = kwargs.get('dtype', self.dtype)
return torch.rand(self.shape, dtype=dtype)

def channel_shuffle(self, groups):
Expand Down
2 changes: 1 addition & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _aten_add(x, y, *, alpha=1):

@op(torch.ops.aten.copy_, torch.ops.aten.copy_.default, is_jax_function=False)
def _aten_copy(x, y, memory_format=None):
x._elem = y._elem
x._elem = y._elem.astype(x._elem.dtype)
return x


Expand Down

0 comments on commit a252c63

Please sign in to comment.