-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathrbm.py
108 lines (85 loc) · 3.46 KB
/
rbm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Implementation of a Restricted Boltzmann Machine"""
import torch
from utils import *
class RBM():
"""Implementation of a Restricted Boltzmann Machine
Note that this implementation does not use Pytorch's nn.Module
because we are updating the weights ourselves
"""
def __init__(self, visible_dim, hidden_dim, gaussian_hidden_distribution=False):
"""Initialize a Restricted Boltzmann Machine
Parameters
----------
visible_dim: int
number of dimensions in visible (input) layer
hidden_dim: int
number of dimensions in hidden layer
gaussian_hidden_distribution: bool
whether to use a Gaussian distribution for the values of the hidden dimension instead of a Bernoulli
"""
self.visible_dim = visible_dim
self.hidden_dim = hidden_dim
self.gaussian_hidden_distribution = gaussian_hidden_distribution
# intialize parameters
self.W = torch.randn(visible_dim, hidden_dim).to(DEVICE) * 0.1
self.h_bias = torch.zeros(hidden_dim).to(DEVICE) # v --> h
self.v_bias = torch.zeros(visible_dim).to(DEVICE) # h --> v
# parameters for learning with momentum
self.W_momentum = torch.zeros(visible_dim, hidden_dim).to(DEVICE)
self.h_bias_momentum = torch.zeros(hidden_dim).to(DEVICE) # v --> h
self.v_bias_momentum = torch.zeros(visible_dim).to(DEVICE) # h --> v
def sample_h(self, v):
"""Get sample hidden values and activation probabilities
Parameters
----------
v: Tensor
tensor of input from visible layer
"""
activation = torch.mm(v, self.W) + self.h_bias
if self.gaussian_hidden_distribution:
return activation, torch.normal(activation, torch.tensor([1]).to(DEVICE))
else:
p = torch.sigmoid(activation)
return p, torch.bernoulli(p)
def sample_v(self, h):
"""Get visible activation probabilities
Parameters
----------
h: Tensor
tensor of input from hidden
"""
activation = torch.mm(h, self.W.t()) + self.v_bias
p = torch.sigmoid(activation)
return p
def update_weights(self, v0, vk, ph0, phk, lr, momentum_coef, weight_decay, batch_size):
"""Learning step: update parameters
Uses contrastive divergence algorithm as described in
Parameters
----------
v0: Tensor
initial visible state
vk: Tensor
final visible state
ph0: Tensor
hidden activation probabilities for v0
phk: Tensor
hidden activation probabilities for vk
lr: float
learning rate
momentum_coef: float
coefficient to use for momentum
weight_decay: float
coefficient to use for weight decay
batch_size: int
size of each batch
"""
self.W_momentum *= momentum_coef
self.W_momentum += torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)
self.h_bias_momentum *= momentum_coef
self.h_bias_momentum += torch.sum((ph0 - phk), 0)
self.v_bias_momentum *= momentum_coef
self.v_bias_momentum += torch.sum((v0 - vk), 0)
self.W += lr*self.W_momentum/batch_size
self.h_bias += lr*self.h_bias_momentum/batch_size
self.v_bias += lr*self.v_bias_momentum/batch_size
self.W -= self.W * weight_decay # L2 weight decay