diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index 92d34f96c92f..83420e976080 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -30,7 +30,8 @@ def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True): def run_export_and_compare( - testcase, func, args, kwargs, atol=1e-3, rtol=1e-5, equal_nan=True + testcase, func, args, kwargs, atol=1e-3, rtol=1e-5, equal_nan=True, + ignore_indices=False ): with testcase.subTest("torch_eval"): res = func(*args, **kwargs) @@ -42,9 +43,14 @@ def run_export_and_compare( res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2) # import pdb; pdb.set_trace() with testcase.subTest("torch_xla2_diff:" + str(atol)): - diff_output( - testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan - ) + if ignore_indices and isinstance(res, tuple) and len(res) == 2: + diff_output( + testcase, res[0], res2[0], atol=atol, rtol=rtol, equal_nan=equal_nan + ) + else: + diff_output( + testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan + ) class TestCoreAtenOps(unittest.TestCase): @@ -175,7 +181,6 @@ def test_aten_unsqueeze_8(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - @unittest.skip def test_aten__adaptive_avg_pool2d_0(self): args = ( torch.randn((1, 3, 1, 10)).to(torch.float32), @@ -189,7 +194,6 @@ def test_aten__adaptive_avg_pool2d_0(self): self, torch.ops.aten._adaptive_avg_pool2d, args, kwargs ) - @unittest.skip def test_aten__adaptive_avg_pool2d_1(self): args = ( torch.randn((1, 3, 10, 10)).to(torch.float32), @@ -243,7 +247,6 @@ def test_aten_squeeze_dim_4(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) - @unittest.skip def test_aten__adaptive_avg_pool3d_0(self): args = ( torch.randn((1, 3, 10, 10, 10)).to(torch.float32), @@ -258,7 +261,6 @@ def test_aten__adaptive_avg_pool3d_0(self): self, torch.ops.aten._adaptive_avg_pool3d, args, kwargs ) - @unittest.skip def test_aten__adaptive_avg_pool3d_1(self): args = ( torch.randn((1, 3, 10, 10, 10)).to(torch.float16), @@ -513,23 +515,23 @@ def test_aten_argmin_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.argmin, args, kwargs) - @unittest.skip def test_aten_as_strided_0(self): args = ( torch.randn((10, 10)).to(torch.float32), [ - 0, - 1, + 2, + 2, + 2 ], [ - 0, + 8, + 4, 1, ], ) kwargs = dict() run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs) - @unittest.skip def test_aten_as_strided_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -545,7 +547,6 @@ def test_aten_as_strided_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs) - @unittest.skip def test_aten_as_strided_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -561,7 +562,6 @@ def test_aten_as_strided_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs) - @unittest.skip def test_aten_as_strided_copy_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -577,7 +577,6 @@ def test_aten_as_strided_copy_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.as_strided_copy, args, kwargs) - @unittest.skip def test_aten_as_strided_copy_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -593,7 +592,6 @@ def test_aten_as_strided_copy_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.as_strided_copy, args, kwargs) - @unittest.skip def test_aten_as_strided_copy_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -685,7 +683,6 @@ def test_aten_atanh_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.atanh, args, kwargs) - @unittest.skip def test_aten_avg_pool2d_0(self): args = ( torch.randn((1, 3, 1, 10)).to(torch.float32), @@ -701,7 +698,6 @@ def test_aten_avg_pool2d_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.avg_pool2d, args, kwargs) - @unittest.skip def test_aten_avg_pool2d_1(self): args = ( torch.randn((3, 2, 10)).to(torch.float32), @@ -721,7 +717,6 @@ def test_aten_avg_pool2d_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.avg_pool2d, args, kwargs) - @unittest.skip def test_aten_avg_pool3d_0(self): args = ( torch.randn((1, 3, 10, 10, 10)).to(torch.float32), @@ -868,11 +863,10 @@ def test_aten_cat_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.cat, args, kwargs) - @unittest.skip def test_aten__cdist_forward_0(self): args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), + torch.randn((5, 7, 10)).to(torch.float32), + torch.randn((5, 8, 10)).to(torch.float32), 1.0, None, ) @@ -1304,19 +1298,16 @@ def test_aten_eq_Tensor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.eq.Tensor, args, kwargs) - @unittest.skip def test_aten_erf_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.erf, args, kwargs) - @unittest.skip def test_aten_erf_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.erf, args, kwargs) - @unittest.skip def test_aten_erf_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -1418,7 +1409,6 @@ def test_aten_expm1_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.expm1, args, kwargs) - @unittest.skip def test_aten_fill_Scalar_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1427,7 +1417,6 @@ def test_aten_fill_Scalar_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.fill.Scalar, args, kwargs) - @unittest.skip def test_aten_fill_Scalar_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -1436,7 +1425,6 @@ def test_aten_fill_Scalar_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.fill.Scalar, args, kwargs) - @unittest.skip def test_aten_fill_Scalar_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -1445,33 +1433,6 @@ def test_aten_fill_Scalar_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.fill.Scalar, args, kwargs) - @unittest.skip - def test_aten_fill_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn(()).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.fill.Tensor, args, kwargs) - - @unittest.skip - def test_aten_fill_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn(()).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.fill.Tensor, args, kwargs) - - @unittest.skip - def test_aten_fill_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, ()).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.fill.Tensor, args, kwargs) - def test_aten_flip_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1560,7 +1521,6 @@ def test_aten_fmod_Tensor_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.fmod.Tensor, args, kwargs) - @unittest.skip def test_aten_full_like_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1569,7 +1529,6 @@ def test_aten_full_like_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.full_like, args, kwargs) - @unittest.skip def test_aten_full_like_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -1578,7 +1537,6 @@ def test_aten_full_like_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.full_like, args, kwargs) - @unittest.skip def test_aten_full_like_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -1587,7 +1545,6 @@ def test_aten_full_like_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.full_like, args, kwargs) - @unittest.skip def test_aten_gather_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1597,7 +1554,6 @@ def test_aten_gather_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.gather, args, kwargs) - @unittest.skip def test_aten_gather_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -1607,7 +1563,6 @@ def test_aten_gather_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.gather, args, kwargs) - @unittest.skip def test_aten_gather_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -1633,7 +1588,6 @@ def test_aten_ge_Scalar_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.ge.Scalar, args, kwargs) - @unittest.skip def test_aten_ge_Scalar_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -1676,7 +1630,6 @@ def test_aten_gelu_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs, atol=0.01, rtol=0.01) - @unittest.skip def test_aten_glu_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1685,7 +1638,6 @@ def test_aten_glu_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.glu, args, kwargs) - @unittest.skip def test_aten_glu_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -1694,7 +1646,6 @@ def test_aten_glu_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.glu, args, kwargs) - @unittest.skip def test_aten_grid_sampler_2d_0(self): args = ( torch.randn((1, 3, 2, 10)).to(torch.float32), @@ -2199,7 +2150,6 @@ def test_aten_lt_Tensor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.lt.Tensor, args, kwargs) - @unittest.skip def test_aten_masked_fill_Scalar_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -2211,7 +2161,6 @@ def test_aten_masked_fill_Scalar_0(self): self, torch.ops.aten.masked_fill.Scalar, args, kwargs ) - @unittest.skip def test_aten_masked_fill_Scalar_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -2223,7 +2172,6 @@ def test_aten_masked_fill_Scalar_1(self): self, torch.ops.aten.masked_fill.Scalar, args, kwargs ) - @unittest.skip def test_aten_masked_fill_Scalar_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -2259,7 +2207,6 @@ def test_aten_max_dim_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.max.dim, args, kwargs) - @unittest.skip def test_aten_max_pool2d_with_indices_0(self): args = ( torch.randn((3, 2, 10)).to(torch.float32), @@ -2278,10 +2225,10 @@ def test_aten_max_pool2d_with_indices_0(self): ) kwargs = dict() run_export_and_compare( - self, torch.ops.aten.max_pool2d_with_indices, args, kwargs + self, torch.ops.aten.max_pool2d_with_indices, args, kwargs, + ignore_indices=True ) - @unittest.skip def test_aten_max_pool2d_with_indices_1(self): args = ( torch.randn((3, 2, 10)).to(torch.float16), @@ -2300,13 +2247,13 @@ def test_aten_max_pool2d_with_indices_1(self): ) kwargs = dict() run_export_and_compare( - self, torch.ops.aten.max_pool2d_with_indices, args, kwargs + self, torch.ops.aten.max_pool2d_with_indices, args, kwargs, + ignore_indices=True ) - @unittest.skip def test_aten_max_pool2d_with_indices_2(self): args = ( - torch.randint(0, 10, (3, 2, 10)).to(torch.int32), + torch.arange(0, 60).reshape(3, 2, 10), [ 2, 2, @@ -2322,10 +2269,10 @@ def test_aten_max_pool2d_with_indices_2(self): ) kwargs = dict() run_export_and_compare( - self, torch.ops.aten.max_pool2d_with_indices, args, kwargs + self, torch.ops.aten.max_pool2d_with_indices, args, kwargs, + ignore_indices=True ) - @unittest.skip def test_aten_max_pool3d_with_indices_0(self): args = ( torch.randn((1, 3, 2, 10)).to(torch.float32), @@ -2350,7 +2297,6 @@ def test_aten_max_pool3d_with_indices_0(self): self, torch.ops.aten.max_pool3d_with_indices, args, kwargs ) - @unittest.skip def test_aten_max_pool3d_with_indices_1(self): args = ( torch.randn((1, 3, 2, 10)).to(torch.float16), @@ -2375,10 +2321,9 @@ def test_aten_max_pool3d_with_indices_1(self): self, torch.ops.aten.max_pool3d_with_indices, args, kwargs ) - @unittest.skip def test_aten_max_pool3d_with_indices_2(self): args = ( - torch.randint(0, 10, (1, 3, 2, 10)).to(torch.int32), + torch.arange(0, 60).reshape(1, 3, 2, 10), [ 2, 2, @@ -2570,7 +2515,6 @@ def test_aten_mul_Tensor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.mul.Tensor, args, kwargs) - @unittest.skip def test_aten__native_batch_norm_legit_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -2587,7 +2531,6 @@ def test_aten__native_batch_norm_legit_0(self): self, torch.ops.aten._native_batch_norm_legit, args, kwargs ) - @unittest.skip def test_aten__native_batch_norm_legit_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -2601,15 +2544,15 @@ def test_aten__native_batch_norm_legit_1(self): ) kwargs = dict() run_export_and_compare( - self, torch.ops.aten._native_batch_norm_legit, args, kwargs + self, torch.ops.aten._native_batch_norm_legit, args, kwargs, + atol=0.01, rtol=0.01, ) - @unittest.skip def test_aten__native_batch_norm_legit_no_stats_0(self): args = ( torch.randn((1, 3, 2, 10)).to(torch.float32), - torch.randn((1, 3, 2, 10)).to(torch.float32), - torch.randn((1, 3, 2, 10)).to(torch.float32), + torch.randn((1, 3, 1, 1)).to(torch.float32), + torch.randn((1, 3, 1, 1)).to(torch.float32), True, 0.0, 1.0, @@ -2619,12 +2562,11 @@ def test_aten__native_batch_norm_legit_no_stats_0(self): self, torch.ops.aten._native_batch_norm_legit.no_stats, args, kwargs ) - @unittest.skip def test_aten__native_batch_norm_legit_no_stats_1(self): args = ( torch.randn((1, 3, 2, 10)).to(torch.float16), - torch.randn((1, 3, 2, 10)).to(torch.float16), - torch.randn((1, 3, 2, 10)).to(torch.float16), + torch.randn((1, 3, 1, 1)).to(torch.float32), + torch.randn((1, 3, 1, 1)).to(torch.float32), True, 0.0, 1.0, @@ -2649,7 +2591,6 @@ def test_aten__native_batch_norm_legit_no_training_0(self): self, torch.ops.aten._native_batch_norm_legit_no_training, args, kwargs ) - @unittest.skip def test_aten__native_batch_norm_legit_no_training_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -2662,30 +2603,30 @@ def test_aten__native_batch_norm_legit_no_training_1(self): ) kwargs = dict() run_export_and_compare( - self, torch.ops.aten._native_batch_norm_legit_no_training, args, kwargs + self, torch.ops.aten._native_batch_norm_legit_no_training, args, kwargs, + atol=0.01, rtol=0.01, ) - @unittest.skip def test_aten_native_dropout_0(self): args = ( torch.randn((10, 10)).to(torch.float32), 1.0, - None, + True, ) kwargs = dict() run_export_and_compare(self, torch.ops.aten.native_dropout, args, kwargs) - @unittest.skip def test_aten_native_dropout_1(self): args = ( torch.randn((10, 10)).to(torch.float16), 1.0, - None, + False, ) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.native_dropout, args, kwargs) + run_export_and_compare(self, torch.ops.aten.native_dropout, args, kwargs, + atol=0.01, rtol=0.01 + ) - @unittest.skip def test_aten_native_group_norm_0(self): args = ( torch.randn((1, 3, 2, 10)).to(torch.float32), @@ -2700,7 +2641,6 @@ def test_aten_native_group_norm_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs) - @unittest.skip def test_aten_native_group_norm_1(self): args = ( torch.randn((1, 3, 2, 10)).to(torch.float16), @@ -2713,9 +2653,10 @@ def test_aten_native_group_norm_1(self): 0.0, ) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs) + run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs, + atol=0.01, rtol=0.01 + ) - @unittest.skip def test_aten_native_layer_norm_0(self): args = ( torch.randn((1, 3, 2, 10)).to(torch.float32), @@ -2795,25 +2736,21 @@ def test_aten_neg_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.neg, args, kwargs) - @unittest.skip def test_aten_nonzero_0(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.nonzero, args, kwargs) - @unittest.skip def test_aten_nonzero_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.nonzero, args, kwargs) - @unittest.skip def test_aten_nonzero_2(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.nonzero, args, kwargs) - @unittest.skip def test_aten__pdist_forward_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -2891,7 +2828,6 @@ def test_aten_permute_copy_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.permute_copy, args, kwargs) - @unittest.skip def test_aten_pixel_shuffle_0(self): args = ( torch.randn((1, 3, 10, 10)).to(torch.float32), @@ -2900,7 +2836,6 @@ def test_aten_pixel_shuffle_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.pixel_shuffle, args, kwargs) - @unittest.skip def test_aten_pixel_shuffle_1(self): args = ( torch.randn((1, 3, 10, 10)).to(torch.float16), @@ -2909,7 +2844,6 @@ def test_aten_pixel_shuffle_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.pixel_shuffle, args, kwargs) - @unittest.skip def test_aten_pixel_shuffle_2(self): args = ( torch.randint(0, 10, (1, 3, 10, 10)).to(torch.int32), @@ -2979,19 +2913,16 @@ def test_aten_pow_Tensor_Tensor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.pow.Tensor_Tensor, args, kwargs) - @unittest.skip def test_aten_prod_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.prod, args, kwargs) - @unittest.skip def test_aten_prod_1(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.prod, args, kwargs) - @unittest.skip def test_aten_prod_dim_int_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3000,7 +2931,6 @@ def test_aten_prod_dim_int_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.prod.dim_int, args, kwargs) - @unittest.skip def test_aten_prod_dim_int_1(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3009,19 +2939,16 @@ def test_aten_prod_dim_int_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.prod.dim_int, args, kwargs) - @unittest.skip def test_aten_reciprocal_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs) - @unittest.skip def test_aten_reciprocal_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs) - @unittest.skip def test_aten_reciprocal_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -3038,7 +2965,6 @@ def test_aten_reflection_pad1d_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad1d, args, kwargs) - @unittest.skip def test_aten_reflection_pad1d_1(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3050,7 +2976,6 @@ def test_aten_reflection_pad1d_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad1d, args, kwargs) - @unittest.skip def test_aten_reflection_pad2d_0(self): args = ( torch.randn((3, 2, 10)).to(torch.float32), @@ -3064,7 +2989,6 @@ def test_aten_reflection_pad2d_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad2d, args, kwargs) - @unittest.skip def test_aten_reflection_pad2d_1(self): args = ( torch.randint(0, 10, (3, 2, 10)).to(torch.int32), @@ -3078,10 +3002,9 @@ def test_aten_reflection_pad2d_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad2d, args, kwargs) - @unittest.skip def test_aten_reflection_pad3d_0(self): args = ( - torch.randn((3, 3, 3, 3, 3, 3)).to(torch.float32), + torch.randn((3, 3, 3, 3)).to(torch.float32), [ 1, 2, @@ -3094,10 +3017,9 @@ def test_aten_reflection_pad3d_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs) - @unittest.skip def test_aten_reflection_pad3d_1(self): args = ( - torch.randn((3, 3, 3, 3, 3, 3)).to(torch.float16), + torch.randn((3, 3, 3, 3, 3)).to(torch.float16), [ 1, 2, @@ -3110,10 +3032,9 @@ def test_aten_reflection_pad3d_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs) - @unittest.skip def test_aten_reflection_pad3d_2(self): args = ( - torch.randint(0, 10, (3, 3, 3, 3, 3, 3)).to(torch.int32), + torch.randint(0, 10, (3, 3, 3, 3)).to(torch.int32), [ 1, 2, @@ -3141,7 +3062,6 @@ def test_aten_relu_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.relu, args, kwargs) - @unittest.skip def test_aten_remainder_Scalar_0(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3150,7 +3070,6 @@ def test_aten_remainder_Scalar_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.remainder.Scalar, args, kwargs) - @unittest.skip def test_aten_remainder_Scalar_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3159,7 +3078,6 @@ def test_aten_remainder_Scalar_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.remainder.Scalar, args, kwargs) - @unittest.skip def test_aten_remainder_Scalar_2(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3168,7 +3086,6 @@ def test_aten_remainder_Scalar_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.remainder.Scalar, args, kwargs) - @unittest.skip def test_aten_remainder_Tensor_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3177,7 +3094,6 @@ def test_aten_remainder_Tensor_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.remainder.Tensor, args, kwargs) - @unittest.skip def test_aten_remainder_Tensor_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3186,7 +3102,6 @@ def test_aten_remainder_Tensor_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.remainder.Tensor, args, kwargs) - @unittest.skip def test_aten_replication_pad2d_0(self): args = ( torch.randn((3, 2, 10)).to(torch.float32), @@ -3200,7 +3115,6 @@ def test_aten_replication_pad2d_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.replication_pad2d, args, kwargs) - @unittest.skip def test_aten_replication_pad2d_1(self): args = ( torch.randint(0, 10, (3, 2, 10)).to(torch.int32), @@ -3214,7 +3128,6 @@ def test_aten_replication_pad2d_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.replication_pad2d, args, kwargs) - @unittest.skip def test_aten_replication_pad3d_0(self): args = ( torch.randn((1, 3, 2, 10)).to(torch.float32), @@ -3230,7 +3143,6 @@ def test_aten_replication_pad3d_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.replication_pad3d, args, kwargs) - @unittest.skip def test_aten_replication_pad3d_1(self): args = ( torch.randint(0, 10, (1, 3, 2, 10)).to(torch.int32), @@ -3246,46 +3158,6 @@ def test_aten_replication_pad3d_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.replication_pad3d, args, kwargs) - @unittest.skip - def test_aten_resize__0(self): - args = ( - torch.randn((2, 5, 10)).to(torch.float32), - [ - 2, - 5, - 10, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.resize_, args, kwargs) - - @unittest.skip - def test_aten_resize__1(self): - args = ( - torch.randn((2, 5, 10)).to(torch.float16), - [ - 2, - 5, - 10, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.resize_, args, kwargs) - - @unittest.skip - def test_aten_resize__2(self): - args = ( - torch.randint(0, 10, (2, 5, 10)).to(torch.int32), - [ - 2, - 5, - 10, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.resize_, args, kwargs) - - @unittest.skip def test_aten_roll_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3301,7 +3173,6 @@ def test_aten_roll_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.roll, args, kwargs) - @unittest.skip def test_aten_roll_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3317,7 +3188,6 @@ def test_aten_roll_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.roll, args, kwargs) - @unittest.skip def test_aten_roll_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3333,19 +3203,16 @@ def test_aten_roll_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.roll, args, kwargs) - @unittest.skip def test_aten_round_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.round, args, kwargs) - @unittest.skip def test_aten_round_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.round, args, kwargs) - @unittest.skip def test_aten_round_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -3519,7 +3386,6 @@ def test_aten_scatter_value_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.scatter.value, args, kwargs) - @unittest.skip def test_aten_select_copy_int_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3529,7 +3395,6 @@ def test_aten_select_copy_int_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.select_copy.int, args, kwargs) - @unittest.skip def test_aten_select_copy_int_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3539,7 +3404,6 @@ def test_aten_select_copy_int_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.select_copy.int, args, kwargs) - @unittest.skip def test_aten_select_copy_int_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3606,37 +3470,31 @@ def test_aten_select_scatter_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs) - @unittest.skip def test_aten_sigmoid_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs) - @unittest.skip def test_aten_sigmoid_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs) - @unittest.skip def test_aten_sigmoid_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs) - @unittest.skip def test_aten_sign_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.sign, args, kwargs) - @unittest.skip def test_aten_sign_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.sign, args, kwargs) - @unittest.skip def test_aten_sign_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -3672,19 +3530,16 @@ def test_aten_sinh_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.sinh, args, kwargs) - @unittest.skip def test_aten_slice_copy_Tensor_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.slice_copy.Tensor, args, kwargs) - @unittest.skip def test_aten_slice_copy_Tensor_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.slice_copy.Tensor, args, kwargs) - @unittest.skip def test_aten_slice_copy_Tensor_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -3790,7 +3645,6 @@ def test_aten_sort_2(self): ) self._compare_sorted_result(args) - @unittest.skip def test_aten_split_copy_Tensor_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3799,7 +3653,6 @@ def test_aten_split_copy_Tensor_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.split_copy.Tensor, args, kwargs) - @unittest.skip def test_aten_split_copy_Tensor_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3808,7 +3661,6 @@ def test_aten_split_copy_Tensor_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.split_copy.Tensor, args, kwargs) - @unittest.skip def test_aten_split_copy_Tensor_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3817,7 +3669,6 @@ def test_aten_split_copy_Tensor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.split_copy.Tensor, args, kwargs) - @unittest.skip def test_aten_split_with_sizes_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3831,7 +3682,6 @@ def test_aten_split_with_sizes_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.split_with_sizes, args, kwargs) - @unittest.skip def test_aten_split_with_sizes_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3845,7 +3695,6 @@ def test_aten_split_with_sizes_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.split_with_sizes, args, kwargs) - @unittest.skip def test_aten_split_with_sizes_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -4161,19 +4010,16 @@ def test_aten_transpose_copy_int_2(self): self, torch.ops.aten.transpose_copy.int, args, kwargs ) - @unittest.skip def test_aten_tril_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.tril, args, kwargs) - @unittest.skip def test_aten_tril_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.tril, args, kwargs) - @unittest.skip def test_aten_tril_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() diff --git a/experimental/torch_xla2/torch_xla2/ops.py b/experimental/torch_xla2/torch_xla2/ops.py index cb9b57cdac5f..aa0643c61bc4 100644 --- a/experimental/torch_xla2/torch_xla2/ops.py +++ b/experimental/torch_xla2/torch_xla2/ops.py @@ -2,6 +2,7 @@ """Torch ops implemented using jax.""" import sys +import flax import jax from jax import numpy as jnp import numpy as np @@ -107,6 +108,7 @@ def _aten_index_copy(x, dim, indexes, source): @op(torch.ops.aten.select) @op(torch.ops.aten.index_select) +@op(torch.ops.aten.select_copy) def _aten_index_select(x, dim, indexes): dims = [] for i in range(len(x.shape)): @@ -179,6 +181,7 @@ def _aten_triu(m, k): @op(torch.ops.aten.slice) +@op(torch.ops.aten.slice_copy) def _aten_slice(self, dim=0, start=None, end=None, step=1): if end == sys.maxsize: end = self.shape[dim] @@ -311,8 +314,9 @@ def _aten_index(self, indexes): @op(torch.ops.aten.split) +@op(torch.ops.aten.split_copy) @op(torch.ops.aten.split_with_sizes) -def split_with_sizes(x, sizes, dim): +def split_with_sizes(x, sizes, dim=0): """Splits an array `x` into sub-arrays based on static sizes `sizes`. Args: @@ -509,6 +513,14 @@ def create_default_conv_dimension_numbers(num_spatial_dims): res = res + bias return res +# _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) +@op(torch.ops.aten._native_batch_norm_legit) +def _aten__native_batch_norm_legit( + input, weight, bias, running_mean, running_var, training, momentum, eps +): + return _aten__native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps + ) @op(torch.ops.aten._native_batch_norm_legit_no_training) def _aten__native_batch_norm_legit_no_training( @@ -542,59 +554,72 @@ def _aten_cat(tensors, dims=0): @op(torch.ops.aten.max_pool2d_with_indices) +@op(torch.ops.aten.max_pool3d_with_indices) def _aten_max_pool2d_with_indices( - self, kernel_size, stride, padding=0, dilation=1, ceil_mode=False + inputs, kernel_size, strides, padding=0, dilation=1, ceil_mode=False ): - stride = stride if stride else [1, 1] - if not isinstance(padding, (list, tuple)): - padding = [padding, padding] - - def build_ceil_mode_padding(): - ceil_mode_padding = [(0, 0), (0, 0)] - for i in range(len(padding)): - left_padding = padding[0] - input_size = self.shape[2 + i] - output_size_rem = ( - input_size + 2 * left_padding - kernel_size[i] - ) % stride[i] - right_padding = left_padding - if ceil_mode and output_size_rem != 0: - extra_padding = stride[i] - output_size_rem - new_output_size = ( - input_size - + left_padding - + right_padding - + extra_padding - - kernel_size[i] - + stride[i] - - 1 - ) // stride[i] + 1 - if (new_output_size - 1) * stride[i] < input_size + left_padding: - right_padding += extra_padding - ceil_mode_padding.append((left_padding, right_padding)) - return ceil_mode_padding - - ceil_mode_padding = build_ceil_mode_padding() - if not all([p == (0, 0) for p in ceil_mode_padding]): - self = jnp.pad( - self, - ceil_mode_padding, - "constant", - constant_values=-jnp.inf, + num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 + kernel_size = tuple(kernel_size) + strides = tuple(strides) + if isinstance(padding, int): + padding = tuple((padding, padding) for _ in range(len(kernel_size))) + elif isinstance(padding, list): + padding = tuple((p, p) for p in padding) + + window_shape = kernel_size + num_batch_dims = inputs.ndim - (len(window_shape) + 1) + strides = strides or (1,) * len(window_shape) + assert len(window_shape) == len( + strides + ), f'len({window_shape}) must equal len({strides})' + strides = (1,) * (1 + num_batch_dims) + strides + dims = (1,) * (1 + num_batch_dims) + window_shape + + is_single_input = False + if num_batch_dims == 0: + # add singleton batch dimension because lax.reduce_window always + # needs a batch dimension. + inputs = inputs[None] + strides = (1,) + strides + dims = (1,) + dims + is_single_input = True + + assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})' + if not isinstance(padding, str): + padding = tuple(map(tuple, padding)) + assert len(padding) == len(window_shape), ( + f'padding {padding} must specify pads for same number of dims as ' + f'window_shape {window_shape}' ) - batch_result = jax.lax.reduce_window( - self, - -jnp.inf, - jax.lax.max, - window_dimensions=[1, 1] + kernel_size, - window_strides=[1, 1] + stride, - padding="VALID", - ) - - # TODO: compute indices from batch_result - # Ref: https://github.com/pytorch/xla/blob/master/torch_xla/csrc/pooling.cpp#L259 - - return batch_result, None + assert all( + [len(x) == 2 for x in padding] + ), f'each entry in padding {padding} must be length 2' + padding = ((0, 0),(0, 0)) + padding + + indices = jnp.arange(np.prod(inputs.shape)).reshape(inputs.shape) + def reduce_fn(a, b): + ai, av = a + bi, bv = b + which = av > bv + return jnp.where(which, ai, bi), jnp.where(which, av, bv) + init_val = -jnp.inf + if inputs.dtype in (jnp.int32, jnp.int64): + init_val = -(1<<31) + init_val = jnp.array(init_val).astype(inputs.dtype) + + indices, y = jax.lax.reduce_window((indices, inputs), (0, init_val), + reduce_fn, dims, strides, padding) + if is_single_input: + indices = jnp.squeeze(indices, axis=0) + y = jnp.squeeze(y, axis=0) + return y, indices + + + batch_result = pool(inputs, -jnp.inf, jax.lax.max, + kernel_size, strides, padding) + indices = pool(inputs, 0, jnp.argmax, + kernel_size, strides, padding) + return batch_result, indices # TODO add more ops @@ -878,8 +903,16 @@ def _aten_scatter_add(input, dim, index, src): # aten.logical_not # aten.sign -# aten.sigmoid +@op(torch.ops.aten.sign) +def _aten_sign(x): + return jnp.sign(x) +# aten.sigmoid +@op(torch.ops.aten.sigmoid) +def _aten_sigmoid(x): + if x.dtype in (jnp.int32, jnp.int64): + x = x.astype(jnp.float32) + return jax.nn.sigmoid(x) # implement aten.asinh in jax @op(torch.ops.aten.asinh) @@ -930,6 +963,34 @@ def _aten_gt(self, other): # aten.pixel_shuffle +@op(torch.ops.aten.pixel_shuffle) +def _aten_pixel_shuffle(x, upscale_factor): + """PixelShuffle implementation in JAX. + + Args: + x: Input tensor. Typically a feature map. + upscale_factor: Integer by which to upscale the spatial dimensions. + + Returns: + Tensor after PixelShuffle operation. + """ + + batch_size, channels, height, width = x.shape + + if channels % (upscale_factor ** 2) != 0: + raise ValueError('Number of channels must be divisible by the square of the upscale factor.') + + new_channels = channels // (upscale_factor ** 2) + new_height = height * upscale_factor + new_width = width * upscale_factor + + x = x.reshape(batch_size, new_channels, upscale_factor, upscale_factor, height, width) + x = jnp.transpose(x, (0, 1, 2, 4, 3, 5)) # Move channels to spatial dimensions + x = x.reshape(batch_size, new_channels, new_height, new_width) + + return x + + # aten.sym_stride # aten.lt @op(torch.ops.aten.lt) @@ -937,9 +998,133 @@ def _aten_lt(self, other): return self < other +def pool(inputs, init, reduce_fn, window_shape, strides, padding): + """Helper function to define pooling functions. + + Pooling functions are implemented using the ReduceWindow XLA op. + NOTE: Be aware that pooling is not generally differentiable. + That means providing a reduce_fn that is differentiable does not imply that + pool is differentiable. + + Args: + inputs: input data with dimensions (batch, window dims..., features). + init: the initial value for the reduction + reduce_fn: a reduce function of the form ``(T, T) -> T``. + window_shape: a shape tuple defining the window to reduce over. + strides: a sequence of ``n`` integers, representing the inter-window + strides (default: ``(1, ..., 1)``). + padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence + of ``n`` ``(low, high)`` integer pairs that give the padding to apply before + and after each spatial dimension. + Returns: + The output of the reduction for each window slice. + """ + num_batch_dims = inputs.ndim - (len(window_shape) + 1) + strides = strides or (1,) * len(window_shape) + assert len(window_shape) == len( + strides + ), f'len({window_shape}) must equal len({strides})' + strides = (1,) * (1 + num_batch_dims) + strides + dims = (1,) * (1 + num_batch_dims) + window_shape + + is_single_input = False + if num_batch_dims == 0: + # add singleton batch dimension because lax.reduce_window always + # needs a batch dimension. + inputs = inputs[None] + strides = (1,) + strides + dims = (1,) + dims + is_single_input = True + + assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})' + if not isinstance(padding, str): + padding = tuple(map(tuple, padding)) + assert len(padding) == len(window_shape), ( + f'padding {padding} must specify pads for same number of dims as ' + f'window_shape {window_shape}' + ) + assert all( + [len(x) == 2 for x in padding] + ), f'each entry in padding {padding} must be length 2' + padding = ((0, 0),(0, 0)) + padding + y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) + if is_single_input: + y = jnp.squeeze(y, axis=0) + return y + +@op(torch.ops.aten._adaptive_avg_pool3d) +def _aten_adaptive_avg_pool3d(x, output_shape): + return _aten_adaptive_avg_pool(x, output_shape, 3) + +@op(torch.ops.aten._adaptive_avg_pool2d) +def _aten_adaptive_avg_pool3d(x, output_shape): + return _aten_adaptive_avg_pool(x, output_shape, 2) + +def _aten_adaptive_avg_pool(x, output_shape, pool_dim): + def adaptive_kernel_size(input_shape, output_shape): + sizes = [1, 1] + spatial_dim_off = len(input_shape) - pool_dim + for spatial_dim in range(pool_dim): + sizes.append( + input_shape[spatial_dim_off + spatial_dim] // output_shape[spatial_dim] + ) + return tuple(sizes) + kernel_sizes = adaptive_kernel_size(x.shape, output_shape) + y = pool(x, 0.0, jax.lax.add, kernel_sizes, kernel_sizes, + padding='VALID') + + div_shape = list(x.shape) + num_batch_dims = len(x.shape) - pool_dim - 1 + div_shape[num_batch_dims] = 1 + div_shape = tuple(div_shape) + if len(div_shape) - 2 == len(kernel_sizes): + div_shape = (1,) + div_shape[1:] + y = y / pool( + jnp.ones(div_shape), 0.0, jax.lax.add, kernel_sizes, kernel_sizes, 'VALID' + ) + return y + + + + + # aten.avg_pool2d +@op(torch.ops.aten.avg_pool2d) +@op(torch.ops.aten.avg_pool3d) +def _aten_avg_pool( + inputs, kernel_size, strides=None, padding=0, + ceil_mode=False, count_include_pad=True, + divisor_override=None): + + num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 + kernel_size = tuple(kernel_size) + strides = tuple(strides) + if isinstance(padding, int): + padding = tuple((padding, padding) for _ in range(len(kernel_size))) + elif isinstance(padding, list): + padding = tuple((p, p) for p in padding) + + y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding) + if count_include_pad: + y = y / np.prod(kernel_size) + else: + div_shape = list(inputs.shape) + div_shape[num_batch_dims] = 1 + div_shape = tuple(div_shape) + if len(div_shape) - 2 == len(kernel_size): + div_shape = (1,) + div_shape[1:] + y = y / pool( + jnp.ones(div_shape), 0.0, jax.lax.add, kernel_size, strides, padding + ) + return y + + # aten.sym_numel # aten.reciprocal +@op(torch.ops.aten.reciprocal) +def _aten_reciprocal(a): + return 1 / a + # aten.scatter @op(torch.ops.aten.select_scatter) def _aten_select_scatter(input, src, dim, index): @@ -973,7 +1158,9 @@ def _aten_acosh(self): # aten.col2im # aten.avg_pool3d # aten.round - +@op(torch.ops.aten.round) +def _aten_round(input, decimals=0): + return jnp.round(input, decimals) # aten.max @op(torch.ops.aten.max) @@ -1036,6 +1223,18 @@ def _aten_argmax(self, dim=None, keepdim=False): # aten.as_strided +@op(torch.ops.aten.as_strided) +@op(torch.ops.aten.as_strided_copy) +def _aten_as_strided(x, sizes, strides, storage_offset=None): + ind = jnp.zeros(sizes, dtype=jnp.int32) + + for i, (size, stride) in enumerate(zip(sizes, strides)): + result_shape = (1, ) * i + (size, ) + (1, ) * (len(sizes) - i - 1) + indexes = (jnp.arange(size) * stride).reshape(result_shape) + ind += indexes + + return jnp.ravel(x)[ind] + # aten.atan2 @@ -1089,6 +1288,23 @@ def _aten_copy(x): return jnp.copy(x) +@op(torch.ops.aten._cdist_forward) +def _aten_cdist_forward(x1, x2, p, compute_mode=''): + # x1 is B x P x M + # x2 is B x Q x M + # res is B x P x Q + x1 = jnp.expand_dims(x1, len(x1.shape)-1) + x2 = jnp.expand_dims(x2, len(x2.shape)-2) + return jnp.linalg.norm(x1 - x2, ord=p, axis=-1) + +@op(torch.ops.aten._pdist_forward) +def _aten__pdist_forward(x, p): + pairwise_dists = _aten_cdist_forward(x, x, p) + condensed_dists = pairwise_dists[jnp.triu_indices(pairwise_dists.shape[0], k=1)] + return condensed_dists + + + # aten.cos @op(torch.ops.aten.cos) def _aten_cos(input): @@ -1115,6 +1331,13 @@ def _aten_eq(input1, input2): # aten.erf +@op(torch.ops.aten.erf) +def _aten_erf(x): + if x.dtype in (jnp.int32, jnp.int64): + x = x.astype(jnp.float32) + return jax.lax.erf(x) + + # aten.exp @op(torch.ops.aten.exp) def _aten_exp(input): @@ -1128,6 +1351,15 @@ def _aten_expm1(input): # aten.fill +@op(torch.ops.aten.fill) +@op(torch.ops.aten.full_like) +def _aten_fill(x, value, dtype=None, pin_memory=None, memory_format=None): + if dtype is None: + dtype = x.dtype + else: + dtype = tensor.t2j_dtype(dtype) + return jnp.full(x.shape, value, dtype) + # aten.flip @op(torch.ops.aten.flip) def _aten_flip(input, dims): @@ -1149,11 +1381,21 @@ def _aten_fmod(input, other): return input - other*_aten_div(input, other, 'trunc') # aten.gather +@op(torch.ops.aten.gather) +def _aten_gather(input, dim, index): + input_indexes, source_indexes = _scatter_index(dim, index) + return input[input_indexes] + # aten.ge @op(torch.ops.aten.ge) def _aten_ge(self, other): return self >= other +@op(torch.ops.aten.glu) +@op(torch.ops.aten.glu.default) +def _aten_glu(x, dim=-1): + return jax.nn.glu(x, dim) + # aten.hardtanh @op(torch.ops.aten.hardtanh) @@ -1237,17 +1479,39 @@ def _aten_logical_xor(self, other): def _aten_neg(x): return -1 * x # aten.nonzero +@op(torch.ops.aten.nonzero) +def _aten_nonzero(x): + index_tuple = jnp.nonzero(x) + index_tuple = [jnp.expand_dims(p, -1) for p in index_tuple] + return jnp.concatenate(index_tuple, axis=-1) + # aten.prod +@op(torch.ops.aten.prod) +def _aten_prod(self, dim=None, keepdim=False): + return jnp.prod(self, axis=dim, keepdims=keepdim) + # aten.rand # aten.randn # aten.randperm # aten.reflection_pad3d # aten.remainder +@op(torch.ops.aten.remainder) +def _aten_remainder(inputs, other): + return inputs % other + # aten.repeat +@op(torch.ops.aten.repeat) +def _aten_repeat(x, reps): + return jnp.tile(x, reps) + # aten.replication_pad2d # aten.replication_pad3d # aten.roll +@op(torch.ops.aten.roll) +def _aten_roll(input, shifts, dims=None): + return jnp.roll(input, shifts, dims) + # aten.scalar_tensor # aten.slice_scatter @op(torch.ops.aten.slice_scatter) @@ -1261,8 +1525,6 @@ def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): return input.at[tuple(input_index)].set(src) - - # aten.sort # torch.sort(input, dim=-1, descending=False, stable=False, *, out=None) @op(torch.ops.aten.sort) @@ -1350,10 +1612,17 @@ def _aten_where(condition, x, y): # aten.to.dtype +#Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None @op(torch.ops.aten.to.dtype) -def _aten_to_dtype(a, dtype): +def _aten_to_dtype(a, dtype, non_blocking=False, copy=False, memory_format=None): jaxdtype = tensor.t2j_dtype(dtype) return a.astype(jaxdtype) # aten.to.device + +#Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False +@op(torch.ops.aten.var_mean.correction) +def _aten_var_mean_correction(self, dim=None, correction=None, keepdim=False): + return (jnp.var(self, axis=dim, ddof=correction, keepdims=keepdim), + jnp.mean(self, dim, keepdims=keepdim)) diff --git a/experimental/torch_xla2/torch_xla2/ops_registry.py b/experimental/torch_xla2/torch_xla2/ops_registry.py index 7915c360783b..3532b80525d1 100644 --- a/experimental/torch_xla2/torch_xla2/ops_registry.py +++ b/experimental/torch_xla2/torch_xla2/ops_registry.py @@ -1,18 +1,26 @@ import torch +import torch._decomp as decomp class LoweringRegistry: def __init__(self): self.registered_ops = {} + self.decomps = {} def lookup(self, op_or_name): - candidate = self.registered_ops.get(op_or_name) + candidate = self._lookup(op_or_name) if candidate is None: if isinstance(op_or_name, torch._ops.OpOverloadPacket): - candidate = self.registered_ops.get(op_or_name.default) + candidate = self._lookup(op_or_name.default) if isinstance(op_or_name, torch._ops.OpOverload): - candidate = self.registered_ops.get(op_or_name.overloadpacket) + candidate = self._lookup(op_or_name.overloadpacket) + return candidate + + def _lookup(self, op): + candidate = self.registered_ops.get(op) + if candidate is None: + candidate = self.decomp.get(op) return candidate def register(self, op, lowering): @@ -20,6 +28,17 @@ def register(self, op, lowering): lowerings = LoweringRegistry() +EXTRA_DECOMP = decomp.get_decompositions([ + torch.ops.aten.upsample_nearest2d, + torch.ops.aten._native_batch_norm_legit.no_stats, + torch.ops.aten._adaptive_avg_pool2d, + torch.ops.aten._adaptive_avg_pool3d, + torch.ops.aten.grid_sampler_2d, + torch.ops.aten.native_dropout, +]) +CORE_ATEN_DECOMP = decomp.core_aten_decompositions() +CORE_ATEN_DECOMP.update(EXTRA_DECOMP) +lowerings.decomp = CORE_ATEN_DECOMP def _all_core_ops(): diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 6292a37e6fcf..77a2ad7d3281 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -120,11 +120,6 @@ def move_to_device(t): return XLATensor2(t2j(t)) -EXTRA_DECOMP = decomp.get_decompositions([torch.ops.aten.upsample_nearest2d]) -CORE_ATEN_DECOMP = decomp.core_aten_decompositions() -CORE_ATEN_DECOMP.update(EXTRA_DECOMP) - - class XLATensor2(torch.Tensor): @staticmethod @@ -189,14 +184,12 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): else: print(" ", a) lowering = ops_registry.lowerings.lookup(func) + if lowering is None: - if func in CORE_ATEN_DECOMP: - with XLADispatchMode(): - return CORE_ATEN_DECOMP[func](*args, **kwargs) - else: - print(func.name(), func.tags) raise RuntimeError("No lowering found for", func.name()) - res = lowering(*args, **kwargs) + + with XLADispatchMode(): + res = lowering(*args, **kwargs) print("output:") for a in torch_pytree.tree_flatten(res)[0]: if isinstance(a, XLATensor2):