-
Notifications
You must be signed in to change notification settings - Fork 82
/
ReinforceGamma.lua
129 lines (109 loc) · 4.11 KB
/
ReinforceGamma.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
------------------------------------------------------------------------
--[[ ReinforceGamma ]]--
-- Ref A. http://incompleteideas.net/sutton/williams-92.pdf
-- Inputs are shape (k) and scale (theta) of multivariate Gamma distribution.
-- Ouputs are samples drawn from these distributions.
-- Scale is provided as constructor argument.
-- Uses the REINFORCE algorithm (ref. A sec 6. p.237-239) which is
-- implemented through the nn.Module:reinforce(r,b) interface.
-- gradOutputs are ignored (REINFORCE algorithm).
------------------------------------------------------------------------
local ReinforceGamma, parent = torch.class("nn.ReinforceGamma", "nn.Reinforce")
function ReinforceGamma:__init(scale, stochastic)
require('randomkit') -- needed to sample gamma dist : luarocks install randomkit
require('cephes') -- needed to compute digamma for gradient :
parent.__init(self, stochastic)
self.scale = scale
if not scale then
self.gradInput = {torch.Tensor(), torch.Tensor()}
end
end
function ReinforceGamma:updateOutput(input)
local shape, scale = input, self.scale
if torch.type(input) == 'table' then
-- input is {shape, scale}
assert(#input == 2)
shape, scale = unpack(input)
end
assert(scale)
self.output:resizeAs(shape)
if torch.type(scale) == 'number' then
scale = shape.new():resizeAs(shape):fill(scale)
elseif torch.isTensor(scale) then
if scale:dim() == shape:dim() then
assert(scale:isSameSizeAs(shape))
else
assert(scale:dim()+1 == shape:dim())
self._scale = self._scale or scale.new()
self._scale:view(scale,1,table.unpack(scale:size():totable()))
self.__scale = self.__scale or scale.new()
self.__scale:expandAs(self._scale, shape)
scale = self.__scale
end
else
error"unsupported shape type"
end
if self.stochastic or self.train ~= false then
self.output:copy(randomkit.gamma(shape:squeeze():float(),scale:squeeze():float()))
else
-- use maximum a posteriori (MAP) estimate
self.output:copy(shape):cmul(scale)
end
return self.output
end
function ReinforceGamma:updateGradInput(input, gradOutput)
-- Note that gradOutput is ignored
-- f : Gamma probability density function
-- g : Digamma probability density function
-- x : the sampled values (self.output)
-- shape : shape parameter of gamma dist
-- scale: scale parameter of gamma dist
local shape, scale = input, self.scale
local gradShape, gradScale = self.gradInput, nil
if torch.type(input) == 'table' then
shape, scale = unpack(input)
gradShape, gradScale = unpack(self.gradInput)
end
assert(scale)
-- Derivative of log gamma w.r.t. shape :
-- d ln(f(x,shape,scale))
-- ---------------------- = ln(x) - g(shape) - ln(scale)
-- d shape
gradShape:resizeAs(shape)
if torch.type(scale) == 'number' then
scale = shape.new():resizeAs(shape):fill(scale)
else
if not scale:dim() == shape:dim() then
scale:copy(self.__scale)
end
end
gradShape:copy(cephes.digamma(shape:float()))
gradShape:mul(-1)
self._logOutput = self._logOutput or self.output.new()
self._logOutput:log( self.output )
self._logScale = self._logScale or scale.new()
self._logScale:log( scale )
gradShape:add( self._logOutput )
gradShape:add(-1, self._logScale )
-- multiply by variance reduced reward
gradShape:cmul(self:rewardAs(shape) )
-- multiply by -1 ( gradient descent on shape )
gradShape:mul(-1)
-- Derivative of log Gamma w.r.t. scale :
-- d ln(f(x,shape,scale)) x shape
-- ---------------------- = ------- - -----
-- d scale scale^2 scale
if gradScale then
gradScale:resizeAs(scale)
gradScale:copy( torch.cdiv(self.output, torch.pow(scale,2)) )
gradScale:add(-1, torch.cdiv(shape, scale) )
gradScale:cmul( self:rewardAs(scale) )
gradScale:mul(-1)
end
return self.gradInput
end
function ReinforceGamma:type(type,cache)
self._logOutput = nil
self._logScale = nil
return parent.type(self,type,cache)
end