-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmodel_train.py
93 lines (73 loc) · 2.56 KB
/
model_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# encoding=utf-8
import os
import time
import re
import math
import numpy as np
import pandas
import tensorflow as tf
# tf.enable_eager_execution()
# this is my custom library file that contains some useful methods
# and some methods for creating different model architectures
import utils
import kanji_label_dict as kld
class MyClassback(tf.keras.callbacks.Callback):
def __init__(self, prefix: str, save_dir: str):
# append `prefix` to model weights to
self.prefix = prefix
self.save_dir = save_dir
self.last_acc = None
# model weights will be saved every `save_iter` epoches
self.save_iter = 5
def on_epoch_end(self, epoch, logs={}):
acc = logs.get('acc')
if self.last_acc is None:
self.last_acc = acc
else:
if self.last_acc > acc:
self.save_model_weights(epoch, acc)
self.last_acc = acc
if acc >= 0.99:
self.model.stop_training = True
self.save_model_weights(epoch, acc)
elif epoch % self.save_iter == 0:
self.save_model_weights(epoch, acc)
def save_model_weights(self, epoch, acc):
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
model_weights_filename = f'{self.save_dir}/{self.prefix}_weights_epoch_{str(epoch).zfill(2)}_acc_{acc:.2f}_{utils.time_now()}.h5'
self.model.save_weights(model_weights_filename)
if __name__ == "__main__":
model = utils.kanji_model_v3()
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
)
model.summary()
# the number of samples is 256776 (the last time I keep track of it)
# full-sized `buffer_size` to be well shuffled
buffer_size = 256776
# small `buffer_size` for faster shuffling
# buffer_size = 256
batch_size = 256
steps_per_epoch = math.ceil(buffer_size / batch_size)
tfrecord_filename = 'kanji_dataset.tfrecord'
ds = utils.load_tfrecord(tfrecord_filename)
ds = ds.cache()
ds = ds.apply(tf.data.experimental.shuffle_and_repeat(
buffer_size=buffer_size
))
ds = ds.batch(batch_size)
save_dir = 'kanji_model_v3'
prefix = 'kanji_model_v3'
callback = MyClassback(prefix=prefix, save_dir=save_dir)
# callback.save_iter = 5
histories = model.fit(
ds,
epochs=20,
steps_per_epoch=steps_per_epoch,
callbacks=[callback],
)
model_filename = f'{save_dir}/{prefix}_model_{utils.time_now()}.h5'
model.save(model_filename)