-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlayers.py
92 lines (71 loc) · 2.68 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import math
import torch
from torch import nn
from torch import exp, log
class _GumbelLayer(nn.Module):
'''
Base class for all gumbel-stochastic layers
'''
def __init__(self, inner, N, r, device, norm):
'''
inner:
pytorch module subclass instance used as deterministic inner class
N: gumbel softmax hyperparameter, see paper
r: gumbel softmax hyperparameter, see paper
device:
cpu or gpu. Needed for gumbel sample, would be nice to make
obselete
norm:
batch norm object
'''
super(_GumbelLayer, self).__init__()
self.inner = inner
self.device = device
if norm:
self.norm = norm
else:
self.norm = nn.Identity()
# temperature annealing parameters
self.N = N
self.r = r
self.time_step = 0
def step(self):
self.time_step += 1
def _tau(self):
return max(.5, math.exp(-self.r*math.floor(self.time_step/self.N)))
def forward(self, x):
l = self.inner(x)
l = self.norm(l)
# Change p to double so that gumbel_softmax func works
delta = 1e-5
p = torch.clamp(torch.sigmoid(l).double(), min=delta, max=1-delta)
o = self.sample(p)
# Change output back to float for the next layer's input
return o.float()
def sample(self, p):
if self.training:
# sample relaxed bernoulli dist
return self._gumbel_softmax(p)
else:
return torch.bernoulli(p).to(self.device)
def _gumbel_softmax(self, p):
temp = self._tau()
y1 = exp(( log(p) + self._sample_gumbel_dist(p.shape) ) / temp)
sum_all = y1 + exp(( log(1-p) + self._sample_gumbel_dist(p.shape) ) / temp)
return y1 / sum_all
def _sample_gumbel_dist(self, input_size):
if self.device == torch.device('cpu'):
u = torch.FloatTensor(input_size).uniform_()
else:
u = torch.FloatTensor(input_size).to(self.device).uniform_()
return -log(-log(u))
class Linear(_GumbelLayer):
def __init__(self, input_dim, output_dim, device, norm, N=500, r=1e-5):
inner = nn.Linear(input_dim, output_dim, bias=False)
norm_obj = nn.BatchNorm1d(output_dim) if norm else None
super(Linear, self).__init__(inner, N, r, device, norm_obj)
class Conv2d(_GumbelLayer):
def __init__(self, inc, outc, kernel, device, norm, N=500, r=1e-5, **kwargs):
inner = nn.Conv2d(inc, outc, kernel, **kwargs)
norm_obj = nn.BatchNorm2d(outc) if norm else None
super(Conv2d, self).__init__(inner, N, r, device, norm_obj)