forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LookupTableBag.cu
147 lines (126 loc) · 5.1 KB
/
LookupTableBag.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include <THCUNN/THCUNN.h>
#include <THCUNN/common.h>
#include <THC/THCTensor.hpp>
#include <THC/THCThrustAllocator.cuh>
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/transform_reduce.h>
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
#include <thrust/system/cuda/execution_policy.h>
#endif
#include <thrust/unique.h>
#include <TH/THHalf.h>
#include <THCUNN/THCHalfAutoNumerics.cuh>
#include <THC/THCTensorSort.cuh>
#if defined(__HIP_PLATFORM_HCC__)
const int WARP_SIZE = 64;
#else
const int WARP_SIZE = 32;
#endif
const int MODE_SUM = 0;
const int MODE_MEAN = 1;
template <typename Dtype, typename Acctype>
__global__ void cunn_LookupTableBag_updateOutputKernel(
int64_t *input, int64_t *offsets, Dtype *weight, Dtype *output,
int64_t *offset2bag, int64_t numIndices, int64_t numBags, int64_t stride, int mode,
int64_t *bag_size) {
// the strategy here is that each bag x feature is handled by a single thread
int64_t chunksPerBag = THCCeilDiv(stride, (int64_t) blockDim.x);
int64_t numChunks = numBags * chunksPerBag;
int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
int64_t chunkStride = gridDim.x * blockDim.y;
for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
if (featureDim < stride) {
int64_t bag = chunk / chunksPerBag;
Dtype* weightFeat = weight + featureDim;
int64_t begin = offsets[bag] - TH_INDEX_BASE;
int64_t end = (bag < numBags - 1) ? (offsets[bag + 1] - TH_INDEX_BASE) : numIndices;
assert(end >= begin);
Acctype weightFeatSum = ScalarConvert<float, Acctype>::to(0);
int64_t bag_size_ = 0;
for (int64_t emb = begin; emb < end; emb++) {
const int weightRow = ((int) input[emb] - TH_INDEX_BASE) * stride;
weightFeatSum += ScalarConvert<Dtype, Acctype>::to(weightFeat[weightRow]);
bag_size_ ++;
if (featureDim == 0) {
offset2bag[emb] = bag + TH_INDEX_BASE;
}
}
if (mode == MODE_MEAN) {
weightFeatSum = weightFeatSum / ScalarConvert<int64_t, Acctype>::to(bag_size_);
bag_size[bag] = bag_size_;
}
(void) MODE_SUM; //silence warnings about unused MODE_SUM;
output[bag * stride + featureDim] = ScalarConvert<Acctype, Dtype>::to(weightFeatSum);
}
}
}
// FIXME: removed the accGradParametersKernelByFeature case present in
// LookupTable. That kernel is faster at small sizes (<768 indices), which
// does not need LookupTableBag (LookupTable + Sum works fine), but would
// still be nice to not be slow in that case.
template <typename Dtype, typename Acctype>
__global__ void cunn_LookupTableBag_accGradParametersKernel(
int64_t *input, int64_t *indices, Dtype *gradOutput, Dtype *gradWeight, int64_t *offset2bag,
int64_t *count, Dtype defaultScale, ptrdiff_t numel, int64_t stride,
int mode, int64_t *bag_size) {
int idx = blockIdx.x * 4 + threadIdx.y;
// Each warp is responsible for an input into the LookupTable.
// If the preceding input has the same as this input, then the warp
// exits immediately. The warp also processes subsequent inputs with the
// same value.
//
// Input Warp
// 1 <warp 1>
// 1 <warp 1> (<warp 2> exits without doing any work)
// 5 <warp 3>
// 8 <warp 4>
// Number of values proceessed by each thread (grain size)
const int SZ = 4;
if (idx < numel
&& (idx == 0 || input[idx] != input[idx - 1])) {
do {
const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
const int weightRow = ((int) input[idx] - TH_INDEX_BASE) * stride;
// Note: only this line changes from LookupTable_accgradParametersKernel
const int origRow = ((int) indices[idx] - TH_INDEX_BASE);
const int seq_number = offset2bag[origRow] - TH_INDEX_BASE;
const int gradOutputRow = ((int) seq_number) * stride;
const Acctype scale = count ? ScalarConvert<Dtype, Acctype>::to(defaultScale) / count[idx] : ScalarConvert<Dtype, Acctype>::to(defaultScale);
Acctype gradient[SZ];
Acctype weight[SZ];
#pragma unroll
for (int ii = 0; ii < SZ; ii++)
{
int featureDim = startFeature + ii * WARP_SIZE;
if (featureDim < stride)
{
gradient[ii] = ScalarConvert<Dtype, Acctype>::to(gradOutput[gradOutputRow + featureDim]);
if (mode == MODE_MEAN) {
gradient[ii] /= bag_size[seq_number];
}
weight[ii] = ScalarConvert<Dtype, Acctype>::to(gradWeight[weightRow + featureDim]);
}
}
#pragma unroll
for (int ii = 0; ii < SZ; ii++)
{
weight[ii] += gradient[ii] * scale;
}
#pragma unroll
for (int ii = 0; ii < SZ; ii++)
{
int featureDim = startFeature + ii * WARP_SIZE;
if (featureDim < stride)
{
gradWeight[weightRow + featureDim] = ScalarConvert<Acctype, Dtype>::to(weight[ii]);
}
}
idx++;
} while (idx < numel && input[idx] == input[idx - 1]);
}
}
#include <THCUNN/generic/LookupTableBag.cu>
#include <THC/THCGenerateFloatTypes.h>