From 0fca8b28c013ea2cda4e3ade539414518e50be97 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 15 Dec 2023 09:28:40 -0800 Subject: [PATCH] Promote int to float for tanh operation (consistent with Pytorch) (#6166) --- test/cpp/test_aten_xla_tensor_2.cpp | 14 ++++++++++++++ test/test_core_aten_ops.py | 1 - torch_xla/csrc/ops/ops_lower_fn.cpp | 5 +++++ torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 6 +++++- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp index 4a86537d824..a9102217096 100644 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -2425,6 +2425,20 @@ TEST_F(AtenXlaTensorTest, TestTanh) { }); } +// In torch, tanh works with integer inputs. The same should be true for +// torch_xla +TEST_F(AtenXlaTensorTest, TestTanhWithInt) { + torch::Tensor a = torch::rand({2, 2}); + torch::Tensor b = torch::tanh(a); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::Tensor xla_b = torch::tanh(xla_a); + AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::tanh", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestClampMinMax) { torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat)); torch::Scalar min_val(0.311); diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 449e54f08cc..15c7a12b64c 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -4413,7 +4413,6 @@ def test_aten_tanh_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.tanh, args, kwargs) - @unittest.skip def test_aten_tanh_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 1f62d388959..6aa4857e334 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -718,6 +718,11 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const { torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { + xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, + /*device=*/nullptr); + } return ReturnOp(xla::Tanh(xla_input), loctx); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 21c6e0e8dac..825c9b59cf8 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -803,7 +803,11 @@ xla::Shape TakeOutputShape(const torch::lazy::Value& input, } xla::Shape TanhOutputShape(const torch::lazy::Value& input) { - return GetXlaShape(input); + xla::Shape result_shape = GetXlaShape(input); + if (xla::primitive_util::IsIntegralType(result_shape.element_type())) { + result_shape.set_element_type(xla::PrimitiveType::F32); + } + return result_shape; } xla::Shape TrilOutputShape(const torch::lazy::Value& input) {