Skip to content

Commit

Permalink
[Cherry-pick] Update test_spmd_debugging.py to avoid code test code s…
Browse files Browse the repository at this point in the history
…elf and Promote int to float for tanh operation (#6263)(#6166) (#6329)
  • Loading branch information
ManfeiBai authored Jan 20, 2024
1 parent f1d17d3 commit dc57a74
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 4 deletions.
14 changes: 14 additions & 0 deletions test/cpp/test_aten_xla_tensor_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,20 @@ TEST_F(AtenXlaTensorTest, TestTanh) {
});
}

// In torch, tanh works with integer inputs. The same should be true for
// torch_xla
TEST_F(AtenXlaTensorTest, TestTanhWithInt) {
torch::Tensor a = torch::rand({2, 2});
torch::Tensor b = torch::tanh(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = torch::tanh(xla_a);
AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::tanh", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestClampMinMax) {
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
torch::Scalar min_val(0.311);
Expand Down
5 changes: 3 additions & 2 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,9 @@ def test_multi_host_replicated_tpu(self):
col.append(
rich.padding.Padding(
rich.align.Align(
'TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"),
(1, 1, 1, 1),
xr.device_type() + ' [0, 1, 2, 3, 4, 5, 6, 7]',
"center",
vertical="middle"), (1, 1, 1, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
fake_table.add_row(*col)
fake_console = rich.console.Console()
Expand Down
1 change: 0 additions & 1 deletion test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4413,7 +4413,6 @@ def test_aten_tanh_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.tanh, args, kwargs)

@unittest.skip
def test_aten_tanh_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
kwargs = dict()
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,11 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const {

torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32,
/*device=*/nullptr);
}
return ReturnOp(xla::Tanh(xla_input), loctx);
}

Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,11 @@ xla::Shape TakeOutputShape(const torch::lazy::Value& input,
}

xla::Shape TanhOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
xla::Shape result_shape = GetXlaShape(input);
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
result_shape.set_element_type(xla::PrimitiveType::F32);
}
return result_shape;
}

xla::Shape TrilOutputShape(const torch::lazy::Value& input) {
Expand Down

0 comments on commit dc57a74

Please sign in to comment.