-
Notifications
You must be signed in to change notification settings - Fork 0
/
samplewithpredictions.lua
43 lines (35 loc) · 976 Bytes
/
samplewithpredictions.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
require 'torch'
require 'nn'
require 'LanguageModel'
local cmd = torch.CmdLine()
cmd:option('-checkpoint', 'cv/checkpoint_4000.t7')
cmd:option('-length', 2000)
cmd:option('-start_text', '')
cmd:option('-sample', 1)
cmd:option('-temperature', 1)
cmd:option('-gpu', 0)
cmd:option('-gpu_backend', 'cuda')
cmd:option('-verbose', 0)
local opt = cmd:parse(arg)
local checkpoint = torch.load(opt.checkpoint)
local model = checkpoint.model
print(model);
local msg
if opt.gpu >= 0 and opt.gpu_backend == 'cuda' then
require 'cutorch'
require 'cunn'
cutorch.setDevice(opt.gpu + 1)
model:cuda()
msg = string.format('Running with CUDA on GPU %d', opt.gpu)
elseif opt.gpu >= 0 and opt.gpu_backend == 'opencl' then
require 'cltorch'
require 'clnn'
model:cl()
msg = string.format('Running with OpenCL on GPU %d', opt.gpu)
else
msg = 'Running in CPU mode'
end
if opt.verbose == 1 then print(msg) end
model:evaluate()
local sample = model:sample(opt)
print(sample)