-
Notifications
You must be signed in to change notification settings - Fork 9
/
raccoon_extensions.py
172 lines (133 loc) · 6.07 KB
/
raccoon_extensions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import theano
import theano.tensor as T
import numpy as np
from raccoon import Extension
from raccoon.extensions import Saver, ValidationMonitor
from data import char2int
from utilities import plot_seq_pt, plot_generated_sequences
floatX = theano.config.floatX
class Sampler(Extension):
"""Extension to pickle objects.
Only the compute_object method should be overwritten.
"""
def __init__(self, name_extension, freq, folder_path, file_name,
fun_pred, n_hidden, apply_at_the_start=True,
apply_at_the_end=True, n_samples=4):
super(Sampler, self).__init__(name_extension, freq,
apply_at_the_end=apply_at_the_end,
apply_at_the_start=apply_at_the_start)
self.folder_path = folder_path
self.file_name = file_name
self.fun_pred = fun_pred
self.n_samples = n_samples
self.n_hidden = n_hidden
def execute_virtual(self, batch_id, epoch_id=None):
sample = self.fun_pred(np.zeros((self.n_samples, 3), floatX),
np.zeros((self.n_samples, self.n_hidden), floatX))
plot_seq_pt(sample,
folder_path=self.folder_path,
file_name='{}_'.format(batch_id) + self.file_name)
return ['executed']
class SamplerCond(Extension):
"""Extension to pickle objects.
Only the compute_object method should be overwritten.
"""
def __init__(self, name_extension, freq, folder_path, file_name,
model, f_sampling, sample_strings, dict_char2int,
bias_value=0.5,
apply_at_the_start=True, apply_at_the_end=True):
super(SamplerCond, self).__init__(name_extension, freq,
apply_at_the_end=apply_at_the_end,
apply_at_the_start=apply_at_the_start)
self.folder_path = folder_path
self.file_name = file_name
self.sample_strings = [s + ' ' for s in sample_strings]
n_samples = len(sample_strings)
self.dict_char2int = dict_char2int
self.f_sampling = f_sampling
self.bias_value = bias_value
# Initial values
self.pt_ini_mat = np.zeros((n_samples, 3), floatX)
self.h_ini_mat = np.zeros((n_samples, model.n_hidden), floatX)
self.k_ini_mat = np.zeros((n_samples, model.n_mixt_attention), floatX)
self.w_ini_mat = np.zeros((n_samples, model.n_chars), floatX)
def execute_virtual(self, batch_id, epoch_id=None):
cond, cond_mask = char2int(self.sample_strings, self.dict_char2int)
pt_gen, a_gen, k_gen, p_gen, w_gen, mask_gen = self.f_sampling(
self.pt_ini_mat, cond, cond_mask,
self.h_ini_mat, self.k_ini_mat, self.w_ini_mat, self.bias_value)
# plot_seq_pt(pt_gen,
# folder_path=self.folder_path,
# file_name='{}_'.format(batch_id) + self.file_name)
p_gen = np.swapaxes(p_gen, 1, 2)
mats = [(a_gen, 'alpha'), (k_gen, 'kapa'), (p_gen, 'phi'),
(w_gen, 'omega')]
plot_generated_sequences(
pt_gen, mats,
mask_gen, folder_path=self.folder_path,
file_name='{}_'.format(batch_id) + self.file_name)
return ['executed']
class SamplingFunctionSaver(Saver):
def __init__(self, monitor, var, freq, folder_path, file_name,
model, f_sampling, dict_char2int, **kwargs):
Saver.__init__(self, 'Sampling function saver', freq, folder_path,
file_name, apply_at_the_end=False, **kwargs)
self.val_monitor = monitor
# Index of the variable to check in the monitoring extension
self.var_idx = monitor.output_links[var][0]
self.best_value = np.inf
self.model = model
self.f_sampling = f_sampling
self.dict_char2int = dict_char2int
def condition(self, batch_id, epoch_id):
return True
# if not self.val_monitor.history:
# return False
# current_value = self.val_monitor.history[-1][self.var_idx]
# if current_value < self.best_value:
# self.best_value = current_value
# return True
# return False
def compute_object(self):
return (self.model, self.f_sampling, self.dict_char2int), \
['extension executed']
def finish(self, bath_id, epoch_id):
return -1, ['not executed at the end']
class ValMonitorHandwriting(ValidationMonitor):
"""
Extension to monitor tensor variables and MonitoredQuantity objects on an
external fuel stream.
"""
def __init__(self, name_extension, freq, inputs, monitored_variables,
stream, updates, model, h_ini, k_ini, w_ini, batch_size,
**kwargs):
ValidationMonitor.__init__(self, name_extension, freq, inputs,
monitored_variables, stream,
updates=updates, **kwargs)
self.stream = stream
self.model = model
self.h_ini = h_ini
self.k_ini = k_ini
self.w_ini = w_ini
self.var = [h_ini, k_ini, w_ini]
self.batch_size = batch_size
def compute_metrics(self):
# Save current state
previous_states = [v.get_value() for v in self.var]
self.model.reset_shared_init_states(
self.h_ini, self.k_ini, self.w_ini, self.batch_size)
metric_values = np.zeros(self.n_outputs, dtype=floatX)
counter_values = np.zeros(self.n_outputs, dtype=floatX)
for inputs, signal in self.stream():
m_values, c_values = self.compute_metrics_minibatch(*inputs)
metric_values += m_values
counter_values += c_values
if signal:
self.model.reset_shared_init_states(
self.h_ini, self.k_ini, self.w_ini, self.batch_size)
metric_values /= counter_values
# restore states
for s, v in zip(previous_states, self.var):
v.set_value(s)
return metric_values