forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
IndexKernel.cpp
125 lines (111 loc) · 4 KB
/
IndexKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <ATen/native/Indexing.h>
#include <cmath>
#include <iostream>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec256/vec256.h>
namespace at { namespace native {
namespace {
using namespace vec256;
struct Indexer {
Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
IntArrayRef original_sizes, IntArrayRef original_strides)
: num_indexers(num_indexers)
, indexers(indexers)
, indexer_strides(indexer_strides)
, original_strides(original_strides.data())
, original_sizes(original_sizes.data()) {
AT_ASSERT(original_strides.size() == num_indexers);
AT_ASSERT(original_sizes.size() == num_indexers);
}
int64_t num_indexers;
char** indexers;
const int64_t* indexer_strides;
const int64_t* original_strides;
const int64_t* original_sizes;
int64_t get(int64_t idx) {
int64_t offset = 0;
for (int j = 0; j < num_indexers; j++) {
int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
int64_t size = original_sizes[j];
if (value < -size || value >= size) {
AT_INDEX_ERROR("index ", value, " is out of bounds for dimension ", j, " with size ", size);
}
if (value < 0) {
value += size;
}
offset += value * original_strides[j];
}
return offset;
}
};
static bool is_constant_index(int ntensor, const int64_t* strides) {
AT_ASSERT(ntensor >= 3);
for (int arg = 2; arg < ntensor; arg++) {
if (strides[arg] != 0) {
return false;
}
}
return true;
}
template <typename scalar_t, typename func_t>
void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride,
const func_t& f, bool serial_execution=false)
{
auto loop = [&](int ntensor, char** data, const int64_t* strides, int64_t n) {
auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
char* dst = data[0];
char* src = data[1];
if (is_constant_index(ntensor, strides)) {
// specialization for when every element uses the same index
int64_t offset = indexer.get(0);
if (strides[0] == sizeof(scalar_t) && strides[1] == sizeof(scalar_t)) {
for (int64_t i = 0; i < n; i++) {
f(dst + strides[0] * i, src + strides[1] * i, offset);
}
} else {
for (int64_t i = 0; i < n; i++) {
f(dst + strides[0] * i, src + strides[1] * i, offset);
}
}
} else {
for (int64_t i = 0; i < n; i++) {
int64_t offset = indexer.get(i);
f(dst + strides[0] * i, src + strides[1] * i, offset);
}
}
};
if (serial_execution) {
iter.serial_for_each(loop, {0, iter.numel()});
} else {
iter.for_each(loop);
}
}
void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
AT_DISPATCH_ALL_TYPES(iter.type(0), "index", [&] {
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
*(scalar_t*)dst = *(scalar_t*)(src + offset);
});
});
}
void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
// NOTE: duplicate indices are only supported if accumulate is true.
AT_DISPATCH_ALL_TYPES(iter.type(0), "index_put", [&] {
if (accumulate) {
// TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
// this needs to be thread-safe.
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
*(scalar_t*)(dst + offset) += *(scalar_t*)src;
}, /*serial_execution=*/true);
} else {
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
*(scalar_t*)(dst + offset) = *(scalar_t*)src;
});
}
});
}
} // anonymous namespace
REGISTER_DISPATCH(index_stub, &index_kernel);
REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
}} // namespace at::native