From 07832b070fe098f378008b8b24bad2320381cd44 Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Tue, 23 Jan 2024 13:38:40 -0800 Subject: [PATCH] fix core aten asinh (#6365) --- test/cpp/test_aten_xla_tensor_2.cpp | 12 ++++++++++++ test/test_core_aten_ops.py | 2 -- torch_xla/csrc/ops/ops_lower_fn.cpp | 4 ++++ torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 6 +++++- 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp index a9102217096..fcdf4b11a6f 100644 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -2226,6 +2226,18 @@ TEST_F(AtenXlaTensorTest, TestAsin) { }); } +TEST_F(AtenXlaTensorTest, TestAsinhWithInt) { + torch::Tensor a = torch::rand({2, 2}); + torch::Tensor b = torch::asinh(a); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::Tensor xla_b = torch::asinh(xla_a); + AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::asinh", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestAsinh) { torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::asinh(a); diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 0eb699a51e7..0a4dc82eb95 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -654,7 +654,6 @@ def test_aten_asin_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.asin, args, kwargs) - @unittest.skip def test_aten_asinh_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() @@ -665,7 +664,6 @@ def test_aten_asinh_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.asinh, args, kwargs) - @unittest.skip def test_aten_asinh_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 3a5b1071c61..9a765db749a 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -151,6 +151,10 @@ torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const { torch_xla::XlaOpVector Asinh::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { + xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); + } return ReturnOp(xla::Asinh(xla_input), loctx); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index bd092d2b3b5..b0133da3ec7 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -232,7 +232,11 @@ xla::Shape AsinOutputShape(const torch::lazy::Value& input) { } xla::Shape AsinhOutputShape(const torch::lazy::Value& input) { - return GetXlaShape(input); + xla::Shape result_shape = GetXlaShape(input); + if (xla::primitive_util::IsIntegralType(result_shape.element_type())) { + result_shape.set_element_type(xla::PrimitiveType::F32); + } + return result_shape; } xla::Shape AtanOutputShape(const torch::lazy::Value& input) {