forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
IndexingUtils.h
160 lines (145 loc) · 5.44 KB
/
IndexingUtils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#pragma once
#include <ATen/ExpandUtils.h>
#include <ATen/native/CanUse32BitIndexMath.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/core/IListRef.h>
#include <c10/util/irange.h>
namespace at::native {
[[noreturn]]
static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
TORCH_CHECK_INDEX(false, "The shape of the mask ", mask.sizes(), " at index ", maskIdx,
" does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx);
}
static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTensorListRef indices) {
// If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
std::vector<Tensor> result;
for (const auto& index_opt : indices) {
if (!index_opt.has_value()) {
result.emplace_back();
} else {
const auto& index = *index_opt;
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
if (index.scalar_type() == kByte) {
TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
" please use a dtype torch.bool instead.");
}
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
// corresponding dimensions in self
for (const auto j : c10::irange(index.dim())) {
int64_t srcIdx = static_cast<int64_t>(result.size() + j);
if (index.size(j) != self.size(srcIdx)) {
invalid_mask(self, srcIdx, index, j);
}
}
// Replace with nonzeros
auto nonzero = index.nonzero();
for (const auto j : c10::irange(index.dim())) {
result.emplace_back(nonzero.select(1, j));
}
} else {
result.emplace_back(index);
}
}
}
return result;
}
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
for (const auto& tensor : indices) {
if (tensor.has_value() && tensor->defined()) {
auto scalarType = tensor->scalar_type();
if (allow_int) {
if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
}
} else {
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
}
}
}
}
}
inline torch::List<c10::optional<Tensor>> toListOfOptionalTensors(ArrayRef<Tensor> list) {
torch::List<c10::optional<Tensor>> result;
result.reserve(list.size());
for (const Tensor& a : list) {
result.push_back(a);
}
return result;
}
inline torch::List<c10::optional<Tensor>> toListOfOptionalTensors(ArrayRef<IValue> list) {
torch::List<c10::optional<Tensor>> result;
result.reserve(list.size());
for (const IValue& a : list) {
result.push_back(a.isTensor() ? c10::optional<Tensor>(a.toTensor()) : c10::optional<Tensor>());
}
return result;
}
static C10_UNUSED bool hasContiguousSubspace(TensorList tl) {
// true if all the non-null tensors are adjacent
auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
auto isNull = [](const Tensor & tensor){ return !tensor.defined(); };
auto start = std::find_if(tl.begin(), tl.end(), isDefined);
auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined);
auto it = std::find_if(start, stop.base(), isNull);
return it == stop.base();
}
// Transposes the tensor and indices together so that all the non-null indices
// index the first k dimensions of the tensor. Returns the transposed tensor
// and the reordered indices. For example:
// transposeToFront(tensor, {nullptr, a, nullptr, b})
// returns
// tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr}
static C10_UNUSED std::tuple<Tensor, std::vector<Tensor>>
transposeToFront(const Tensor& self, TensorList indices) {
std::vector<int64_t> dims;
std::vector<Tensor> transposedIndices;
dims.reserve(self.dim());
for (const auto i : c10::irange(self.dim())) {
if (indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back(indices[i]);
}
}
for (const auto i : c10::irange(self.dim())) {
if (!indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back();
}
}
return std::make_tuple(self.permute(dims), std::move(transposedIndices));
}
inline std::tuple<Tensor, std::vector<Tensor>, std::vector<int64_t>>
transposeToFrontAndInvPerm(const Tensor& self, TensorList indices) {
std::vector<int64_t> dims;
std::vector<int64_t> invPerm;
std::vector<Tensor> transposedIndices;
dims.reserve(self.dim());
invPerm.resize(self.dim());
for (const auto i : c10::irange(self.dim())) {
if (indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back(indices[i]);
}
}
for (const auto i : c10::irange(self.dim())) {
if (!indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back();
}
}
for (const auto i : c10::irange(self.dim())) {
invPerm[dims[i]] = i;
}
return std::make_tuple(self.permute(dims), std::move(transposedIndices), std::move(invPerm));
}
struct AdvancedIndex {
AdvancedIndex(const Tensor& src, TensorList indices);
Tensor src;
std::vector<Tensor> indices;
DimVector indexed_sizes;
DimVector indexed_strides;
int64_t dims_before;
int64_t dims_after;
};
} //namespace at::native