forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
UnfoldBackwardKernel.cpp
152 lines (131 loc) · 4.74 KB
/
UnfoldBackwardKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/UnfoldBackward.h>
#include <ATen/native/cpu/Loops.h>
#include <c10/util/irange.h>
#if (defined(_WIN32) || defined(_WIN64))
#define RESTRICT __restrict
#else
#define RESTRICT __restrict__
#endif
// Note on naming: it is unconventional.
// grad_in does not mean that it is a gradient wrt to input,
// grad_in/grad_out is just an input/output of unfold_backward kernel.
//
// unfold_backward, the algorithm.
//
// Consider out = in.unfold(dim, size, step), then
// out.shape[dim] == (in.shape[dim] - size) / step + 1,
// out.shape[-1] == size.
// out.dims() == in.dims() + 1
//
// unfold_backward receives grad_in and returns grad_out such that
// grad_in.shape == out.shape,
// grad_out.shape = in.shape.
//
// unfold_backward considers the following two cases:
// case1. step >= size.
// case2. step < size.
//
// case1. step >= size.
// In this case the iteration takes over grad_in and performs the following copy:
// grad_out[..., i_out_dim,...] = grad_in[..., i_in_dim,..., i_in_last_dim],
// where i_out_dim = i_in_dim * step + i_in_last_dim.
//
// case2. step < size.
// In this case the iteration takes over grad_out,
// where grad_out[...,i_out_dim,...] accumulates all values
// grad_in[...,i_in_dim,...,i_in_last_dim], where
// i_in_dim is in [left_idx_fold, right_idx_fold],
// i_in_last_dim = i_out_dim - i_in_dim * step,
// left_idx_fold = (i_out_dim - size) / step
// if i_out_dim in [left_idx_fold * step, left_idx_fold * step + size)
// else (i_out_dim - size) / step + 1,
// right_idx_fold = i_out_dim / step.
//
// Simply put, given i_out_dim, we find which folds of grad_in
// intersect with i_out_dim, these are precisely [left_idx_fold, right_idx_fold],
// and then the corresponding value of grad_in[...,i_in_dim,...,i_in_last_dim]
// gets added up to grad_out[...,i_out_dim,...].
namespace at::native {
namespace {
template <typename scalar_t>
void _unfold_backward_internal_kernel(
TensorIterator& iter,
int64_t size,
int64_t step,
int64_t grad_in_dim_stride,
int64_t grad_in_last_dim_stride,
int64_t grad_in_dim_size,
int64_t grad_out_dim_stride
) {
if (iter.numel() == 0) {
return;
}
auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
auto* RESTRICT grad_out_ptr = data[0];
auto* RESTRICT grad_in_ptr = data[1];
auto* RESTRICT idx_dim_ptr = data[2];
for (const auto elem C10_UNUSED : c10::irange(nelems)) {
auto* RESTRICT grad_out_data = reinterpret_cast<scalar_t*>(grad_out_ptr);
auto* RESTRICT grad_in_data = reinterpret_cast<scalar_t*>(grad_in_ptr);
auto idx_dim = *reinterpret_cast<int64_t*>(idx_dim_ptr);
// left_fold potentially intersecting with idx_dim
// is either (idx_dim - size) / step or the next integer.
int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0;
if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) {
++left_fold_idx;
}
auto right_fold_idx = idx_dim / step;
right_fold_idx = (right_fold_idx >= grad_in_dim_size)
? (grad_in_dim_size - 1) : right_fold_idx;
for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) {
auto idx_last_dim = idx_dim - fold_idx * step;
*grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride
+ idx_last_dim * grad_in_last_dim_stride];
}
grad_out_ptr += strides[0];
grad_in_ptr += strides[1];
idx_dim_ptr += strides[2];
}
};
iter.for_each(loop);
}
void unfold_backward_cpu_kernel(
Tensor& grad_out,
const Tensor& grad_in,
int64_t dim,
int64_t size,
int64_t step
) {
dim = maybe_wrap_dim(dim, grad_out.dim());
// last dim stores the folds
auto last_dim = maybe_wrap_dim(-1, grad_in.dim());
auto grad_in_dim_stride = ensure_nonempty_stride(grad_in, dim);
auto grad_in_last_dim_stride = ensure_nonempty_stride(grad_in, last_dim);
auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
auto grad_out_dim_stride = ensure_nonempty_stride(grad_out, dim);
TensorIterator iter = _make_unfold_backward_iter_over_grad_out(
grad_out, grad_in, dim, size, step);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(),
"unfold_backward_cpu", [&] {
_unfold_backward_internal_kernel<scalar_t>(
iter,
size,
step,
grad_in_dim_stride,
grad_in_last_dim_stride,
grad_in_dim_size,
grad_out_dim_stride
);
}
);
}
}
REGISTER_DISPATCH(unfold_backward_stub, &unfold_backward_cpu_kernel);
} // namespace at::native