Skip to content

Commit

Permalink
fix core aten asinh (#6365)
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore authored Jan 23, 2024
1 parent c4b5ab6 commit 07832b0
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 3 deletions.
12 changes: 12 additions & 0 deletions test/cpp/test_aten_xla_tensor_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2226,6 +2226,18 @@ TEST_F(AtenXlaTensorTest, TestAsin) {
});
}

TEST_F(AtenXlaTensorTest, TestAsinhWithInt) {
torch::Tensor a = torch::rand({2, 2});
torch::Tensor b = torch::asinh(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = torch::asinh(xla_a);
AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::asinh", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestAsinh) {
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::asinh(a);
Expand Down
2 changes: 0 additions & 2 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,6 @@ def test_aten_asin_2(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.asin, args, kwargs)

@unittest.skip
def test_aten_asinh_0(self):
args = (torch.randn((10, 10)).to(torch.float32),)
kwargs = dict()
Expand All @@ -665,7 +664,6 @@ def test_aten_asinh_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.asinh, args, kwargs)

@unittest.skip
def test_aten_asinh_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
kwargs = dict()
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {

torch_xla::XlaOpVector Asinh::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Asinh(xla_input), loctx);
}

Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,11 @@ xla::Shape AsinOutputShape(const torch::lazy::Value& input) {
}

xla::Shape AsinhOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
xla::Shape result_shape = GetXlaShape(input);
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
result_shape.set_element_type(xla::PrimitiveType::F32);
}
return result_shape;
}

xla::Shape AtanOutputShape(const torch::lazy::Value& input) {
Expand Down

0 comments on commit 07832b0

Please sign in to comment.