forked from Moodstocks/gtsrb
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.lua
72 lines (62 loc) · 2.42 KB
/
dataset.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
----------------------------------------------------------------------
-- Use this script to define the training and validating datasets.
--
-- A dataset is an object which implements the operator dataset[index]
-- and the method dataset:size(). The size() methods returns the number
-- of examples and dataset[i] has to return the i-th example.
--
-- An example has to be an object which implements the operator
-- example[field], where field might take the value 1 (input features)
-- or 2 (corresponding label which will be given to the criterion)
--
-- Hugo Duthil
----------------------------------------------------------------------
-- Each data structure is a table of 2-field elements :
-- + [1] : a 3D tensor {channel, x, y} representing a 32x32 YUV image
-- + [2] : image label
require 'torch'
require 'paths'
local script_dir = paths.dirname(paths.thisfile()).."/"
train_file = script_dir..params.train_set -- path to the training set
test_file = script_dir..params.test_set -- path to the test set
pp_train_file = script_dir..params.pp_train_set -- path to the training set
pp_test_file = script_dir..params.pp_test_set -- path to the training set
use_pp_sets = params.use_pp_sets -- load already preprocessed sets
-- Set the default type of Tensor to float
torch.setdefaulttensortype('torch.FloatTensor')
-- if we use already preprocessed data sets
if use_pp_sets then
tr_file = pp_train_file
ts_file = pp_test_file
else
tr_file = train_file
ts_file = test_file
end
-- check if train set already exists
if paths.filep(tr_file) then
if not train_set then
print("\nLoading training set")
train_set = torch.load(tr_file)
function train_set:size() return #train_set end
print(tr_file)
print("Training set loaded")
else
print("\nTraining set already loaded")
end
else
print("\nNo training set found")
end
if paths.filep(ts_file) then
-- check if test set already exists
if not test_set then
print("\nLoading test set")
test_set = torch.load(ts_file)
function test_set:size() return #test_set end
print(ts_file)
print("Test set loaded")
else
print("\nTest set already loaded")
end
else
print("\nNo test set found")
end