From a252c63facb22f5d851aadbc58c2c26e8db0f13a Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 9 Sep 2024 14:50:08 -0700 Subject: [PATCH] [torch_xla2] Fix type promotions by `aten::copy_` and `aten::rand_like` (#7960) --- .../torch_xla2/test/test_core_aten_ops.py | 35 +++++++++++++------ .../torch_xla2/test/test_mutations.py | 10 ------ .../torch_xla2/test_dist/test_distributed.py | 2 +- .../torch_xla2/torch_xla2/decompositions.py | 2 +- .../torch_xla2/torch_xla2/ops/jaten.py | 2 +- 5 files changed, 28 insertions(+), 23 deletions(-) diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index af025e1ef791..c3108bce1c81 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -1,3 +1,4 @@ +import math import unittest import torch @@ -7,20 +8,17 @@ from torch.utils import _pytree as pytree -def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True): +def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True, check_dtype=False): if isinstance(output1, torch.Tensor): testcase.assertIsInstance(output2, torch.Tensor) output2_cpu = output2.detach().cpu() - if output2_cpu.dtype != output1.dtype: - output2_cpu = output2_cpu.to(output1.dtype) - testcase.assertTrue( - torch.allclose( - output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan)) + torch.testing.assert_close( + output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan, check_dtype=check_dtype) elif isinstance(output1, (tuple, list)): testcase.assertIsInstance(output2, (tuple, list)) testcase.assertEqual(len(output1), len(output2)) for o1, o2 in zip(output1, output2): - diff_output(testcase, o1, o2, rtol, atol) + diff_output(testcase, o1, o2, rtol, atol, equal_nan=equal_nan, check_dtype=check_dtype) else: testcase.assertEqual(output1, output2) @@ -32,6 +30,7 @@ def run_export_and_compare(testcase, atol=1e-3, rtol=1e-5, equal_nan=True, + check_dtype=False, ignore_indices=False): with testcase.subTest("torch_eval"): @@ -50,10 +49,11 @@ def run_export_and_compare(testcase, res2[0], atol=atol, rtol=rtol, - equal_nan=equal_nan) + equal_nan=equal_nan, + check_dtype=check_dtype) else: diff_output( - testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan) + testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan, check_dtype=check_dtype) class TestCoreAtenOps(unittest.TestCase): @@ -207,7 +207,7 @@ def test_aten__adaptive_avg_pool2d_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten._adaptive_avg_pool2d, args, kwargs) - + def test_aten_avg_pool2d_2(self): args = ( torch.randn((1, 3, 6, 6)).to(torch.float32), @@ -4364,6 +4364,21 @@ def test_aten_where_self_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.where.self, args, kwargs) + def test_aten_copy_dtype(self): + args = ( + torch.ones((3, 3), dtype=torch.int32), + torch.zeros((3, 3), dtype=torch.float32), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.copy_, args, kwargs, check_dtype=True) + + def test_aten_rand_like(self): + args = ( + torch.ones((3, 3), dtype=torch.bfloat16), + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.rand_like, args, kwargs, atol=math.inf, check_dtype=True) + if __name__ == "__main__": test_base.main() diff --git a/experimental/torch_xla2/test/test_mutations.py b/experimental/torch_xla2/test/test_mutations.py index d0c31d5f3ed1..510881431751 100644 --- a/experimental/torch_xla2/test/test_mutations.py +++ b/experimental/torch_xla2/test/test_mutations.py @@ -34,16 +34,6 @@ def test_mul(self): xt = torch_xla2.tensor.j2t(x._elem) self.assertEqual(xt, torch.tensor([4, 10, 18], dtype=torch.int32)) - def test_div(self): - with self.env: - x = torch.tensor([1, 2, 3], dtype=torch.int32) - y = torch.tensor([4, 5, 6], dtype=torch.int32) - - x.div_(y) - xt = torch_xla2.tensor.j2t(x._elem) - self.assertEqual(xt, - torch.tensor([1. / 4, 2. / 5, 3. / 6], dtype=torch.float)) - if __name__ == '__main__': unittest.main() diff --git a/experimental/torch_xla2/test_dist/test_distributed.py b/experimental/torch_xla2/test_dist/test_distributed.py index 92d4322bff06..7875b96dcb62 100644 --- a/experimental/torch_xla2/test_dist/test_distributed.py +++ b/experimental/torch_xla2/test_dist/test_distributed.py @@ -73,7 +73,7 @@ def f(index: torch_xla2.tensor.XLATensor2): ("op", "expected"), [ (dist.ReduceOp.SUM, sum(range(4))), - (dist.ReduceOp.AVG, sum(range(4)) / 4), + (dist.ReduceOp.AVG, sum(range(4)) // 4), (dist.ReduceOp.MIN, 0), (dist.ReduceOp.MAX, 3), ], diff --git a/experimental/torch_xla2/torch_xla2/decompositions.py b/experimental/torch_xla2/torch_xla2/decompositions.py index 4ef7537ea130..354ac3d93bff 100644 --- a/experimental/torch_xla2/torch_xla2/decompositions.py +++ b/experimental/torch_xla2/torch_xla2/decompositions.py @@ -101,7 +101,7 @@ def bernoulli(self, *, generator=None): def rand_like(self, **kwargs): - dtype = kwargs.get('dtype') + dtype = kwargs.get('dtype', self.dtype) return torch.rand(self.shape, dtype=dtype) def channel_shuffle(self, groups): diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index e094c0bfc792..8c3ead024880 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -82,7 +82,7 @@ def _aten_add(x, y, *, alpha=1): @op(torch.ops.aten.copy_, torch.ops.aten.copy_.default, is_jax_function=False) def _aten_copy(x, y, memory_format=None): - x._elem = y._elem + x._elem = y._elem.astype(x._elem.dtype) return x