From 993cd09af8f70582d71fb592a265508bb7fda085 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Mon, 5 Feb 2024 17:17:24 -0800 Subject: [PATCH] Fix some more core aten ops (#6462) --- test/test_core_aten_ops.py | 114 ++++++++---------------- torch_xla/csrc/ops/ops_lower_fn.cpp | 4 + torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 6 +- 3 files changed, 44 insertions(+), 80 deletions(-) diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index b1262fdd8cb..93651b9a8e0 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -529,54 +529,6 @@ def test_aten_argmin_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.argmin, args, kwargs) - @unittest.skip - def test_aten_as_strided_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 0, - 1, - ], - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs) - - @unittest.skip - def test_aten_as_strided_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 0, - 1, - ], - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs) - - @unittest.skip - def test_aten_as_strided_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 0, - 1, - ], - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs) - @unittest.skip def test_aten_as_strided_copy_0(self): args = ( @@ -1369,41 +1321,44 @@ def test_aten_div_Scalar_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.div.Scalar, args, kwargs) - @unittest.skip def test_aten_div_Scalar_mode_0(self): + + def aten_div_Scalar_mode_rounding_mode_trunc(input, other): + return torch.ops.aten.div.Scalar_mode(input, other, rounding_mode='floor') + args = ( torch.randn((10, 10)).to(torch.float32), 0.123, ) - kwargs = dict(( - "rounding_mode", - "trunc", - )) - run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs) + kwargs = dict() + run_export_and_compare(self, aten_div_Scalar_mode_rounding_mode_trunc, args, + kwargs) - @unittest.skip def test_aten_div_Scalar_mode_1(self): + + def aten_div_Scalar_mode_rounding_mode_trunc(input, other): + return torch.ops.aten.div.Scalar_mode(input, other, rounding_mode='floor') + args = ( torch.randn((10, 10)).to(torch.float16), 0.123, ) - kwargs = dict(( - "rounding_mode", - "trunc", - )) - run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs) + kwargs = dict() + run_export_and_compare(self, aten_div_Scalar_mode_rounding_mode_trunc, args, + kwargs) - @unittest.skip def test_aten_div_Scalar_mode_2(self): + + def aten_div_Scalar_mode_rounding_mode_trunc(input, other): + return torch.ops.aten.div.Scalar_mode(input, other, rounding_mode='floor') + args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), 0.123, ) - kwargs = dict(( - "rounding_mode", - "trunc", - )) - run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs) + kwargs = dict() + run_export_and_compare(self, aten_div_Scalar_mode_rounding_mode_trunc, args, + kwargs) def test_aten_div_Tensor_0(self): args = ( @@ -1429,29 +1384,31 @@ def test_aten_div_Tensor_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.div.Tensor, args, kwargs) - @unittest.skip def test_aten_div_Tensor_mode_0(self): + + def aten_div_Tensor_mode_rounding_mode_trunc(input, other): + return torch.ops.aten.div.Tensor_mode(input, other, rounding_mode='trunc') + args = ( torch.randn((10, 10)).to(torch.float32), torch.randn((10, 10)).to(torch.float32), ) - kwargs = dict(( - "rounding_mode", - "trunc", - )) - run_export_and_compare(self, torch.ops.aten.div.Tensor_mode, args, kwargs) + kwargs = dict() + run_export_and_compare(self, aten_div_Tensor_mode_rounding_mode_trunc, args, + kwargs) - @unittest.skip def test_aten_div_Tensor_mode_1(self): + + def aten_div_Tensor_mode_rounding_mode_trunc(input, other): + return torch.ops.aten.div.Tensor_mode(input, other, rounding_mode='trunc') + args = ( torch.randn((10, 10)).to(torch.float16), torch.randn((10, 10)).to(torch.float16), ) - kwargs = dict(( - "rounding_mode", - "trunc", - )) - run_export_and_compare(self, torch.ops.aten.div.Tensor_mode, args, kwargs) + kwargs = dict() + run_export_and_compare(self, aten_div_Tensor_mode_rounding_mode_trunc, args, + kwargs) def test_aten_embedding_0(self): args = ( @@ -4039,7 +3996,6 @@ def test_aten_sin_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.sin, args, kwargs) - @unittest.skip def test_aten_sin_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 9a765db749a..f7fec2d65dd 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -756,6 +756,10 @@ torch_xla::XlaOpVector SiluBackward::Lower(LoweringContext* loctx) const { torch_xla::XlaOpVector Sin::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { + xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); + } return ReturnOp(xla::Sin(xla_input), loctx); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index b0133da3ec7..4e7e2f74dbf 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -835,7 +835,11 @@ xla::Shape SiluBackwardOutputShape(const torch::lazy::Value& grad_output, } xla::Shape SinOutputShape(const torch::lazy::Value& input) { - return GetXlaShape(input); + xla::Shape result_shape = GetXlaShape(input); + if (xla::primitive_util::IsIntegralType(result_shape.element_type())) { + result_shape.set_element_type(xla::PrimitiveType::F32); + } + return result_shape; } xla::Shape SinhOutputShape(const torch::lazy::Value& input) {