-
Notifications
You must be signed in to change notification settings - Fork 0
/
CudaAdapter.lua
65 lines (56 loc) · 2.3 KB
/
CudaAdapter.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
require 'nn'
require 'strict'
--------------------------------------- CudaAdapter---------------------------------------
-- Enables to run CPU-only modules in CUDA models and vice versa.
--
-- CudaAdapter keeps its output and gradInput tensors (which are public) always in the model's type.
-- If the type of the adaptee's tensors differ, a conversion is performed. Note that the conversion
-- is of course expensive and causes a major delay (although implemented efficiently).
local CudaAdapter, parent = torch.class('myrock.CudaAdapter', 'nn.Container')
function CudaAdapter:__init(module)
assert(module ~= nil)
parent.__init(self)
self.modules[1] = module
self.inputF = torch.Tensor()
self.gradOutputF = torch.Tensor()
end
function CudaAdapter:updateOutput(input)
if input:type()==self.modules[1].output:type() then
self.output = self.modules[1]:updateOutput(input)
else
self.inputF = self.inputF:typeAs(self.modules[1].output)
self:convert(self.inputF, input)
self:convert(self.output, self.modules[1]:updateOutput(self.inputF))
end
return self.output
end
function CudaAdapter:updateGradInput(input, gradOutput)
if input:type()==self.modules[1].output:type() then
self.gradInput = self.modules[1]:updateGradInput(input, gradOutput)
else
self.gradOutputF = self.gradOutputF:typeAs(self.modules[1].output)
self:convert(self.gradOutputF, gradOutput)
self:convert(self.gradInput, self.modules[1]:updateGradInput(self.inputF, self.gradOutputF))
end
return self.gradInput
end
function CudaAdapter:convert(out, x)
assert(out~=nil and x ~= nil)
if (torch.isTensor(x)) then
out:resize(x:size()):copy(x)
elseif (torch.type(x) == 'table') then
for k,v in pairs(x) do
if (out[k]~=nil) then out[k]:resize(x[k]:size()):copy(x[k]) end
end
else
error('CudaAdapter: unknown type ' .. torch.type(x))
end
end
function CudaAdapter:type(type, tensorCache)
self.output = torch.Tensor():type(type, tensorCache)
self.gradInput = torch.Tensor():type(type, tensorCache)
return self --the internal module has been spared of type conversion
end
function CudaAdapter:__tostring__()
return 'myrock.CudaAdapter' .. ' {' .. tostring(self.modules[1]) .. '}'
end