Skip to content

Commit

Permalink
Promote int to float for tanh operation (consistent with Pytorch)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Dec 14, 2023
1 parent c4f8772 commit 5a6e57f
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 2 deletions.
14 changes: 14 additions & 0 deletions test/cpp/test_aten_xla_tensor_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2411,6 +2411,20 @@ TEST_F(AtenXlaTensorTest, TestTanh) {
});
}

// In torch, tanh works with integer inputs. The same should be true for
// torch_xla
TEST_F(AtenXlaTensorTest, TestTanhWithInt) {
torch::Tensor a = torch::rand({2, 2});
torch::Tensor b = torch::tanh(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = torch::tanh(xla_a);
AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::tanh", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestClampMinMax) {
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
torch::Scalar min_val(0.311);
Expand Down
1 change: 0 additions & 1 deletion test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4419,7 +4419,6 @@ def test_aten_tanh_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.tanh, args, kwargs)

@unittest.skip
def test_aten_tanh_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
kwargs = dict()
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,11 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const {

torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32,
/*device=*/nullptr);
}
return ReturnOp(xla::Tanh(xla_input), loctx);
}

Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,11 @@ xla::Shape TakeOutputShape(const torch::lazy::Value& input,
}

xla::Shape TanhOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
xla::Shape result_shape = GetXlaShape(input);
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
result_shape.set_element_type(xla::PrimitiveType::F32);
}
return result_shape;
}

xla::Shape TrilOutputShape(const torch::lazy::Value& input) {
Expand Down

0 comments on commit 5a6e57f

Please sign in to comment.