diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp index 87ae7fd1d33..6cbe623809e 100644 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -2411,6 +2411,20 @@ TEST_F(AtenXlaTensorTest, TestTanh) { }); } +// In torch, tanh works with integer inputs. The same should be true for +// torch_xla +TEST_F(AtenXlaTensorTest, TestTanhWithInt) { + torch::Tensor a = torch::rand({2, 2}); + torch::Tensor b = torch::tanh(a); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::Tensor xla_b = torch::tanh(xla_a); + AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::tanh", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestClampMinMax) { torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat)); torch::Scalar min_val(0.311); diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index e10fc28a155..ec0a482c1b5 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -4419,7 +4419,6 @@ def test_aten_tanh_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.tanh, args, kwargs) - @unittest.skip def test_aten_tanh_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 8b281d0f01a..99250adbdad 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -751,6 +751,11 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const { torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { + xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, + /*device=*/nullptr); + } return ReturnOp(xla::Tanh(xla_input), loctx); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index f6de4ae6f89..9091ee44491 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -818,7 +818,11 @@ xla::Shape TakeOutputShape(const torch::lazy::Value& input, } xla::Shape TanhOutputShape(const torch::lazy::Value& input) { - return GetXlaShape(input); + xla::Shape result_shape = GetXlaShape(input); + if (xla::primitive_util::IsIntegralType(result_shape.element_type())) { + result_shape.set_element_type(xla::PrimitiveType::F32); + } + return result_shape; } xla::Shape TrilOutputShape(const torch::lazy::Value& input) {