diff --git a/test/test_operations.py b/test/test_operations.py index 5890513efe0..f95c387f435 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2021,6 +2021,54 @@ def foo(t0, t1): self.assertEqual(r, Xr.cpu()) + def test_index_zero_tensor_by_zero_tensor(self): + + # Test if simple one-tensor indexing works. + # Should return a non-permuted tensor. + def f1(x, i): + return x[i] + + # Test if scattered two-tensor indexing works. + # Should return a permuted tensor, with indexed dimensions first. + def f2(x, i0, i1): + return x[:, i0, :, i1] + + cases = { + f1: [ + ((0,), (0,)), + ((0, 10), (0, 5, 5)), + ((0, 3, 3), (5, 5, 0)), + ], + f2: [ + ((10, 0, 10, 10), (5, 0, 5), (5, 1, 1)), + ((0, 0, 10, 0), (5, 5, 0), (5, 5, 1)), + ] + } + + def make_tensor(shape): + return torch.rand(shape) + + def make_index(shape): + return torch.randint(0, 100, shape, dtype=torch.long) + + def test(f, xshape, ishapes): + x = make_tensor(xshape) + ilist = [make_index(s) for s in ishapes] + + Xx = x.to(xm.xla_device()) + Xilist = [i.to(xm.xla_device()) for i in ilist] + + out = f(x, *ilist) + Xout = f(Xx, *Xilist) + + self.assertEqual(out, Xout.cpu()) + + for xshape, ishape in cases[f1]: + test(f1, xshape, (ishape,)) + + for xshape, i0shape, i1shape in cases[f2]: + test(f2, xshape, (i0shape, i1shape)) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/ops/index_ops.cpp b/torch_xla/csrc/ops/index_ops.cpp index ddbf0c677aa..ccc3d090a56 100644 --- a/torch_xla/csrc/ops/index_ops.cpp +++ b/torch_xla/csrc/ops/index_ops.cpp @@ -277,12 +277,59 @@ torch::lazy::Value EnsureRank1(const torch::lazy::Value& index) { : index; } +bool HasZeroElementIndex(absl::Span indices) { + return std::any_of(indices.begin(), indices.end(), + [](const XLATensorPtr& index) { + return xla::ShapeUtil::ElementsIn(*index->shape()) == 0; + }); +} + +XLATensorPtr GetZeroElementTensor(const XLATensorPtr& base, + absl::Span indices, + int64_t start_dim) { + // Returns a 0-element tensor described by the indexing. + // + // At this point, we know that we are indexing 'base' with 0-element + // tensors, i.e. one of its dimensions has size 0. Therefore, we + // need to return a 0-element tensor of the appropriate size. + // + // This function computes the output size and calls 'full' to create the + // desired 0-element tensor. + std::vector dimensions; + + // In the beginning, we add all dimensions that come before the ones that + // correspond to the indices. + absl::Span base_dimensions = base->shape().get().dimensions(); + dimensions.insert(dimensions.end(), base_dimensions.begin(), + base_dimensions.begin() + start_dim); + + // Then, we add the dimensions of the first index. Notice that, at this + // point, all indices are already broadcasted, i.e. have the same size. + // So, we grab the first one for convenience. + for (auto dim : indices.front()->shape().get().dimensions()) { + dimensions.push_back(dim); + } + + // Finally, add the remaining dimensions that weren't indexed. + dimensions.insert(dimensions.end(), + base_dimensions.begin() + start_dim + indices.size(), + base_dimensions.end()); + + return tensor_methods::full(dimensions, 0, base->GetDevice(), base->dtype()); +} + XLATensorPtr IndexByTensors(const XLATensorPtr& base, absl::Span indices, int64_t start_dim) { if (indices.empty()) { return base; } + // Check whether we are trying to index with a 0-element tensor. + // If so, there's no need to compute anything. We simply return + // a 0-element tensor. + if (HasZeroElementIndex(indices)) { + return GetZeroElementTensor(base, indices, start_dim); + } auto canonical_indices = WrapIndicesOnce(base, indices, start_dim); int64_t indices_rank = canonical_indices.front()->shape().get().rank(); // Stack the indices to allow the whole multi-indexing to be dispatched with a