diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index dcb3f7bf00c..92d34f96c92 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -103,7 +103,6 @@ def test_aten_acosh_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.acosh, args, kwargs) - @unittest.skip def test_aten_unsqueeze_0(self): args = ( torch.randn((1, 3, 10)).to(torch.float32), @@ -112,7 +111,6 @@ def test_aten_unsqueeze_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - @unittest.skip def test_aten_unsqueeze_1(self): args = ( torch.randn((1, 3, 10)).to(torch.float16), @@ -121,7 +119,6 @@ def test_aten_unsqueeze_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - @unittest.skip def test_aten_unsqueeze_2(self): args = ( torch.randint(0, 10, (1, 3, 10)).to(torch.int32), @@ -130,7 +127,6 @@ def test_aten_unsqueeze_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - @unittest.skip def test_aten_unsqueeze_3(self): args = ( torch.randn((1, 3, 10)).to(torch.float32), @@ -139,7 +135,6 @@ def test_aten_unsqueeze_3(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - @unittest.skip def test_aten_unsqueeze_4(self): args = ( torch.randn((1, 3, 10)).to(torch.float16), @@ -148,7 +143,6 @@ def test_aten_unsqueeze_4(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - @unittest.skip def test_aten_unsqueeze_5(self): args = ( torch.randint(0, 10, (1, 3, 10)).to(torch.int32), @@ -209,7 +203,6 @@ def test_aten__adaptive_avg_pool2d_1(self): self, torch.ops.aten._adaptive_avg_pool2d, args, kwargs ) - @unittest.skip def test_aten_squeeze_dim_0(self): args = ( torch.randn((1, 3, 1, 5)).to(torch.float32), @@ -218,7 +211,6 @@ def test_aten_squeeze_dim_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) - @unittest.skip def test_aten_squeeze_dim_1(self): args = ( torch.randn((1, 3, 1, 5)).to(torch.float32), @@ -227,7 +219,6 @@ def test_aten_squeeze_dim_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) - @unittest.skip def test_aten_squeeze_dim_2(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -236,7 +227,6 @@ def test_aten_squeeze_dim_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) - @unittest.skip def test_aten_squeeze_dim_3(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -245,7 +235,6 @@ def test_aten_squeeze_dim_3(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) - @unittest.skip def test_aten_squeeze_dim_4(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -332,7 +321,6 @@ def test_aten_add_Tensor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.add.Tensor, args, kwargs) - @unittest.skip def test_aten_addmm_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -342,7 +330,6 @@ def test_aten_addmm_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.addmm, args, kwargs) - @unittest.skip def test_aten_addmm_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -350,9 +337,8 @@ def test_aten_addmm_1(self): torch.randn((10, 10)).to(torch.float16), ) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.addmm, args, kwargs) + run_export_and_compare(self, torch.ops.aten.addmm, args, kwargs, atol=0.001, rtol=0.001) - @unittest.skip def test_aten_addmm_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -663,7 +649,6 @@ def test_aten_atan_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.atan, args, kwargs) - @unittest.skip def test_aten_atan_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -829,7 +814,6 @@ def test_aten_bitwise_xor_Scalar_0(self): self, torch.ops.aten.bitwise_xor.Scalar, args, kwargs ) - @unittest.skip def test_aten_bmm_0(self): args = ( torch.randn((10, 10, 10)).to(torch.float32), @@ -838,7 +822,6 @@ def test_aten_bmm_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.bmm, args, kwargs) - @unittest.skip def test_aten_bmm_1(self): args = ( torch.randn((10, 10, 10)).to(torch.float16), @@ -847,7 +830,6 @@ def test_aten_bmm_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.bmm, args, kwargs) - @unittest.skip def test_aten_bmm_2(self): args = ( torch.randint(0, 10, (10, 10, 10)).to(torch.int32), @@ -856,7 +838,6 @@ def test_aten_bmm_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.bmm, args, kwargs) - @unittest.skip def test_aten_cat_0(self): args = ( [ @@ -867,7 +848,6 @@ def test_aten_cat_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.cat, args, kwargs) - @unittest.skip def test_aten_cat_1(self): args = ( [ @@ -878,7 +858,6 @@ def test_aten_cat_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.cat, args, kwargs) - @unittest.skip def test_aten_cat_2(self): args = ( [ @@ -1040,7 +1019,6 @@ def test_aten_convolution_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.convolution, args, kwargs) - @unittest.skip def test_aten_convolution_1(self): args = ( torch.randn((3, 2, 10)).to(torch.float16), @@ -1117,7 +1095,6 @@ def test_aten_cosh_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.cosh, args, kwargs) - @unittest.skip def test_aten_cumsum_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1126,16 +1103,14 @@ def test_aten_cumsum_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.cumsum, args, kwargs) - @unittest.skip def test_aten_cumsum_1(self): args = ( torch.randn((10, 10)).to(torch.float16), 1, ) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.cumsum, args, kwargs) + run_export_and_compare(self, torch.ops.aten.cumsum, args, kwargs, atol=1e-2, rtol=1e-3) - @unittest.skip def test_aten_cumsum_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -1144,19 +1119,16 @@ def test_aten_cumsum_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.cumsum, args, kwargs) - @unittest.skip def test_aten_diagonal_0(self): args = (torch.randn((10, 20)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.diagonal, args, kwargs) - @unittest.skip def test_aten_diagonal_1(self): args = (torch.randn((10, 20)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.diagonal, args, kwargs) - @unittest.skip def test_aten_diagonal_2(self): args = (torch.randint(0, 10, (10, 20)).to(torch.int32),) kwargs = dict() @@ -1186,7 +1158,6 @@ def test_aten_div_Scalar_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.div.Scalar, args, kwargs) - @unittest.skip def test_aten_div_Scalar_mode_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1197,7 +1168,6 @@ def test_aten_div_Scalar_mode_0(self): } run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs) - @unittest.skip def test_aten_div_Scalar_mode_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -1206,7 +1176,7 @@ def test_aten_div_Scalar_mode_1(self): kwargs = { "rounding_mode": "trunc", } - run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs) + run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs, rtol=0.1) def test_aten_div_Scalar_mode_2(self): args = ( @@ -1242,7 +1212,6 @@ def test_aten_div_Tensor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.div.Tensor, args, kwargs) - @unittest.skip def test_aten_div_Tensor_mode_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1253,7 +1222,6 @@ def test_aten_div_Tensor_mode_0(self): } run_export_and_compare(self, torch.ops.aten.div.Tensor_mode, args, kwargs) - @unittest.skip def test_aten_div_Tensor_mode_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -1402,7 +1370,6 @@ def test_aten_expand_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.expand, args, kwargs) - @unittest.skip def test_aten_expand_copy_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1414,7 +1381,6 @@ def test_aten_expand_copy_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.expand_copy, args, kwargs) - @unittest.skip def test_aten_expand_copy_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -1426,7 +1392,6 @@ def test_aten_expand_copy_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.expand_copy, args, kwargs) - @unittest.skip def test_aten_expand_copy_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -1555,25 +1520,6 @@ def test_aten_floor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.floor, args, kwargs) - @unittest.skip - def test_aten_floor_divide_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.floor_divide, args, kwargs) - - @unittest.skip - def test_aten_floor_divide_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.floor_divide, args, kwargs) - - @unittest.skip def test_aten_fmod_Scalar_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1582,16 +1528,14 @@ def test_aten_fmod_Scalar_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.fmod.Scalar, args, kwargs) - @unittest.skip def test_aten_fmod_Scalar_1(self): args = ( torch.randn((10, 10)).to(torch.float16), 0.123, ) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.fmod.Scalar, args, kwargs) + run_export_and_compare(self, torch.ops.aten.fmod.Scalar, args, kwargs, rtol=0.1, atol=0.2) - @unittest.skip def test_aten_fmod_Scalar_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -1600,7 +1544,6 @@ def test_aten_fmod_Scalar_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.fmod.Scalar, args, kwargs) - @unittest.skip def test_aten_fmod_Tensor_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1609,7 +1552,6 @@ def test_aten_fmod_Tensor_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.fmod.Tensor, args, kwargs) - @unittest.skip def test_aten_fmod_Tensor_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -1675,7 +1617,6 @@ def test_aten_gather_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.gather, args, kwargs) - @unittest.skip def test_aten_ge_Scalar_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1684,7 +1625,6 @@ def test_aten_ge_Scalar_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.ge.Scalar, args, kwargs) - @unittest.skip def test_aten_ge_Scalar_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -1702,7 +1642,6 @@ def test_aten_ge_Scalar_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.ge.Scalar, args, kwargs) - @unittest.skip def test_aten_ge_Tensor_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1711,7 +1650,6 @@ def test_aten_ge_Tensor_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.ge.Tensor, args, kwargs) - @unittest.skip def test_aten_ge_Tensor_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -1720,7 +1658,6 @@ def test_aten_ge_Tensor_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.ge.Tensor, args, kwargs) - @unittest.skip def test_aten_ge_Tensor_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -1729,17 +1666,15 @@ def test_aten_ge_Tensor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.ge.Tensor, args, kwargs) - @unittest.skip def test_aten_gelu_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs) - @unittest.skip def test_aten_gelu_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs) + run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs, atol=0.01, rtol=0.01) @unittest.skip def test_aten_glu_0(self): @@ -1903,10 +1838,9 @@ def test_aten_index_select_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs) - @unittest.skip def test_aten_index_Tensor_0(self): args = ( - torch.randn((2, 10)).to(torch.float32), + torch.randn((10, 10)).to(torch.float32), [ torch.randint(0, 10, (2,)).to(torch.int64), ], @@ -1914,10 +1848,9 @@ def test_aten_index_Tensor_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.index.Tensor, args, kwargs) - @unittest.skip def test_aten_index_Tensor_1(self): args = ( - torch.randn((2, 10)).to(torch.float16), + torch.randn((10, 10)).to(torch.float16), [ torch.randint(0, 10, (2,)).to(torch.int64), ], @@ -1925,10 +1858,9 @@ def test_aten_index_Tensor_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.index.Tensor, args, kwargs) - @unittest.skip def test_aten_index_Tensor_2(self): args = ( - torch.randint(0, 10, (2, 10)).to(torch.int32), + torch.randint(0, 10, (10, 10)).to(torch.int32), [ torch.randint(0, 10, (2,)).to(torch.int64), ], @@ -2014,7 +1946,6 @@ def test_aten_le_Tensor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.le.Tensor, args, kwargs) - @unittest.skip def test_aten_leaky_relu_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -2023,7 +1954,6 @@ def test_aten_leaky_relu_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.leaky_relu, args, kwargs) - @unittest.skip def test_aten_leaky_relu_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -2032,19 +1962,16 @@ def test_aten_leaky_relu_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.leaky_relu, args, kwargs) - @unittest.skip def test_aten_lift_fresh_copy_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.lift_fresh_copy, args, kwargs) - @unittest.skip def test_aten_lift_fresh_copy_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.lift_fresh_copy, args, kwargs) - @unittest.skip def test_aten_lift_fresh_copy_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -2146,25 +2073,21 @@ def test_aten_logical_and_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs) - @unittest.skip def test_aten_logical_not_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.logical_not, args, kwargs) - @unittest.skip def test_aten_logical_not_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.logical_not, args, kwargs) - @unittest.skip def test_aten_logical_not_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.logical_not, args, kwargs) - @unittest.skip def test_aten_logical_or_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -2173,7 +2096,6 @@ def test_aten_logical_or_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.logical_or, args, kwargs) - @unittest.skip def test_aten_logical_or_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -2182,7 +2104,6 @@ def test_aten_logical_or_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.logical_or, args, kwargs) - @unittest.skip def test_aten_logical_or_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -2215,19 +2136,16 @@ def test_aten_logical_xor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.logical_xor, args, kwargs) - @unittest.skip def test_aten_logit_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.logit, args, kwargs) - @unittest.skip def test_aten_logit_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.logit, args, kwargs) - @unittest.skip def test_aten_logit_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -2506,19 +2424,16 @@ def test_aten_maximum_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.maximum, args, kwargs) - @unittest.skip def test_aten_mean_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.mean, args, kwargs) - @unittest.skip def test_aten_mean_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.mean, args, kwargs) - @unittest.skip def test_aten_mean_dim_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -2527,7 +2442,6 @@ def test_aten_mean_dim_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.mean.dim, args, kwargs) - @unittest.skip def test_aten_mean_dim_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -2584,7 +2498,6 @@ def test_aten_minimum_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.minimum, args, kwargs) - @unittest.skip def test_aten_mm_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -2593,7 +2506,6 @@ def test_aten_mm_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.mm, args, kwargs) - @unittest.skip def test_aten_mm_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -2602,7 +2514,6 @@ def test_aten_mm_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.mm, args, kwargs) - @unittest.skip def test_aten_mm_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -2821,7 +2732,6 @@ def test_aten_native_layer_norm_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.native_layer_norm, args, kwargs) - @unittest.skip def test_aten_ne_Scalar_0(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -2830,7 +2740,6 @@ def test_aten_ne_Scalar_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.ne.Scalar, args, kwargs) - @unittest.skip def test_aten_ne_Scalar_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -2839,7 +2748,6 @@ def test_aten_ne_Scalar_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.ne.Scalar, args, kwargs) - @unittest.skip def test_aten_ne_Scalar_2(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -2848,7 +2756,6 @@ def test_aten_ne_Scalar_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.ne.Scalar, args, kwargs) - @unittest.skip def test_aten_ne_Tensor_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -2857,7 +2764,6 @@ def test_aten_ne_Tensor_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.ne.Tensor, args, kwargs) - @unittest.skip def test_aten_ne_Tensor_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -2866,7 +2772,6 @@ def test_aten_ne_Tensor_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.ne.Tensor, args, kwargs) - @unittest.skip def test_aten_ne_Tensor_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -2875,19 +2780,16 @@ def test_aten_ne_Tensor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.ne.Tensor, args, kwargs) - @unittest.skip def test_aten_neg_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.neg, args, kwargs) - @unittest.skip def test_aten_neg_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.neg, args, kwargs) - @unittest.skip def test_aten_neg_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -2953,7 +2855,6 @@ def test_aten_permute_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.permute, args, kwargs) - @unittest.skip def test_aten_permute_copy_0(self): args = ( torch.randn((2, 2, 2)).to(torch.float32), @@ -2966,7 +2867,6 @@ def test_aten_permute_copy_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.permute_copy, args, kwargs) - @unittest.skip def test_aten_permute_copy_1(self): args = ( torch.randn((2, 2, 2)).to(torch.float16), @@ -2979,7 +2879,6 @@ def test_aten_permute_copy_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.permute_copy, args, kwargs) - @unittest.skip def test_aten_permute_copy_2(self): args = ( torch.randint(0, 10, (2, 2, 2)).to(torch.int32), @@ -3019,7 +2918,6 @@ def test_aten_pixel_shuffle_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.pixel_shuffle, args, kwargs) - @unittest.skip def test_aten_pow_Scalar_0(self): args = ( 1.123, @@ -3052,13 +2950,11 @@ def test_aten_pow_Tensor_Scalar_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.pow.Tensor_Scalar, args, kwargs) - @unittest.skip def test_aten_pow_Scalar_1(self): args = (10000, torch.randn(16 * 8)) kwargs = dict() run_export_and_compare(self, torch.ops.aten.pow.Scalar, args, kwargs) - @unittest.skip def test_aten_pow_Tensor_Tensor_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3067,7 +2963,6 @@ def test_aten_pow_Tensor_Tensor_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.pow.Tensor_Tensor, args, kwargs) - @unittest.skip def test_aten_pow_Tensor_Tensor_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3076,7 +2971,6 @@ def test_aten_pow_Tensor_Tensor_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.pow.Tensor_Tensor, args, kwargs) - @unittest.skip def test_aten_pow_Tensor_Tensor_2(self): args = ( torch.randint(0, 5, (10, 10)).to(torch.int32), @@ -3462,13 +3356,11 @@ def test_aten_rsqrt_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.rsqrt, args, kwargs) - @unittest.skip def test_aten_rsqrt_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.rsqrt, args, kwargs) + run_export_and_compare(self, torch.ops.aten.rsqrt, args, kwargs, atol=0.01, rtol=0.01) - @unittest.skip def test_aten_rsqrt_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -3498,7 +3390,6 @@ def test_aten_rsub_Scalar_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.rsub.Scalar, args, kwargs) - @unittest.skip def test_aten_scatter_add_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3509,7 +3400,6 @@ def test_aten_scatter_add_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.scatter_add, args, kwargs) - @unittest.skip def test_aten_scatter_add_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3520,7 +3410,6 @@ def test_aten_scatter_add_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.scatter_add, args, kwargs) - @unittest.skip def test_aten_scatter_add_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3531,7 +3420,6 @@ def test_aten_scatter_add_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.scatter_add, args, kwargs) - @unittest.skip def test_aten_scatter_reduce_two_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3545,35 +3433,32 @@ def test_aten_scatter_reduce_two_0(self): self, torch.ops.aten.scatter_reduce.two, args, kwargs ) - @unittest.skip def test_aten_scatter_reduce_two_1(self): args = ( torch.randn((10, 10)).to(torch.float16), 1, torch.randint(0, 10, (10, 10)).to(torch.int64), torch.randn((10, 10)).to(torch.float16), - "sum", + "amin", ) kwargs = dict() run_export_and_compare( self, torch.ops.aten.scatter_reduce.two, args, kwargs ) - @unittest.skip def test_aten_scatter_reduce_two_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), 1, torch.randint(0, 10, (10, 10)).to(torch.int64), torch.randint(0, 10, (10, 10)).to(torch.int32), - "sum", + "amax", ) kwargs = dict() run_export_and_compare( self, torch.ops.aten.scatter_reduce.two, args, kwargs ) - @unittest.skip def test_aten_scatter_src_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3584,7 +3469,6 @@ def test_aten_scatter_src_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.scatter.src, args, kwargs) - @unittest.skip def test_aten_scatter_src_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3595,7 +3479,6 @@ def test_aten_scatter_src_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.scatter.src, args, kwargs) - @unittest.skip def test_aten_scatter_src_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3606,7 +3489,6 @@ def test_aten_scatter_src_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.scatter.src, args, kwargs) - @unittest.skip def test_aten_scatter_value_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3617,7 +3499,6 @@ def test_aten_scatter_value_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.scatter.value, args, kwargs) - @unittest.skip def test_aten_scatter_value_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3628,7 +3509,6 @@ def test_aten_scatter_value_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.scatter.value, args, kwargs) - @unittest.skip def test_aten_scatter_value_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3696,7 +3576,6 @@ def test_aten_select_int_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.select.int, args, kwargs) - @unittest.skip def test_aten_select_scatter_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3707,7 +3586,6 @@ def test_aten_select_scatter_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs) - @unittest.skip def test_aten_select_scatter_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3718,7 +3596,6 @@ def test_aten_select_scatter_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs) - @unittest.skip def test_aten_select_scatter_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3813,7 +3690,6 @@ def test_aten_slice_copy_Tensor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.slice_copy.Tensor, args, kwargs) - @unittest.skip def test_aten_slice_scatter_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3823,7 +3699,6 @@ def test_aten_slice_scatter_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.slice_scatter, args, kwargs) - @unittest.skip def test_aten_slice_scatter_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3833,7 +3708,6 @@ def test_aten_slice_scatter_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.slice_scatter, args, kwargs) - @unittest.skip def test_aten_slice_scatter_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3876,7 +3750,6 @@ def test_aten__softmax_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten._softmax, args, kwargs) - @unittest.skip def test_aten__softmax_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -4001,7 +3874,6 @@ def test_aten_sqrt_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.sqrt, args, kwargs) - @unittest.skip def test_aten_squeeze_copy_dim_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -4010,7 +3882,6 @@ def test_aten_squeeze_copy_dim_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.squeeze_copy.dim, args, kwargs) - @unittest.skip def test_aten_squeeze_copy_dim_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -4019,7 +3890,6 @@ def test_aten_squeeze_copy_dim_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.squeeze_copy.dim, args, kwargs) - @unittest.skip def test_aten_squeeze_copy_dim_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -4028,7 +3898,6 @@ def test_aten_squeeze_copy_dim_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.squeeze_copy.dim, args, kwargs) - @unittest.skip def test_aten_squeeze_dims_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -4040,7 +3909,6 @@ def test_aten_squeeze_dims_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.squeeze.dims, args, kwargs) - @unittest.skip def test_aten_squeeze_dims_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -4052,7 +3920,6 @@ def test_aten_squeeze_dims_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.squeeze.dims, args, kwargs) - @unittest.skip def test_aten_squeeze_dims_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -4169,13 +4036,11 @@ def test_aten_sum_dim_IntList_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.sum.dim_IntList, args, kwargs) - @unittest.skip def test_aten_tan_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.tan, args, kwargs) - @unittest.skip def test_aten_tan_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() @@ -4188,7 +4053,6 @@ def test_aten_tan_1(self): atol=0.01, ) - @unittest.skip def test_aten_tan_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -4263,7 +4127,7 @@ def test_aten_topk_4(self): ) kwargs = dict() run_export_and_compare(self, torch.ops.aten.topk, args, kwargs) - @unittest.skip + def test_aten_transpose_copy_int_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -4275,7 +4139,6 @@ def test_aten_transpose_copy_int_0(self): self, torch.ops.aten.transpose_copy.int, args, kwargs ) - @unittest.skip def test_aten_transpose_copy_int_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -4287,7 +4150,6 @@ def test_aten_transpose_copy_int_1(self): self, torch.ops.aten.transpose_copy.int, args, kwargs ) - @unittest.skip def test_aten_transpose_copy_int_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -4332,25 +4194,21 @@ def test_aten_trunc_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.trunc, args, kwargs) - @unittest.skip def test_aten_unbind_copy_int_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.unbind_copy.int, args, kwargs) - @unittest.skip def test_aten_unbind_copy_int_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.unbind_copy.int, args, kwargs) - @unittest.skip def test_aten_unbind_copy_int_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.unbind_copy.int, args, kwargs) - @unittest.skip def test_aten_unsqueeze_copy_0(self): args = ( torch.randn((2, 0, 2)).to(torch.float32), @@ -4359,7 +4217,6 @@ def test_aten_unsqueeze_copy_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.unsqueeze_copy, args, kwargs) - @unittest.skip def test_aten_unsqueeze_copy_1(self): args = ( torch.randn((2, 0, 2)).to(torch.float16), @@ -4368,7 +4225,6 @@ def test_aten_unsqueeze_copy_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.unsqueeze_copy, args, kwargs) - @unittest.skip def test_aten_unsqueeze_copy_2(self): args = ( torch.randint(0, 10, (2, 0, 2)).to(torch.int32), @@ -4404,25 +4260,21 @@ def test_aten_upsample_nearest2d_0(self): self, torch.ops.aten.upsample_nearest2d, args, kwargs ) - @unittest.skip def test_aten_var_correction_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.var.correction, args, kwargs) - @unittest.skip def test_aten_var_correction_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.var.correction, args, kwargs) - @unittest.skip def test_aten_var_correction_2(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict(correction=0) run_export_and_compare(self, torch.ops.aten.var.correction, args, kwargs) - @unittest.skip def test_aten_var_correction_3(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict(correction=0) @@ -4461,7 +4313,6 @@ def test_aten_view_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.view, args, kwargs) - @unittest.skip def test_aten_view_copy_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -4474,7 +4325,6 @@ def test_aten_view_copy_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.view_copy, args, kwargs) - @unittest.skip def test_aten_view_copy_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -4487,7 +4337,6 @@ def test_aten_view_copy_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.view_copy, args, kwargs) - @unittest.skip def test_aten_view_copy_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), diff --git a/experimental/torch_xla2/torch_xla2/ops.py b/experimental/torch_xla2/torch_xla2/ops.py index a6f153d72e1..cb9b57cdac5 100644 --- a/experimental/torch_xla2/torch_xla2/ops.py +++ b/experimental/torch_xla2/torch_xla2/ops.py @@ -44,6 +44,7 @@ def inner(func): return inner +@op(torch.ops.aten.view_copy) @op(torch.ops.aten.view) @op(torch.ops.aten._unsafe_view) def _aten_unsafe_view(x, shape): @@ -117,7 +118,7 @@ def _aten_index_select(x, dim, indexes): @op(torch.ops.aten.mean) -def _aten_mean(x, dim, keepdim): +def _aten_mean(x, dim=None, keepdim=False): return jnp.mean(x, dim, keepdims=keepdim) @@ -146,7 +147,6 @@ def _aten_sub(x, y): @op(torch.ops.aten.mm) def _aten_mm(x, y): res = x @ y - assert res.dtype == jnp.bfloat16 return res @@ -166,6 +166,7 @@ def _aten_t(x): @op(torch.ops.aten.transpose) +@op(torch.ops.aten.transpose_copy) def _aten_transpose(x, dim0, dim1): shape = list(range(len(x.shape))) shape[dim0], shape[dim1] = shape[dim1], shape[dim0] @@ -218,8 +219,6 @@ def _aten_softmax(x, dim, halftofloat): def _aten_pow(x, y): if isinstance(y, int): y = float(y) - if isinstance(y, jnp.ndarray): - y = y.astype(jnp.astype(jnp.bfloat16)) return jnp.power(x, y) @@ -233,8 +232,13 @@ def _aten_view_as_complex(input): @op(torch.ops.aten.div) def _aten_div(x, y, rounding_mode=""): + res = x / y if rounding_mode == "trunc": - return jnp.floor_divide(x, y) + res = jnp.trunc(res) + return res + +@op(torch.ops.aten.true_divide) +def _aten_true_divide(x, y): return x / y @@ -253,10 +257,15 @@ def _aten_embedding(a, w, padding_idx=-1): @op(torch.ops.aten.rsqrt) def _aten_rsqrt(x): + if isinstance(x, int): + x = float(x) + if x.dtype == jnp.int32: + x = x.astype(jnp.float32) return jax.lax.rsqrt(x) @op(torch.ops.aten.expand) +@op(torch.ops.aten.expand_copy) def _aten_expand(x, dims): def fix_dims(d, xs): if d == -1: @@ -332,11 +341,13 @@ def make_range(rank, dim, start, end): @op(torch.ops.aten.permute) +@op(torch.ops.aten.permute_copy) def permute(t, dims): return jnp.transpose(t, dims) @op(torch.ops.aten.unsqueeze) +@op(torch.ops.aten.unsqueeze_copy) @op(torch.ops.aten.unsqueeze.default) def _aten_unsqueeze(self, dim): if dim < 0: @@ -412,6 +423,7 @@ def _aten_gelu(self, *, approximate="none"): @op(torch.ops.aten.squeeze) +@op(torch.ops.aten.squeeze_copy) def _aten_squeeze_dim(self, dim): """Squeezes a Jax tensor by removing a single dimension of size 1. @@ -427,16 +439,21 @@ def _aten_squeeze_dim(self, dim): # Validate input arguments if not isinstance(self, jnp.ndarray): raise TypeError(f"Expected a Jax tensor, got {type(self)}.") - if not isinstance(dim, int): - raise TypeError(f"Expected dim to be an int, got {type(dim)}.") + if isinstance(dim, int): + dim = [dim] # Check if the specified dimension has size 1 - if self.shape[dim] != 1: + if all([self.shape[d] != 1 for d in dim]): return self # Use slicing to remove the dimension if it is 1 new_shape = list(self.shape) - new_shape.pop(dim) + def fix_dim(p): + if p < 0: + return p + len(self.shape) + return p + dim = [fix_dim(d) for d in dim] + new_shape = [p for i, p in enumerate(self.shape) if i not in dim or p != 1] return self.reshape(new_shape) @@ -797,6 +814,10 @@ def _aten_sum(self, dim=None, keepdim=False, dtype=None): def _aten_sqrt(self): return jnp.sqrt(self) +@op(torch.ops.aten.tan) +def _aten_tanh(self): + return jnp.tan(self) + # aten.tanh @op(torch.ops.aten.tanh) @@ -824,7 +845,36 @@ def _aten_minimum(self, other): # aten.max_pool2d_backward +def _scatter_index(dim, index): + """Returns a tuple of indexes; + + The first is to select in input (to modify), + the second is to select from the values. + """ + index_shape = list(index.shape) + input_indexes = [] + source_indexes = [] + for i in range(len(index_shape)): + source_indexes.append(slice(0, index_shape[i])) + if i == dim: + input_indexes.append(index) + else: + target_shape = [1] * len(index_shape) + target_shape[i] = index_shape[i] + input_indexes.append( + jnp.broadcast_to(jnp.arange(index_shape[i]).reshape(target_shape), index_shape) + ) + return tuple(input_indexes), tuple(source_indexes) + # aten.scatter_add +@op(torch.ops.aten.scatter_add) +def _aten_scatter_add(input, dim, index, src): + """JAX implementation of scatter, mimicking torch.scatter behavior""" + + input_indexes, source_indexes = _scatter_index(dim, index) + return input.at[input_indexes].add(src[source_indexes]) + + # aten.logical_not # aten.sign @@ -844,6 +894,24 @@ def _aten_atan(self): # aten.scatter_reduce +@op(torch.ops.aten.scatter_reduce) +def _aten_scatter_reduce(input, dim, index, src, reduce, *, include_self=True): + input_indexes, source_indexes = _scatter_index(dim, index) + if reduce == "sum": + return input.at[input_indexes].add(src[source_indexes]) + elif reduce == "prod": + return input.at[input_indexes].multiply(src[source_indexes]) + elif reduce == "mean": + return input.at[input_indexes].add(src[source_indexes]) + elif reduce == "amax": + return input.at[input_indexes].max(src[source_indexes]) + elif reduce == "amin": + return input.at[input_indexes].min(src[source_indexes]) + else: + raise RuntimeError('Unknow reduction type: ', reduce) + + + # aten.acos @op(torch.ops.aten.acos) def _aten_acos(self): @@ -873,6 +941,26 @@ def _aten_lt(self, other): # aten.sym_numel # aten.reciprocal # aten.scatter +@op(torch.ops.aten.select_scatter) +def _aten_select_scatter(input, src, dim, index): + input_indexes = [] + for x in range(len(input.shape)): + if x == dim: + input_indexes.append(index) + else: + input_indexes.append(slice(None, None, None)) + return input.at[tuple(input_indexes)].set(src) + + +@op(torch.ops.aten.scatter.src) +def _aten_scatter_src(input, dim, index, src, reduce=None): + input_index, source_indexes = _scatter_index(dim, index) + return input.at[input_index].set(src[source_indexes]) + +@op(torch.ops.aten.scatter.value) +def _aten_scatter(input, dim, index, src, reduce=None): + input_index, source_indexes = _scatter_index(dim, index) + return input.at[input_index].set(src) # aten.acosh @@ -996,6 +1084,7 @@ def _aten_constant_pad_nd(input, padding, value=0): # aten.convolution_backward @op(torch.ops.aten.copy) +@op(torch.ops.aten.lift_fresh_copy) def _aten_copy(x): return jnp.copy(x) @@ -1013,6 +1102,11 @@ def _aten_cosh(input): # aten.diagonal +@op(torch.ops.aten.diagonal) +def _aten_diagonal(input, offset=0, dim1=0, dim2=1): + return jnp.diagonal(input, offset, dim1, dim2) + + # aten.empty_strided # aten.eq @op(torch.ops.aten.eq) @@ -1050,6 +1144,10 @@ def _aten_floor(input): # aten.fmod +@op(torch.ops.aten.fmod) +def _aten_fmod(input, other): + return input - other*_aten_div(input, other, 'trunc') + # aten.gather # aten.ge @op(torch.ops.aten.ge) @@ -1135,6 +1233,9 @@ def _aten_logical_xor(self, other): # aten.native_dropout # aten.native_group_norm_backward # aten.neg +@op(torch.ops.aten.neg) +def _aten_neg(x): + return -1 * x # aten.nonzero # aten.prod @@ -1148,8 +1249,18 @@ def _aten_logical_xor(self, other): # aten.replication_pad3d # aten.roll # aten.scalar_tensor -# aten.select_scatter # aten.slice_scatter +@op(torch.ops.aten.slice_scatter) +def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): + input_index = [] + for x in range(len(input.shape)): + if x == dim: + input_index.append(slice(start, end, step)) + else: + input_index.append(slice(None, None, None)) + return input.at[tuple(input_index)].set(src) + + # aten.sort @@ -1220,6 +1331,13 @@ def _aten_trunc(a): return jnp.trunc(a) +@op(torch.ops.aten.unbind) +@op(torch.ops.aten.unbind_copy) +def _aten_unbind(a, dim=0): + return tuple(_aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim) + for i in range(a.shape[dim])) + + # NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d # despite those being core aten ops, they also have decompositions. # here we are using torch decompositions. diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 6b6ac010f63..6292a37e6fc 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -5,6 +5,7 @@ import numpy import torch import torch._decomp as decomp +import torch._decomp.decompositions from torch_xla2 import ops_registry import torch.utils._python_dispatch as torch_dispatch import torch.utils._pytree as torch_pytree