forked from protonx-tf-03-projects/GRU
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
130 lines (105 loc) · 5.07 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
from argparse import ArgumentParser
import tensorflow as tf
from model.gru_rnn import GRU_RNN
from model.lstm_rnn import LSTM_RNN
from model.tanh_rnn import Tanh_RNN
from data import Dataset
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
if __name__ == "__main__":
tf.keras.backend.clear_session()
parser = ArgumentParser()
parser.add_argument("--logdir", default="logs")
home_dir = os.getcwd()
# Arguments users used when running command lines
parser.add_argument(
"--model-folder", default='{}/tmp/model/'.format(home_dir), type=str)
parser.add_argument(
"--checkpoint-folder", default='{}/tmp/checkpoints/'.format(home_dir), type=str)
parser.add_argument(
"--vocab-folder", default='{}/tmp/saved_vocab/'.format(home_dir), type=str)
parser.add_argument("--data-path", default='data/IMDB_Dataset.csv', type=str)
parser.add_argument("--data-name", default='review', type=str)
parser.add_argument("--label-name", default='sentiment', type=str)
parser.add_argument(
"--data-classes", default={'negative': 0, 'positive': 1}, type=set)
parser.add_argument("--num-class", default=2, type=int)
parser.add_argument("--model", default='gru', type=str)
parser.add_argument("--units", default=128, type=int)
parser.add_argument("--embedding-size", default=128, type=int)
parser.add_argument("--vocab-size", default=10000, type=int)
parser.add_argument("--max-length", default=256, type=int)
parser.add_argument("--learning-rate", default=0.0008, type=float)
parser.add_argument("--optimizer", default='rmsprop', type=str)
parser.add_argument("--test-size", default=0.2, type=float)
parser.add_argument("--batch-size", default=32, type=int)
parser.add_argument("--buffer-size", default=128, type=int)
parser.add_argument("--epochs", default=20, type=int)
args = parser.parse_args()
# Project Description
print(' ')
print('---------------------Welcome to GRU Team | TF03 | ProtonX-------------------')
print('Github: joeeislovely | anhdungpro97 | ttduongtran')
print('---------------------------------------------------------------------')
print(f'Training {args.model.upper()} model with hyper-params:')
print('===========================')
# print arguments
for i, arg in enumerate(vars(args)):
print('{}. {}: {}'.format(i, arg, vars(args)[arg]))
print('===========================')
# Prepair dataset
dataset = Dataset(args.data_path, args.vocab_size,
args.data_classes, args.vocab_folder)
train_ds, val_ds = dataset.build_dataset(
args.max_length, args.test_size, args.buffer_size, args.batch_size, args.data_name, args.label_name)
sentences_tokenizer = dataset.sentences_tokenizer
sentences_tokenizer_size = len(sentences_tokenizer.word_counts) + 1
# Initializing variables
input_length = args.max_length
# Initializing model
if args.model == 'lstm':
model = LSTM_RNN(args.units, args.embedding_size,
sentences_tokenizer_size, input_length, num_class=args.num_class)
elif args.model == 'tanh':
model = Tanh_RNN(args.units, args.embedding_size,
sentences_tokenizer_size, input_length, num_class=args.num_class)
else:
model = GRU_RNN(args.units, args.embedding_size,
sentences_tokenizer_size, input_length, num_class=args.num_class)
# Set up loss function
losses = tf.keras.losses.CategoricalCrossentropy(
name="categorical_crossentropy")
# Optimizer Definition
if args.optimizer == 'rmsprop':
optimizer = tf.keras.optimizers.RMSprop(
learning_rate=args.learning_rate, name='rmsprop')
else:
optimizer = tf.keras.optimizers.Adam(
learning_rate=args.learning_rate, name='adam')
# Compile optimizer and loss function into model
metrics = ['accuracy', 'mse']
model.compile(optimizer=optimizer,
loss=losses, metrics=metrics)
# model.summary()
# Callbacks: Early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=2)
# Callbacks: checkpoint training
# include the epoch in the file name
checkpoint_path = "tmp/checkpoints/cp-{epoch:04d}.ckpt/"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Save weights, every 4-epochs.
checkpoint = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, verbose=1, save_weights_only=True, period=5)
# Training model
model.fit(train_ds, epochs=args.epochs,
batch_size=args.batch_size,
validation_data=val_ds,
verbose=1,
callbacks=[checkpoint])
# Saving model
model.save(f"{args.model_folder}/{args.model}.h5py")
# Do Prediction
# print('==============Evaluate=============')
# model.evaluate(val_ds, batch_size=128)