Skip to content

Commit

Permalink
index: fix index of 0-element tensor by 0-element tensor. (#7113)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored May 28, 2024
1 parent fd4900c commit be3b08e
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
48 changes: 48 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2021,6 +2021,54 @@ def foo(t0, t1):

self.assertEqual(r, Xr.cpu())

def test_index_zero_tensor_by_zero_tensor(self):

# Test if simple one-tensor indexing works.
# Should return a non-permuted tensor.
def f1(x, i):
return x[i]

# Test if scattered two-tensor indexing works.
# Should return a permuted tensor, with indexed dimensions first.
def f2(x, i0, i1):
return x[:, i0, :, i1]

cases = {
f1: [
((0,), (0,)),
((0, 10), (0, 5, 5)),
((0, 3, 3), (5, 5, 0)),
],
f2: [
((10, 0, 10, 10), (5, 0, 5), (5, 1, 1)),
((0, 0, 10, 0), (5, 5, 0), (5, 5, 1)),
]
}

def make_tensor(shape):
return torch.rand(shape)

def make_index(shape):
return torch.randint(0, 100, shape, dtype=torch.long)

def test(f, xshape, ishapes):
x = make_tensor(xshape)
ilist = [make_index(s) for s in ishapes]

Xx = x.to(xm.xla_device())
Xilist = [i.to(xm.xla_device()) for i in ilist]

out = f(x, *ilist)
Xout = f(Xx, *Xilist)

self.assertEqual(out, Xout.cpu())

for xshape, ishape in cases[f1]:
test(f1, xshape, (ishape,))

for xshape, i0shape, i1shape in cases[f2]:
test(f2, xshape, (i0shape, i1shape))


class MNISTComparator(nn.Module):

Expand Down
47 changes: 47 additions & 0 deletions torch_xla/csrc/ops/index_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,59 @@ torch::lazy::Value EnsureRank1(const torch::lazy::Value& index) {
: index;
}

bool HasZeroElementIndex(absl::Span<const XLATensorPtr> indices) {
return std::any_of(indices.begin(), indices.end(),
[](const XLATensorPtr& index) {
return xla::ShapeUtil::ElementsIn(*index->shape()) == 0;
});
}

XLATensorPtr GetZeroElementTensor(const XLATensorPtr& base,
absl::Span<const XLATensorPtr> indices,
int64_t start_dim) {
// Returns a 0-element tensor described by the indexing.
//
// At this point, we know that we are indexing 'base' with 0-element
// tensors, i.e. one of its dimensions has size 0. Therefore, we
// need to return a 0-element tensor of the appropriate size.
//
// This function computes the output size and calls 'full' to create the
// desired 0-element tensor.
std::vector<int64_t> dimensions;

// In the beginning, we add all dimensions that come before the ones that
// correspond to the indices.
absl::Span<const int64_t> base_dimensions = base->shape().get().dimensions();
dimensions.insert(dimensions.end(), base_dimensions.begin(),
base_dimensions.begin() + start_dim);

// Then, we add the dimensions of the first index. Notice that, at this
// point, all indices are already broadcasted, i.e. have the same size.
// So, we grab the first one for convenience.
for (auto dim : indices.front()->shape().get().dimensions()) {
dimensions.push_back(dim);
}

// Finally, add the remaining dimensions that weren't indexed.
dimensions.insert(dimensions.end(),
base_dimensions.begin() + start_dim + indices.size(),
base_dimensions.end());

return tensor_methods::full(dimensions, 0, base->GetDevice(), base->dtype());
}

XLATensorPtr IndexByTensors(const XLATensorPtr& base,
absl::Span<const XLATensorPtr> indices,
int64_t start_dim) {
if (indices.empty()) {
return base;
}
// Check whether we are trying to index with a 0-element tensor.
// If so, there's no need to compute anything. We simply return
// a 0-element tensor.
if (HasZeroElementIndex(indices)) {
return GetZeroElementTensor(base, indices, start_dim);
}
auto canonical_indices = WrapIndicesOnce(base, indices, start_dim);
int64_t indices_rank = canonical_indices.front()->shape().get().rank();
// Stack the indices to allow the whole multi-indexing to be dispatched with a
Expand Down

0 comments on commit be3b08e

Please sign in to comment.