-
Notifications
You must be signed in to change notification settings - Fork 1
/
svrtset.py
executable file
·156 lines (121 loc) · 5.43 KB
/
svrtset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# svrt is the ``Synthetic Visual Reasoning Test'', an image
# generator for evaluating classification performance of machine
# learning systems, humans and primates.
#
# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/
# Written by Francois Fleuret <[email protected]>
#
# This file is part of svrt.
#
# svrt is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.
#
# svrt is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with svrt. If not, see <http://www.gnu.org/licenses/>.
import torch
from math import sqrt
from torch import multiprocessing
from torch import Tensor
from torch.autograd import Variable
import svrt
# FIXME
import resource
######################################################################
def generate_one_batch(s):
problem_number, batch_size, random_seed = s
svrt.seed(random_seed)
target = torch.LongTensor(batch_size).bernoulli_(0.5)
input = svrt.generate_vignettes(problem_number, target)
input = input.float().view(input.size(0), 1, input.size(1), input.size(2))
return [ input, target ]
class VignetteSet:
def __init__(self, problem_number, nb_samples, batch_size, cuda = False, logger = None):
if nb_samples%batch_size > 0:
print('nb_samples must be a multiple of batch_size')
raise
self.cuda = cuda
self.problem_number = problem_number
self.batch_size = batch_size
self.nb_samples = nb_samples
self.nb_batches = self.nb_samples // self.batch_size
seeds = torch.LongTensor(self.nb_batches).random_()
mp_args = []
for b in range(0, self.nb_batches):
mp_args.append( [ problem_number, batch_size, seeds[b] ])
self.data = []
for b in range(0, self.nb_batches):
self.data.append(generate_one_batch(mp_args[b]))
if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)
# Weird thing going on with the multi-processing, waiting for more info
# pool = multiprocessing.Pool(multiprocessing.cpu_count())
# self.data = pool.map(generate_one_batch, mp_args)
acc = 0.0
acc_sq = 0.0
for b in range(0, self.nb_batches):
input = self.data[b][0]
acc += input.sum() / input.numel()
acc_sq += input.pow(2).sum() / input.numel()
mean = acc / self.nb_batches
std = sqrt(acc_sq / self.nb_batches - mean * mean)
for b in range(0, self.nb_batches):
self.data[b][0].sub_(mean).div_(std)
if cuda:
self.data[b][0] = self.data[b][0].cuda()
self.data[b][1] = self.data[b][1].cuda()
def get_batch(self, b):
return self.data[b]
######################################################################
class CompressedVignetteSet:
def __init__(self, problem_number, nb_samples, batch_size, cuda = False, logger = None):
if nb_samples%batch_size > 0:
print('nb_samples must be a multiple of batch_size')
raise
self.cuda = cuda
self.problem_number = problem_number
self.batch_size = batch_size
self.nb_samples = nb_samples
self.nb_batches = self.nb_samples // self.batch_size
self.targets = []
self.input_storages = []
acc = 0.0
acc_sq = 0.0
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
for b in range(0, self.nb_batches):
target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
input = svrt.generate_vignettes(problem_number, target)
# FIXME input_as_float should not be necessary but there
# are weird memory leaks going on, which do not seem to be
# my fault
if b == 0:
input_as_float = input.float()
else:
input_as_float.copy_(input)
acc += input_as_float.sum() / input.numel()
acc_sq += input_as_float.pow(2).sum() / input.numel()
self.targets.append(target)
self.input_storages.append(svrt.compress(input.storage()))
if logger is not None: logger(self.nb_batches * self.batch_size, b * self.batch_size)
# FIXME
if resource.getrusage(resource.RUSAGE_SELF).ru_maxrss > 16e6:
print('Memory leak?!')
raise
mem = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - usage) * 1024
print('Using {:.02f}Gb total {:.02f}b / samples'
.format(mem / (1024 * 1024 * 1024), mem / self.nb_samples))
self.mean = acc / self.nb_batches
self.std = sqrt(acc_sq / self.nb_batches - self.mean * self.mean)
def get_batch(self, b):
input = torch.ByteTensor(svrt.uncompress(self.input_storages[b])).float()
input = input.view(self.batch_size, 1, 128, 128).sub_(self.mean).div_(self.std)
target = self.targets[b]
if self.cuda:
input = input.cuda()
target = target.cuda()
return input, target
######################################################################