forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SpatialCrossMapLRN.cu
127 lines (122 loc) · 4.8 KB
/
SpatialCrossMapLRN.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include <THCUNN/THCUNN.h>
#include <TH/THHalf.h>
#include <THCUNN/THCHalfAutoNumerics.cuh>
#include <THC/THCTensor.hpp>
#include <THC/THCStorage.hpp>
#include <THCUNN/common.h>
#include <c10/macros/Macros.h>
template <typename Dtype, typename Acctype>
__global__ void
#if __CUDA_ARCH__ >= 320 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS)
#endif
LRNFillScale(const int nthreads, const Dtype* const in,
const int num, const int channels, const int height,
const int width, const int size, const Dtype alpha_over_size,
const Dtype k, Dtype* const scale) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local offset
const int w = index % width;
const int h = (index / width) % height;
const int n = index / width / height;
const int offset = (n * channels * height + h) * width + w;
const int step = height * width;
const Dtype* const in_off = in + offset;
Dtype* const scale_off = scale + offset;
int head = 0;
const int pre_pad = (size - 1) / 2;
const int post_pad = size - pre_pad - 1;
Acctype accum_scale = Acctype(0);
// fill the scale at [n, :, h, w]
// accumulate values
while (head < post_pad && head < channels) {
accum_scale += in_off[head * step] * in_off[head * step];
++head;
}
// both add and subtract
while (head < channels) {
accum_scale += in_off[head * step] * in_off[head * step];
if (head - size >= 0) {
accum_scale -= in_off[(head - size) * step]
* in_off[(head - size) * step];
}
scale_off[(head - post_pad) * step] = ScalarConvert<Acctype, Dtype>::to(k + accum_scale * alpha_over_size);
++head;
}
// subtract only
while (head < channels + post_pad) {
if (head - size >= 0) {
accum_scale -= in_off[(head - size) * step]
* in_off[(head - size) * step];
}
scale_off[(head - post_pad) * step] = ScalarConvert<Acctype, Dtype>::to(k + accum_scale * alpha_over_size);
++head;
}
}
}
template <typename Dtype>
__global__ void LRNComputeOutput(const int nthreads, const Dtype* in,
const Dtype* scale, const Dtype negative_beta, Dtype* out) {
CUDA_KERNEL_LOOP(index, nthreads) {
out[index] = in[index] * pow(scale[index], negative_beta);
}
}
template <typename Dtype, typename Acctype>
__global__ void LRNComputeDiff(const int nthreads,
const Dtype* const bottom_data, const Dtype* const top_data,
const Dtype* const scale, const Dtype* const top_diff,
const int num, const int channels, const int height,
const int width, const int size, const Dtype negative_beta,
const Dtype cache_ratio, Dtype* const bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local offset
const int w = index % width;
const int h = (index / width) % height;
const int n = index / width / height;
const int offset = (n * channels * height + h) * width + w;
const int step = height * width;
const Dtype* const bottom_off = bottom_data + offset;
const Dtype* const top_off = top_data + offset;
const Dtype* const scale_off = scale + offset;
const Dtype* const top_diff_off = top_diff + offset;
Dtype* const bottom_diff_off = bottom_diff + offset;
int head = 0;
const int pre_pad = size - (size + 1) / 2;
const int post_pad = size - pre_pad - 1;
Acctype accum_ratio = Acctype(0);
// accumulate values
while (head < post_pad && head < channels) {
accum_ratio += top_diff_off[head * step] * top_off[head * step] /
scale_off[head * step];
++head;
}
// both add and subtract
while (head < channels) {
accum_ratio += top_diff_off[head * step] * top_off[head * step] /
scale_off[head * step];
if (head - size >= 0) {
accum_ratio -= top_diff_off[(head - size) * step] *
top_off[(head - size) * step] / scale_off[(head - size) * step];
}
bottom_diff_off[(head - post_pad) * step] =
ScalarConvert<Acctype, Dtype>::to(top_diff_off[(head - post_pad) * step]
* pow(scale_off[(head - post_pad) * step], negative_beta)
- cache_ratio * bottom_off[(head - post_pad) * step] * accum_ratio);
++head;
}
// subtract only
while (head < channels + post_pad) {
if (head - size >= 0) {
accum_ratio -= top_diff_off[(head - size) * step] *
top_off[(head - size) * step] / scale_off[(head - size) * step];
}
bottom_diff_off[(head - post_pad) * step] =
ScalarConvert<Acctype, Dtype>::to(top_diff_off[(head - post_pad) * step]
* pow(scale_off[(head - post_pad) * step], negative_beta)
- cache_ratio * bottom_off[(head - post_pad) * step] * accum_ratio);
++head;
}
}
}
#include <THCUNN/generic/SpatialCrossMapLRN.cu>
#include <THC/THCGenerateFloatTypes.h>