From 31105b43821967235adb77ff468b73aaf1cb3c23 Mon Sep 17 00:00:00 2001 From: rpsilva-aws Date: Fri, 6 Dec 2024 00:29:14 +0000 Subject: [PATCH] Support S32 index operations for copy and fill --- test/cpp/test_aten_xla_tensor_1.cpp | 754 +++++++++++++--------------- torch_xla/csrc/ops/index_ops.cpp | 22 +- 2 files changed, 374 insertions(+), 402 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index d204b344808..33b3be50f2d 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -3,6 +3,7 @@ #include +#include "absl/strings/str_cat.h" #include "test/cpp/cpp_test_util.h" #include "test/cpp/torch_xla_test.h" #include "torch_xla/csrc/aten_xla_bridge.h" @@ -2131,58 +2132,6 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMaxInPlace) { ExpectCounterChanged("xla::scatter_reduce", cpp_test::GetIgnoredCounters()); } -TEST_F(AtenXlaTensorTest, TestIndexSelect) { - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor a = - isFloatingType(scalar_type) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type)); - for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) { - torch::Tensor b = - torch::empty({2}, torch::TensorOptions(index_scalar_type)); - b[0] = 0; - b[1] = 2; - torch::Tensor c0 = torch::index_select(a, 0, b); - torch::Tensor c1 = torch::index_select(a, 1, b); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - torch::Tensor xla_b = CopyToDevice(b, device); - torch::Tensor xla_c0 = torch::index_select(xla_a, 0, xla_b); - torch::Tensor xla_c1 = torch::index_select(xla_a, 1, xla_b); - AllEqual(c0, xla_c0); - AllEqual(c1, xla_c1); - }); - } - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_select", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestIndexSelectRank0) { - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor a = - isFloatingType(scalar_type) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type)); - torch::Tensor b = - torch::scalar_tensor(2, torch::TensorOptions(torch::kLong)); - torch::Tensor c0 = torch::index_select(a, 0, b); - torch::Tensor c1 = torch::index_select(a, 1, b); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - torch::Tensor xla_b = CopyToDevice(b, device); - torch::Tensor xla_c0 = torch::index_select(xla_a, 0, xla_b); - torch::Tensor xla_c1 = torch::index_select(xla_a, 1, xla_b); - AllEqual(c0, xla_c0); - AllEqual(c1, xla_c1); - }); - } -} - TEST_F(AtenXlaTensorTest, TestInverse) { // TODO: Renable after the LAPACK dependency issue is resolved. GTEST_SKIP(); @@ -3010,204 +2959,271 @@ TEST_F(AtenXlaTensorTest, TestMaskIndexPut) { } } -TEST_F(AtenXlaTensorTest, TestIndexPutImpl) { +class IndexOpsAtenXlaTensorTest + : public AtenXlaTensorTest, + public ::testing::WithParamInterface< + std::tuple> { + protected: + torch::ScalarType GetIndexType() const { return std::get<0>(GetParam()); } + torch::ScalarType GetValueType() const { return std::get<1>(GetParam()); } +}; + +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexSelect) { + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor a = + isFloatingType(scalar_type) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type)); + torch::Tensor b = torch::empty({2}, torch::TensorOptions(GetIndexType())); + b[0] = 0; + b[1] = 2; + torch::Tensor c0 = torch::index_select(a, 0, b); + torch::Tensor c1 = torch::index_select(a, 1, b); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::Tensor xla_b = CopyToDevice(b, device); + torch::Tensor xla_c0 = torch::index_select(xla_a, 0, xla_b); + torch::Tensor xla_c1 = torch::index_select(xla_a, 1, xla_b); + AllEqual(c0, xla_c0); + AllEqual(c1, xla_c1); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_select", cpp_test::GetIgnoredCounters()); +} + +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexSelectRank0) { + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor a = + isFloatingType(scalar_type) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type)); + torch::Tensor b = + torch::scalar_tensor(2, torch::TensorOptions(GetIndexType())); + torch::Tensor c0 = torch::index_select(a, 0, b); + torch::Tensor c1 = torch::index_select(a, 1, b); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::Tensor xla_b = CopyToDevice(b, device); + torch::Tensor xla_c0 = torch::index_select(xla_a, 0, xla_b); + torch::Tensor xla_c1 = torch::index_select(xla_a, 1, xla_b); + AllEqual(c0, xla_c0); + AllEqual(c1, xla_c1); + }); +} + +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexPutImpl) { torch::Tensor indices = - torch::randint(-3, 3, {2, 4, 3}, torch::TensorOptions(torch::kLong)); - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor values = - torch::ones({3, 5, 6, 7}, torch::TensorOptions(scalar_type)); - for (bool accumulate : {false, true}) { - ForEachDevice([&](const torch::Device& device) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand({4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type)) - : torch::randint(100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type)); - torch::Tensor xla_params = CopyToDevice(params.clone(), device); - torch::Tensor result = torch::_index_put_impl_( - params, {indices}, values, accumulate, /*unsafe=*/true); - torch::Tensor xla_indices = CopyToDevice(indices, device); - torch::Tensor xla_values = CopyToDevice(values, device); - torch::Tensor xla_result = torch::_index_put_impl_( - xla_params, {xla_indices}, xla_values, accumulate, /*unsafe=*/true); - AllEqual(result, xla_result); - AllEqual(params, xla_params); - }); + torch::randint(-3, 3, {2, 4, 3}, torch::TensorOptions(GetIndexType())); + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor values = + torch::ones({3, 5, 6, 7}, torch::TensorOptions(scalar_type)); + for (bool accumulate : {false, true}) { + ForEachDevice([&](const torch::Device& device) { + torch::Tensor params = + isFloatingType(scalar_type) + ? torch::rand({4, 3, 5, 6, 7}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type)); + torch::Tensor xla_params = CopyToDevice(params.clone(), device); + torch::Tensor result = torch::_index_put_impl_( + params, {indices}, values, accumulate, /*unsafe=*/true); + torch::Tensor xla_indices = CopyToDevice(indices, device); + torch::Tensor xla_values = CopyToDevice(values, device); + torch::Tensor xla_result = torch::_index_put_impl_( + xla_params, {xla_indices}, xla_values, accumulate, /*unsafe=*/true); + AllEqual(result, xla_result); + AllEqual(params, xla_params); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::_index_put_impl_", - cpp_test::GetIgnoredCounters()); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_index_put_impl_", + cpp_test::GetIgnoredCounters()); } } -TEST_F(AtenXlaTensorTest, TestIndexFillWithScalar) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexFillWithScalar) { torch::Tensor index = - torch::tensor({0, 2}, torch::TensorOptions(torch::kLong)); + torch::tensor({0, 2}, torch::TensorOptions(GetIndexType())); + torch::ScalarType scalar_type = GetValueType(); torch::Scalar value = 42; - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4, 5}, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - torch::Tensor result = torch::index_fill(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_result = - torch::index_fill(xla_base, dim, xla_index, value); - AllEqual(result, xla_result); - }); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4, 5}, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor result = torch::index_fill(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_result = + torch::index_fill(xla_base, dim, xla_index, value); + AllEqual(result, xla_result); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); } } -TEST_F(AtenXlaTensorTest, TestIndexFillWithScalarInPlace) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexFillWithScalarInPlace) { torch::Tensor index = - torch::tensor({0, 2}, torch::TensorOptions(torch::kLong)); + torch::tensor({0, 2}, torch::TensorOptions(GetIndexType())); + torch::ScalarType scalar_type = GetValueType(); torch::Scalar value = 42; int rank = 3; - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - for (int dim = -rank; dim < rank; ++dim) { - ForEachDevice([&](const torch::Device& device) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4, 5}, - torch::TensorOptions(scalar_type)); - torch::Tensor xla_base = CopyToDevice(base.clone(), device); - torch::Tensor result = base.index_fill_(dim, index, value); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_result = xla_base.index_fill_(dim, xla_index, value); - AllEqual(result, xla_result); - AllEqual(base, xla_base); - }); + for (int dim = -rank; dim < rank; ++dim) { + ForEachDevice([&](const torch::Device& device) { + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4, 5}, + torch::TensorOptions(scalar_type)); + torch::Tensor xla_base = CopyToDevice(base.clone(), device); + torch::Tensor result = base.index_fill_(dim, index, value); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_result = xla_base.index_fill_(dim, xla_index, value); + AllEqual(result, xla_result); + AllEqual(base, xla_base); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); } } -TEST_F(AtenXlaTensorTest, TestIndexFillWithTensor) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexFillWithTensor) { torch::Tensor index = - torch::tensor({0, 2}, torch::TensorOptions(torch::kLong)); - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4, 5}, torch::TensorOptions(scalar_type)); - torch::Tensor value = - torch::scalar_tensor(42, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - torch::Tensor result = torch::index_fill(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - torch::index_fill(xla_base, dim, xla_index, xla_value); - AllEqual(result, xla_result); - }); + torch::tensor({0, 2}, torch::TensorOptions(GetIndexType())); + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4, 5}, torch::TensorOptions(scalar_type)); + torch::Tensor value = + torch::scalar_tensor(42, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor result = torch::index_fill(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + torch::index_fill(xla_base, dim, xla_index, xla_value); + AllEqual(result, xla_result); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); } } -TEST_F(AtenXlaTensorTest, TestIndexFillWithTensorInPlace) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexFillWithTensorInPlace) { torch::Tensor index = - torch::tensor({0, 2}, torch::TensorOptions(torch::kLong)); - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor value = - torch::scalar_tensor(42, torch::TensorOptions(scalar_type)); - int rank = 3; - for (int dim = -rank; dim < rank; ++dim) { - ForEachDevice([&](const torch::Device& device) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4, 5}, - torch::TensorOptions(scalar_type)); - torch::Tensor xla_base = CopyToDevice(base.clone(), device); - torch::Tensor result = base.index_fill_(dim, index, value); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - xla_base.index_fill_(dim, xla_index, xla_value); - AllEqual(result, xla_result); - AllEqual(base, xla_base); - }); + torch::tensor({0, 2}, torch::TensorOptions(GetIndexType())); + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor value = + torch::scalar_tensor(42, torch::TensorOptions(scalar_type)); + int rank = 3; + for (int dim = -rank; dim < rank; ++dim) { + ForEachDevice([&](const torch::Device& device) { + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4, 5}, + torch::TensorOptions(scalar_type)); + torch::Tensor xla_base = CopyToDevice(base.clone(), device); + torch::Tensor result = base.index_fill_(dim, index, value); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + xla_base.index_fill_(dim, xla_index, xla_value); + AllEqual(result, xla_result); + AllEqual(base, xla_base); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); } } -TEST_F(AtenXlaTensorTest, TestIndexFillRank0) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexFillRank0) { torch::Tensor index = - torch::scalar_tensor(2, torch::TensorOptions(torch::kLong)); - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4, 5}, torch::TensorOptions(scalar_type)); - torch::Tensor value = - torch::scalar_tensor(42, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - torch::Tensor result = torch::index_fill(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - torch::index_fill(xla_base, dim, xla_index, xla_value); - AllEqual(result, xla_result); - }); + torch::scalar_tensor(2, torch::TensorOptions(GetIndexType())); + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4, 5}, torch::TensorOptions(scalar_type)); + torch::Tensor value = + torch::scalar_tensor(42, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor result = torch::index_fill(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + torch::index_fill(xla_base, dim, xla_index, xla_value); + AllEqual(result, xla_result); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); } } -TEST_F(AtenXlaTensorTest, TestIndexAdd) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexAdd) { int index_size = 10; - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor index = torch::randint(0, base.size(dim), {index_size}, + torch::TensorOptions(GetIndexType())); + std::vector value_sizes(base.sizes().begin(), base.sizes().end()); + int canonical_dim = dim < 0 ? dim + rank : dim; + value_sizes[canonical_dim] = index_size; + torch::Tensor value = isFloatingType(scalar_type) - ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) { + ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) + : torch::randint(100, value_sizes, + torch::TensorOptions(scalar_type)); + torch::Tensor result = torch::index_add(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + torch::index_add(xla_base, dim, xla_index, xla_value); + AllClose(result, xla_result); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_add", cpp_test::GetIgnoredCounters()); + } +} + +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexAddInPlace) { + int index_size = 10; + int rank = 3; + std::vector alphas{0.0, 1.0, 2.0}; + torch::ScalarType scalar_type = GetValueType(); + for (int dim = -rank; dim < rank; ++dim) { + for (double alpha : alphas) { + ForEachDevice([&](const torch::Device& device) { + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) + : torch::randint(50, {5, 3, 7}, + torch::TensorOptions(scalar_type)); torch::Tensor index = torch::randint(0, base.size(dim), {index_size}, - torch::TensorOptions(index_scalar_type)); + torch::TensorOptions(GetIndexType())); std::vector value_sizes(base.sizes().begin(), base.sizes().end()); int canonical_dim = dim < 0 ? dim + rank : dim; @@ -3215,43 +3231,101 @@ TEST_F(AtenXlaTensorTest, TestIndexAdd) { torch::Tensor value = isFloatingType(scalar_type) ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) - : torch::randint(100, value_sizes, + : torch::randint(50, value_sizes, torch::TensorOptions(scalar_type)); - torch::Tensor result = torch::index_add(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - torch::index_add(xla_base, dim, xla_index, xla_value); - AllClose(result, xla_result); - }); - } + torch::Tensor xla_base = CopyToDevice(base.clone(), device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + xla_base.index_add_(dim, xla_index, xla_value, alpha); + torch::Tensor result = base.index_add_(dim, index, value, alpha); + AllClose(result, xla_result); + AllClose(base, xla_base); + }); } } ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); ExpectCounterChanged("xla::index_add", cpp_test::GetIgnoredCounters()); } -TEST_F(AtenXlaTensorTest, TestIndexAddInPlace) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexAddRank0) { + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor index = torch::randint(0, base.size(dim), at::IntArrayRef{}, + torch::TensorOptions(GetIndexType())); + std::vector value_sizes(base.sizes().begin(), base.sizes().end()); + int canonical_dim = dim < 0 ? dim + rank : dim; + value_sizes[canonical_dim] = 1; + torch::Tensor value = + isFloatingType(scalar_type) + ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) + : torch::randint(100, value_sizes, + torch::TensorOptions(scalar_type)); + torch::Tensor result = torch::index_add(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + torch::index_add(xla_base, dim, xla_index, xla_value); + AllEqual(result, xla_result); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_add", cpp_test::GetIgnoredCounters()); + } +} + +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexCopy) { + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor index = + torch::randperm(base.size(dim), torch::TensorOptions(GetIndexType())); + torch::Tensor value = + isFloatingType(scalar_type) + ? torch::rand(base.sizes(), torch::TensorOptions(scalar_type)) + : torch::randint(100, base.sizes(), + torch::TensorOptions(scalar_type)); + torch::Tensor result = torch::index_copy(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + torch::index_copy(xla_base, dim, xla_index, xla_value); + AllEqual(result, xla_result); + }); + } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_copy", cpp_test::GetIgnoredCounters()); +} + +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexCopyInPlace) { int index_size = 10; int rank = 3; - std::vector alphas{0.0, 1.0, 2.0}; - - for (torch::ScalarType scalar_type : - {torch::kByte, torch::kFloat, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - for (int dim = -rank; dim < rank; ++dim) { - for (double alpha : alphas) { - ForEachDevice([&](const torch::Device& device) { + torch::ScalarType scalar_type = GetValueType(); + for (int dim = -rank; dim < rank; ++dim) { + ForEachDevice( + {XlaDeviceType::CPU, XlaDeviceType::TPU}, + [&](const torch::Device& device) { torch::Tensor base = isFloatingType(scalar_type) ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) - : torch::randint(50, {5, 3, 7}, + : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); torch::Tensor index = torch::randint(0, base.size(dim), {index_size}, - torch::TensorOptions(torch::kLong)); + torch::TensorOptions(GetIndexType())); std::vector value_sizes(base.sizes().begin(), base.sizes().end()); int canonical_dim = dim < 0 ? dim + rank : dim; @@ -3259,174 +3333,70 @@ TEST_F(AtenXlaTensorTest, TestIndexAddInPlace) { torch::Tensor value = isFloatingType(scalar_type) ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) - : torch::randint(50, value_sizes, + : torch::randint(100, value_sizes, torch::TensorOptions(scalar_type)); torch::Tensor xla_base = CopyToDevice(base.clone(), device); + torch::Tensor result = base.index_copy(dim, index, value); torch::Tensor xla_index = CopyToDevice(index, device); torch::Tensor xla_value = CopyToDevice(value, device); torch::Tensor xla_result = - xla_base.index_add_(dim, xla_index, xla_value, alpha); - torch::Tensor result = base.index_add_(dim, index, value, alpha); - AllClose(result, xla_result); - AllClose(base, xla_base); - }); - } - } - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_add", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestIndexAddRank0) { - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - torch::Tensor index = torch::randint(0, base.size(dim), at::IntArrayRef{}, - torch::TensorOptions(torch::kLong)); - std::vector value_sizes(base.sizes().begin(), - base.sizes().end()); - int canonical_dim = dim < 0 ? dim + rank : dim; - value_sizes[canonical_dim] = 1; - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) - : torch::randint(100, value_sizes, - torch::TensorOptions(scalar_type)); - torch::Tensor result = torch::index_add(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - torch::index_add(xla_base, dim, xla_index, xla_value); - AllEqual(result, xla_result); - }); + xla_base.index_copy(dim, xla_index, xla_value); + AllEqual(result, xla_result); + AllEqual(base, xla_base); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_add", cpp_test::GetIgnoredCounters()); - } - } -} - -TEST_F(AtenXlaTensorTest, TestIndexCopy) { - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - torch::Tensor index = - torch::randperm(base.size(dim), torch::TensorOptions(torch::kLong)); - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand(base.sizes(), torch::TensorOptions(scalar_type)) - : torch::randint(100, base.sizes(), - torch::TensorOptions(scalar_type)); - torch::Tensor result = torch::index_copy(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - torch::index_copy(xla_base, dim, xla_index, xla_value); - AllEqual(result, xla_result); - }); - } - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_copy", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestIndexCopyInPlace) { - int index_size = 10; - int rank = 3; - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - for (int dim = -rank; dim < rank; ++dim) { - ForEachDevice( - {XlaDeviceType::CPU, XlaDeviceType::TPU}, - [&](const torch::Device& device) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {5, 3, 7}, - torch::TensorOptions(scalar_type)); - torch::Tensor index = - torch::randint(0, base.size(dim), {index_size}, - torch::TensorOptions(torch::kLong)); - std::vector value_sizes(base.sizes().begin(), - base.sizes().end()); - int canonical_dim = dim < 0 ? dim + rank : dim; - value_sizes[canonical_dim] = index_size; - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand(value_sizes, - torch::TensorOptions(scalar_type)) - : torch::randint(100, value_sizes, - torch::TensorOptions(scalar_type)); - torch::Tensor xla_base = CopyToDevice(base.clone(), device); - torch::Tensor result = base.index_copy(dim, index, value); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - xla_base.index_copy(dim, xla_index, xla_value); - AllEqual(result, xla_result); - AllEqual(base, xla_base); - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_copy", - cpp_test::GetIgnoredCounters()); - }); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_copy", + cpp_test::GetIgnoredCounters()); + }); } } -TEST_F(AtenXlaTensorTest, TestIndexCopyRank0) { - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexCopyRank0) { + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor index = torch::randint(0, base.size(dim), at::IntArrayRef{}, + torch::TensorOptions(GetIndexType())); + std::vector value_sizes(base.sizes().begin(), base.sizes().end()); + int canonical_dim = dim < 0 ? dim + rank : dim; + value_sizes[canonical_dim] = 1; + torch::Tensor value = isFloatingType(scalar_type) - ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - torch::Tensor index = torch::randint(0, base.size(dim), at::IntArrayRef{}, - torch::TensorOptions(torch::kLong)); - std::vector value_sizes(base.sizes().begin(), - base.sizes().end()); - int canonical_dim = dim < 0 ? dim + rank : dim; - value_sizes[canonical_dim] = 1; - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) - : torch::randint(100, value_sizes, - torch::TensorOptions(scalar_type)); - torch::Tensor result = torch::index_copy(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - torch::index_copy(xla_base, dim, xla_index, xla_value); - AllEqual(result, xla_result); - }); + ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) + : torch::randint(100, value_sizes, + torch::TensorOptions(scalar_type)); + torch::Tensor result = torch::index_copy(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + torch::index_copy(xla_base, dim, xla_index, xla_value); + AllEqual(result, xla_result); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_copy", cpp_test::GetIgnoredCounters()); - } - } -} + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_copy", cpp_test::GetIgnoredCounters()); + } +} + +INSTANTIATE_TEST_SUITE_P( + IndexOpsAtenXlaTensor, IndexOpsAtenXlaTensorTest, + ::testing::Combine(::testing::Values(torch::kLong, torch::kInt), + ::testing::Values(torch::kFloat, torch::kByte, + torch::kChar, torch::kShort, + torch::kInt, torch::kLong)), + [](const testing::TestParamInfo& + info) { + const auto& params = info.param; + return absl::StrCat("IndexType", torch::toString(std::get<0>(params)), + "_", "ValueType", + torch::toString(std::get<1>(params))); + }); TEST_F(AtenXlaTensorTest, TestRelu) { torch::Tensor input = diff --git a/torch_xla/csrc/ops/index_ops.cpp b/torch_xla/csrc/ops/index_ops.cpp index 1472f05b78a..ade91cb2efb 100644 --- a/torch_xla/csrc/ops/index_ops.cpp +++ b/torch_xla/csrc/ops/index_ops.cpp @@ -365,8 +365,10 @@ torch::lazy::Value IndexPutByTensors( torch::lazy::NodePtr IndexFill(const XLATensorPtr& base, int64_t dim, const XLATensorPtr& index, const at::Scalar& value) { - XLA_CHECK_EQ(index->dtype(), at::ScalarType::Long) - << "Fill index is expected to be of scalar type Long, but it is " + XLA_CHECK(index->dtype() == at::ScalarType::Long || + index->dtype() == at::ScalarType::Int) + << "Fill index is expected to be of scalar type Long or scalar type Int, " + "but it is " << index->dtype(); XLA_CHECK_LE(index->shape().get().rank(), 1) << "Fill index is supposed to be a vector"; @@ -379,10 +381,10 @@ torch::lazy::NodePtr IndexFill(const XLATensorPtr& base, int64_t dim, torch::lazy::NodePtr IndexFill(const XLATensorPtr& base, int64_t dim, const XLATensorPtr& index, const XLATensorPtr& value) { - XLA_CHECK_EQ(index->dtype(), at::ScalarType::Long) - << "Fill index is expected to be of scalar type Long, but it is " - << index->dtype(); - XLA_CHECK_LE(index->shape().get().rank(), 1) + XLA_CHECK(index->dtype() == at::ScalarType::Long || + index->dtype() == at::ScalarType::Int) + << "Fill index is expected to be of scalar type Long or scalar type Int, " + "but it is " XLA_CHECK_LE(index->shape().get().rank(), 1) << "Fill index is supposed to be a vector"; XLA_CHECK_EQ(value->shape().get().rank(), 0) << "Fill only supports a 0-dimensional value tensor"; @@ -407,10 +409,10 @@ torch::lazy::Value IndexAdd(const XLATensorPtr& base, int64_t dim, torch::lazy::Value IndexCopy(const XLATensorPtr& base, int64_t dim, const XLATensorPtr& index, const XLATensorPtr& source) { - XLA_CHECK_EQ(index->dtype(), at::ScalarType::Long) - << "Copy index is expected to be of scalar type Long, but it is " - << index->dtype(); - XLA_CHECK_LE(index->shape().get().rank(), 1) + XLA_CHECK(index->dtype() == at::ScalarType::Long || + index->dtype() == at::ScalarType::Int) + << "Add index is expected to be of scalar type Long or scalar type Int, " + "but it is " XLA_CHECK_LE(index->shape().get().rank(), 1) << "Copy index is supposed to be a vector"; return IndexCopyOp(base->GetIrValue(), dim, index->GetIrValue(), source->GetIrValue());