-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathutils.py
94 lines (80 loc) · 2.96 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
'''
Utilities
2020-11-17 first created
'''
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import os
import numpy as np
from sklearn.preprocessing import StandardScaler
from time import time, strftime, gmtime
import tensorflow as tf
tfk = tf.keras
tfkc = tfk.callbacks
class NBatchLogger(tfkc.Callback):
'''A Logger that logs the average performance per `display` steps.
See: https://gist.github.com/jaekookang/7e2ca4dc2b1ab10dbb80b9e65ca91179
'''
def __init__(self, n_display, max_epoch, save_dir=None, suffix=None, silent=False):
self.epoch = 0
self.display = n_display
self.max_epoch = max_epoch
self.logs = {}
self.save_dir = save_dir
self.silent = silent
if self.save_dir is not None:
assert os.path.exists(self.save_dir), Exception(
f'Path:{self.save_dir} does not exist!')
fname = 'train.log'
if suffix is not None:
fname = f'train_{suffix}.log'
self.fid = open(os.path.join(save_dir, fname), 'w')
self.t0 = time()
def on_train_begin(self, logs={}):
logs = logs or self.logs
txt = f'=== Started at {self.get_time()} ==='
self.write_log(txt)
if not self.silent:
print(txt)
def on_epoch_end(self, epoch, logs={}):
self.epoch += 1
fstr = ' {} | Epoch: {:0{}d}/{:0{}d} | '
precision = len(str(self.max_epoch))
if (self.epoch % self.display == 0) | (self.epoch == 1):
txt = fstr.format(self.get_time(), self.epoch, precision, self.max_epoch, precision)
# txt = f' {self.get_time()} | Epoch: {self.epoch}/{self.max_epoch} | '
if not self.silent:
print(txt, end='')
for i, key in enumerate(logs.keys()):
if (i+1) == len(logs.keys()):
_txt = f'{key}={logs[key]:4f}'
if not self.silent:
print(_txt, end='\n')
else:
_txt = f'{key}={logs[key]:4f} '
if not self.silent:
print(_txt, end='')
txt = txt + _txt
self.write_log(txt)
self.logs = logs
def on_train_end(self, logs={}):
logs = logs or self.logs
t1 = time()
txt = f'=== Time elapsed: {(t1-self.t0)/60:.4f} min ==='
if not self.silent:
print(txt)
self.write_log(txt)
def get_time(self):
return strftime('%Y-%m-%d %Hh:%Mm:%Ss', gmtime())
def write_log(self, txt):
if self.save_dir is not None:
self.fid.write(txt+'\n')
self.fid.flush()
class UpdateLossFactor(tfkc.Callback):
def __init__(self, n_epochs):
super(UpdateLossFactor, self).__init__()
self.n_epochs = n_epochs
def on_epoch_end(self, epoch, logs={}):
self.model.loss_factor = min(
1., 2. * 0.002**(1. - (float(epoch) / self.n_epochs)))