forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Dropout.cpp
118 lines (96 loc) · 3.81 KB
/
Dropout.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
namespace at { namespace native {
namespace {
Tensor make_feature_noise(const Tensor& input) {
auto input_sizes = input.sizes();
AT_CHECK(input.dim() >= 2, "Feature dropout requires at least 2 dimensions in the input");
std::vector<int64_t> sizes;
sizes.reserve(input.dim());
sizes.push_back(input_sizes[0]);
sizes.push_back(input_sizes[1]);
for (int64_t i = 2; i < input.dim(); ++i)
sizes.push_back(1);
return at::empty(sizes, input.options());
}
bool is_fused_kernel_acceptable(const Tensor& input, double p) {
return input.is_cuda() && p > 0 && p < 1;
}
// NB: sure, we could have used different overloads here, but I would feel insecure
// knowing that this dispatch depends only on the constness of the references
template<bool inplace>
Tensor& multiply(Tensor& input, const Tensor& noise) {
static_assert(inplace, "Wrong multiply overload triggered in Dropout.cpp");
return input.mul_(noise);
}
template<bool inplace>
Tensor multiply(const Tensor& input, const Tensor& noise) {
static_assert(!inplace, "Wrong multiply overload triggered in Dropout.cpp");
return input.mul(noise);
}
template<bool feature_dropout, bool alpha_dropout, bool inplace, typename T>
typename std::conditional<inplace, Tensor&, Tensor>::type
_dropout_impl(T& input, double p, bool train) {
AT_CHECK(p >= 0 && p <= 1, "dropout probability has to be between 0 and 1, but got ", p);
if (p == 0 || !train) {
return input;
}
if (p == 1) {
return multiply<inplace>(input, at::zeros({}, input.options()));
}
at::Tensor b; // used for alpha_dropout only
auto noise = feature_dropout ? make_feature_noise(input) : at::empty_like(input);
noise.bernoulli_(1 - p);
if (alpha_dropout) {
constexpr double alpha = 1.7580993408473766;
double a = 1. / std::sqrt((alpha * alpha * p + 1) * (1 - p));
b = noise.add(-1).mul_(alpha * a).add_(alpha * a * p);
noise.mul_(a);
} else {
noise.div_(1 - p);
}
if (!alpha_dropout) {
return multiply<inplace>(input, noise);
} else {
return multiply<inplace>(input, noise).add_(b);
}
}
#define ALIAS_SPECIALIZATION(ALIAS_NAME, IS_FEATURE, IS_ALPHA) \
template <bool inplace, typename... Args> \
typename std::conditional<inplace, Tensor&, Tensor>::type \
ALIAS_NAME(Args&&... args) { \
return _dropout_impl<IS_FEATURE, IS_ALPHA, inplace>(std::forward<Args>(args)...); \
}
ALIAS_SPECIALIZATION(_dropout, false, false)
ALIAS_SPECIALIZATION(_feature_dropout, true, false)
ALIAS_SPECIALIZATION(_alpha_dropout, false, true )
ALIAS_SPECIALIZATION(_feature_alpha_dropout, true, true )
} // anomymous namepsace
Tensor dropout(const Tensor& input, double p, bool train) {
if (train && is_fused_kernel_acceptable(input, p)) {
return std::get<0>(at::_fused_dropout(input, 1 - p));
}
return _dropout<false>(input, p, train);
}
Tensor& dropout_(Tensor& input, double p, bool train) {
return _dropout<true>(input, p, train);
}
Tensor feature_dropout(const Tensor& input, double p, bool train) {
return _feature_dropout<false>(input, p, train);
}
Tensor& feature_dropout_(Tensor& input, double p, bool train) {
return _feature_dropout<true>(input, p, train);
}
Tensor alpha_dropout(const Tensor& input, double p, bool train) {
return _alpha_dropout<false>(input, p, train);
}
Tensor& alpha_dropout_(Tensor& input, double p, bool train) {
return _alpha_dropout<true>(input, p, train);
}
Tensor feature_alpha_dropout(const Tensor& input, double p, bool train) {
return _feature_alpha_dropout<false>(input, p, train);
}
Tensor& feature_alpha_dropout_(Tensor& input, double p, bool train) {
return _feature_alpha_dropout<true>(input, p, train);
}
}} // namespace at::native