-
Notifications
You must be signed in to change notification settings - Fork 1
/
model_training.py
64 lines (58 loc) · 2.53 KB
/
model_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import tensorflow as tf
import numpy as np
from model.dsp_model import network
from metrics import metrics
from random import shuffle
def mdr_model(train_in, train_out,
test_in, test_out,
train_profile, test_profile,
configs):
nnnet = network(configs)
################
# train and test
loss_list = []
with tf.Session() as sess:
batchgen = Batch(train_in, train_out, train_profile)
sess.run(tf.global_variables_initializer())
for ep in range(configs["epoches"]):
x, y, prof_batch = batchgen.next(configs["batch_size"])
fd = {nnnet["flow_input"]: x[:,:,:configs["flow_features"]],
nnnet["stock_input"]: x[:,:,configs["flow_features"]:],
nnnet["profile_input"]: prof_batch,
nnnet["decoder_targets"]: y}
_, runloss, train_res = sess.run([nnnet["opt"], nnnet["loss"],
nnnet["train_res"]], fd)
loss_list.append(runloss)
############################## test
testdata = Batch(test_in, test_out, test_profile)
ground_truth = []
predictions = []
for _ in range(int(test_in.shape[0] / configs["batch_size"])):
x_test, y_test, prof_test = testdata.next(configs["batch_size"])
fd = {nnnet["flow_input"]:x_test[:,:,:configs["flow_features"]],
nnnet["stock_input"]:x_test[:,:,configs["flow_features"]:],
nnnet["profile_input"]: prof_test,
nnnet["decoder_targets"]: y_test}
ddd, alpha_res = sess.run([nnnet["predictions"], nnnet["alphas"]], fd)
ground_truth.append(y_test)
predictions.append(ddd)
ground_truth = np.concatenate(ground_truth)
predictions = np.concatenate(predictions)
return metrics(ground_truth, predictions)
class Batch:
def __init__(self, en_data, de_data, profile):
self.en_data = en_data
self.de_data = de_data
self.profile = profile
self.cursor = 0
self.dlen = en_data.shape[0]
self.id_list = list(range(self.dlen))
def next(self, batch_size):
if batch_size > self.dlen:
raise ValueError("batch size larger than data length.")
if (self.cursor + batch_size) > self.dlen:
self.cursor=0
shuffle(self.id_list)
idx = self.id_list[self.cursor: (self.cursor + batch_size)]
self.cursor += batch_size
return self.en_data[idx], self.de_data[idx], self.profile[idx]