From abb34a3e5c22c6da94964dc7ee1e950e00b0d983 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 6 Nov 2023 16:58:14 -0300 Subject: [PATCH] Remove `_unsafe_index` implementation. (#5769) --- codegen/xla_native_functions.yaml | 1 - test/cpp/test_aten_xla_tensor_1.cpp | 28 ---------------------------- torch_xla/csrc/aten_xla_type.cpp | 7 ------- 3 files changed, 36 deletions(-) diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index e89e7ddd65e..bd57e18b1a3 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -130,7 +130,6 @@ supported: - _to_cpu - _to_copy - _unsafe_view - - _unsafe_index.Tensor - adaptive_max_pool2d - adaptive_max_pool2d_backward - add.Scalar diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index d5cfb55bbd7..a5db45185f7 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -2043,34 +2043,6 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMaxInPlace) { ExpectCounterChanged("xla::scatter_reduce", cpp_test::GetIgnoredCounters()); } -TEST_F(AtenXlaTensorTest, TestUnsafeIndex) { - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor a = - isFloatingType(scalar_type) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type)); - for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) { - torch::List> indices{ - torch::tensor({0, 1}, torch::TensorOptions(index_scalar_type)), - torch::tensor({2, 3}, torch::TensorOptions(index_scalar_type))}; - torch::Tensor c0 = torch::_unsafe_index(a, indices); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - torch::List> xla_indices{ - CopyToDevice(*indices.get(0), device), - CopyToDevice(*indices.get(1), device)}; - torch::Tensor xla_c0 = torch::_unsafe_index(xla_a, xla_indices); - AllEqual(c0, xla_c0); - }); - } - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::_unsafe_index", cpp_test::GetIgnoredCounters()); -} - TEST_F(AtenXlaTensorTest, TestIndexSelect) { for (torch::ScalarType scalar_type : {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 49780fc8811..fa2197093dc 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -626,13 +626,6 @@ at::Tensor XLANativeFunctions::_unsafe_view(const at::Tensor& self, return view_copy_symint(self, c10::fromIntArrayRefSlow(size)); } -at::Tensor XLANativeFunctions::_unsafe_index( - const at::Tensor& self, - const c10::List>& indices) { - TORCH_LAZY_FN_COUNTER("xla::"); - return index(self, indices); -} - at::Tensor XLANativeFunctions::add(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) {