forked from jonnyli1125/gector-ja
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
118 lines (106 loc) · 5.07 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import os
import json
import tensorflow as tf
from tensorflow import keras
import numpy as np
from transformers import AdamWeightDecay
from sklearn.metrics import classification_report
from model import GEC
from utils.helpers import read_dataset, WeightedSCCE
AUTO = tf.data.AUTOTUNE
def train(corpora_dir, output_weights_path, vocab_dir, transforms_file,
pretrained_weights_path, batch_size, n_epochs, dev_ratio, dataset_len,
dataset_ratio, bert_trainable, learning_rate, class_weight_path,
filename='edit_tagged_sentences.tfrec.gz'):
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
print('TPUs: ', tf.config.list_logical_devices('TPU'))
except (ValueError, KeyError) as e:
tpu = None
files = [os.path.join(root, filename)
for root, dirs, files in tf.io.gfile.walk(corpora_dir)
if filename in files]
dataset = read_dataset(files).shuffle(buffer_size=1024)
if dataset_len:
dataset_card = tf.data.experimental.assert_cardinality(dataset_len)
dataset = dataset.apply(dataset_card)
if 0 < dataset_ratio < 1:
dataset_len = int(dataset_len * dataset_ratio)
dataset = dataset.take(dataset_len)
print(dataset, dataset.cardinality().numpy())
print('Loaded dataset')
dev_len = int(dataset_len * dev_ratio)
train_set = dataset.skip(dev_len).prefetch(AUTO)
dev_set = dataset.take(dev_len).prefetch(AUTO)
print(train_set.cardinality().numpy(), dev_set.cardinality().numpy())
print(f'Using {dev_ratio} of dataset for dev set')
train_set = train_set.batch(batch_size, num_parallel_calls=AUTO)
dev_set = dev_set.batch(batch_size, num_parallel_calls=AUTO)
if tpu:
strategy = tf.distribute.TPUStrategy(tpu)
else:
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
gec = GEC(vocab_path=vocab_dir, verb_adj_forms_path=transforms_file,
pretrained_weights_path=pretrained_weights_path,
bert_trainable=bert_trainable, learning_rate=learning_rate)
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=output_weights_path + '_checkpoint',
save_weights_only=True,
monitor='val_labels_probs_sparse_categorical_accuracy',
mode='max',
save_best_only=True)
early_stopping_callback = keras.callbacks.EarlyStopping(
monitor='loss', patience=3)
gec.model.fit(train_set, epochs=n_epochs, validation_data=dev_set,
callbacks=[model_checkpoint_callback, early_stopping_callback])
gec.model.save_weights(output_weights_path)
def main(args):
train(args.corpora_dir, args.output_weights_path, args.vocab_dir,
args.transforms_file, args.pretrained_weights_path, args.batch_size,
args.n_epochs, args.dev_ratio, args.dataset_len, args.dataset_ratio,
args.bert_trainable, args.learning_rate, args.class_weight_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--corpora_dir',
help='Path to dataset folder',
required=True)
parser.add_argument('-o', '--output_weights_path',
help='Path to save model weights to',
required=True)
parser.add_argument('-v', '--vocab_dir',
help='Path to output vocab folder',
default='./data/output_vocab')
parser.add_argument('-t', '--transforms_file',
help='Path to verb/adj transforms file',
default='./data/transform.txt')
parser.add_argument('-p', '--pretrained_weights_path',
help='Path to pretrained model weights')
parser.add_argument('-b', '--batch_size', type=int,
help='Number of samples per batch',
default=32)
parser.add_argument('-e', '--n_epochs', type=int,
help='Number of epochs',
default=10)
parser.add_argument('-d', '--dev_ratio', type=float,
help='Percent of whole dataset to use for dev set',
default=0.01)
parser.add_argument('-l', '--dataset_len', type=int,
help='Cardinality of dataset')
parser.add_argument('-r', '--dataset_ratio', type=float,
help='Percent of whole dataset to use',
default=1.0)
parser.add_argument('-bt', '--bert_trainable',
help='Enable training for BERT encoder layers',
action='store_true')
parser.add_argument('-lr', '--learning_rate', type=float,
help='Learning rate',
default=1e-5)
parser.add_argument('-cw', '--class_weight_path',
help='Path to class weight file')
args = parser.parse_args()
main(args)