forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGatedLinearUnit.cpp
75 lines (65 loc) · 2.8 KB
/
GatedLinearUnit.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/Activation.h>
#include <ATen/native/TensorIterator.h>
namespace at {
namespace native {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(glu_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(glu_backward_stub);
Tensor& glu_out(const Tensor& self, int64_t dim, Tensor &result) {
// this can't pass anyway because a 0-dimensional tensor has "size" 1, which
// can't be evenly halved, but give a nicer error message here.
TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors");
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
const int64_t nIn = self.size(wrap_dim);
TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ",
wrap_dim, " is size ", nIn);
// size output to half of input
const int64_t selfSize = nIn / 2;
auto newSizes = self.sizes().vec();
newSizes[wrap_dim] = selfSize;
result.resize_(newSizes);
// half tensor
Tensor firstHalf = self.narrow(wrap_dim, 0, selfSize);
Tensor secondHalf = self.narrow(wrap_dim, selfSize, selfSize);
auto iter = TensorIterator::binary_op(result, firstHalf, secondHalf);
glu_stub(iter.device_type(), iter);
return result;
}
Tensor glu(const Tensor& self, int64_t dim) {
auto result = at::empty({0}, self.options());
return at::glu_out(result, self, dim);
}
Tensor& glu_backward_out(const Tensor& grad_output, const Tensor& input, int64_t dim, Tensor& grad_input) {
TORCH_CHECK(input.dim() > 0, "glu does not support 0-dimensional tensors");
auto wrap_dim = maybe_wrap_dim(dim, input.dim());
const int64_t nIn = input.size(wrap_dim);
TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ",
wrap_dim, " is size ", nIn);
grad_input.resize_as_(input);
const int64_t inputSize = nIn / 2;
// half tensor
Tensor firstHalf = input.narrow(wrap_dim, 0, inputSize);
Tensor secondHalf = input.narrow(wrap_dim, inputSize, inputSize);
Tensor gradInputfirstHalf = grad_input.narrow(wrap_dim, 0, inputSize);
Tensor gradInputsecondHalf = grad_input.narrow(wrap_dim, inputSize, inputSize);
at::sigmoid_out(gradInputfirstHalf, secondHalf);
// for second gradinput half, can get a better performance by fusion
auto iter = at::TensorIteratorConfig()
.add_output(gradInputsecondHalf)
.add_input(gradInputfirstHalf)
.add_input(firstHalf)
.add_input(grad_output)
.build();
glu_backward_stub(iter.device_type(), iter);
gradInputfirstHalf.mul_(grad_output);
return grad_input;
}
Tensor glu_backward(const Tensor& grad_output, const Tensor& input, int64_t dim) {
auto grad_input = at::empty({0}, input.options());
return at::glu_backward_out(grad_input, grad_output, input, dim);
}
} // at::native
} // at