-
Notifications
You must be signed in to change notification settings - Fork 18
/
expert_model.py
1693 lines (1419 loc) · 66.1 KB
/
expert_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# email: [email protected]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# import abc
import collections
import math
import time
import codecs
import sys
import argparse
import os
import random
import six
import json
import numpy as np
import tensorflow as tf
import opennmt as onmt
from tensorflow.python.ops import lookup_ops
from tensorflow.python.client import device_lib
import re
_DIGIT_RE = re.compile(r"\d")
############################## Vocab Utils ##########################################
BLK = "<blank>"
SOS = "<s>"
EOS = "</s>"
UNK = "<unk>"
VOCAB_SIZE_THRESHOLD_CPU = 50000
SHUFFLE_BUFFER_SIZE = 1000000
def print_out(s, f=None, new_line=True):
if isinstance(s, bytes):
s = s.decode("utf-8")
if f:
f.write(s.encode("utf-8"))
if new_line:
f.write(b"\n")
out_s = s.encode("utf-8")
if not isinstance(out_s, str):
out_s = out_s.decode("utf-8")
print(out_s, end="", file=sys.stdout)
if new_line:
sys.stdout.write("\n")
sys.stdout.flush()
def create_vocab_tables(src_vocab_file, tgt_vocab_file, share_vocab, vocab_size):
src_vocab_table = lookup_ops.index_table_from_file(
src_vocab_file, default_value=vocab_size)
if share_vocab:
tgt_vocab_table = src_vocab_table
else:
tgt_vocab_table = lookup_ops.index_table_from_file(
tgt_vocab_file, default_value=vocab_size)
return src_vocab_table, tgt_vocab_table
def load_vocab(vocab_file):
vocab = []
with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f:
vocab_size = 0
for word in f:
vocab_size += 1
vocab.append(word.strip())
return vocab, vocab_size
def check_vocab(vocab_file, out_dir, check_special_token=True, sos=None,
eos=None, blk=None, data_file=None, max_vocabulary_size=50000):
"""Check if vocab_file doesn't exist, create from corpus_file."""
if tf.gfile.Exists(vocab_file):
print_out("# Vocab file %s exists" % vocab_file)
vocab, vocab_size = load_vocab(vocab_file)
if check_special_token:
if not blk: blk = BLK
if not sos: sos = SOS
if not eos: eos = EOS
assert len(vocab) >= 3
if vocab[0] != blk or vocab[1] != sos or vocab[2] != eos:
print_out("The first 3 vocab words [%s, %s, %s]"
" are not [%s, %s, %s]" %
(vocab[0], vocab[1], vocab[2], blk, sos, eos))
vocab = [blk, sos, eos] + vocab
vocab_size += 3
new_vocab_file = os.path.join(out_dir, os.path.basename(vocab_file))
with codecs.getwriter("utf-8")(
tf.gfile.GFile(new_vocab_file, "wb")) as f:
for word in vocab:
f.write("%s\n" % word)
vocab_file = new_vocab_file
else:
vocab, vocab_file = create_vocabulary(vocab_file, data_file, max_vocabulary_size)
vocab_size = len(vocab)
return vocab_size, vocab_file
def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size, normalize_digits=False):
print_out("Creating vocabulary %s from data %s" % (vocabulary_path, data_path))
vocab = {}
with tf.gfile.GFile(data_path, mode="rb") as f:
counter = 0
for line in f:
counter += 1
if counter % 100000 == 0:
print(" processing line %d" % counter)
line = tf.compat.as_bytes(line)
tokens = line.split()
for w in tokens:
word = _DIGIT_RE.sub("0", w) if normalize_digits else w
if word in vocab:
vocab[word] += 1
else:
vocab[word] = 1
vocab_list = [BLK, SOS, EOS] + sorted(vocab, key=vocab.get, reverse=True) + [UNK]
if len(vocab_list) > max_vocabulary_size + 1:
vocab_list = vocab_list[:max_vocabulary_size + 1]
with tf.gfile.GFile(vocabulary_path, mode="wb") as vocab_file:
for w in vocab_list:
vocab_file.write("%s\n" % w)
return vocab_list, vocabulary_path
############################## Iterator #######################################
class BatchedInput(
collections.namedtuple("BatchedInput",
("initializer",
"qe_iterator"))):
pass
def get_infer_iterator_exp(src_dataset,
tgt_dataset,
src_vocab_table,
tgt_vocab_table,
batch_size,
sos,
eos,
src_max_len=None,
tgt_max_len=None):
tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32)
tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32)
qe_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))
qe_dataset = qe_dataset.map(
lambda src, tgt: (
tf.string_split([src]).values,
tf.string_split([tgt]).values))
def batching_func(x):
return x.padded_batch(
batch_size,
padded_shapes=(
tf.TensorShape([None]), # src
tf.TensorShape([None]), # mt
tf.TensorShape([]), # src_len
tf.TensorShape([])), # mt_len
padding_values=(
0, # src_eos_id, # src
0, # tgt_eos_id, # tgt_input
0, # qe_len -- unused
0)) # mt_len -- unused
if src_max_len:
qe_dataset = qe_dataset.map(
lambda src, tgt: (src[:src_max_len], tgt))
if tgt_max_len:
qe_dataset = qe_dataset.map(
lambda src, tgt: (src, tgt[:tgt_max_len]))
qe_dataset = qe_dataset.map(
lambda src, tgt: (
tf.cast(src_vocab_table.lookup(src), tf.int32),
tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)))
qe_dataset = qe_dataset.map(
lambda src, tgt: (
src,
tf.concat(([tgt_sos_id], tgt, [tgt_eos_id]), 0)))
qe_dataset = qe_dataset.map(
lambda src, tgt: (
src,
tgt,
tf.size(src),
tf.size(tgt)))
qe_batched_dataset = batching_func(qe_dataset)
qe_batched_iter = qe_batched_dataset.make_initializable_iterator()
return BatchedInput(
initializer=qe_batched_iter.initializer,
qe_iterator=qe_batched_iter)
def get_iterator_exp(src_dataset,
tgt_dataset,
src_vocab_table,
tgt_vocab_table,
batch_size,
sos,
eos,
random_seed,
bucket_width,
num_gpus,
src_max_len=None,
tgt_max_len=None,
num_parallel_calls=8,
output_buffer_size=None,
skip_count=None,
reshuffle_each_iteration=True,
mode=tf.contrib.learn.ModeKeys.TRAIN):
batch_size = batch_size * num_gpus
def batching_func(x):
return x.padded_batch(
batch_size,
padded_shapes=(
tf.TensorShape([None]), # src
tf.TensorShape([None]), # tgt
tf.TensorShape([]), # src_len
tf.TensorShape([])), # tgt_len
padding_values=(
0, # src_eos_id, # src
0, # tgt_eos_id, # mt
0, # src_len -- unused
0)) # tgt_len -- unused
if not output_buffer_size:
output_buffer_size = 500000
tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32)
tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32)
qe_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))
if skip_count is not None:
qe_dataset = qe_dataset.skip(skip_count)
qe_dataset = qe_dataset.shuffle(SHUFFLE_BUFFER_SIZE, random_seed, reshuffle_each_iteration)
qe_dataset = qe_dataset.map(
lambda src, tgt: (
tf.string_split([src]).values,
tf.string_split([tgt]).values),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
qe_dataset = qe_dataset.filter(
lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0))
if src_max_len:
qe_dataset = qe_dataset.map(
lambda src, tgt: (
src[:src_max_len],
tgt),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
if tgt_max_len:
qe_dataset = qe_dataset.map(
lambda src, tgt: (
src,
tgt[:tgt_max_len]),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
qe_dataset = qe_dataset.map(
lambda src, tgt: (
tf.cast(src_vocab_table.lookup(src), tf.int32),
tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
qe_dataset = qe_dataset.map(
lambda src,tgt: (
src,
tf.concat(([tgt_sos_id], tgt, [tgt_eos_id]), 0)),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
qe_dataset = qe_dataset.map(
lambda src, tgt: (
src,
tgt,
tf.size(src),
tf.size(tgt)),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
if bucket_width > 1 and mode == tf.contrib.learn.ModeKeys.TRAIN:
def key_func(unused_1, unused_2, src_len, tgt_len):
bucket_id = tf.constant(0, dtype=tf.int32)
bucket_id = tf.maximum(bucket_id, src_len // bucket_width)
bucket_id = tf.maximum(bucket_id, tgt_len // bucket_width)
return tf.to_int64(bucket_id) # tf.to_int64(tf.minimum(num_buckets, bucket_id))
def reduce_func(unused_key, windowed_data):
return batching_func(windowed_data)
def window_size_func(key):
if bucket_width > 1:
key += 1 # For bucket_width == 1, key 0 is unassigned.
size = batch_size // (key * bucket_width)
if num_gpus > 1:
size = size + num_gpus - size % num_gpus
return tf.to_int64(tf.maximum(size, num_gpus))
# bucketing for qe data
qe_batched_dataset = qe_dataset.apply(
tf.data.experimental.group_by_window(
key_func=key_func, reduce_func=reduce_func, window_size_func=window_size_func))
else:
qe_batched_dataset = batching_func(qe_dataset)
qe_batched_iter = qe_batched_dataset.make_initializable_iterator()
return BatchedInput(
initializer=qe_batched_iter.initializer,
qe_iterator=qe_batched_iter)
############################## Model helper ###################################
class TrainModel(
collections.namedtuple("TrainModel", ("graph", "model", "iterator",
"skip_count_placeholder"))):
pass
def create_train_model(model_creator, hparams, scope=None):
src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
src_vocab_file = hparams.src_vocab_file
tgt_vocab_file = hparams.tgt_vocab_file
graph = tf.Graph()
with graph.as_default(), tf.container(scope or "train"):
src_vocab_table, tgt_vocab_table = create_vocab_tables(
src_vocab_file, tgt_vocab_file, hparams.share_vocab, hparams.max_vocab_size)
src_dataset = tf.data.TextLineDataset(src_file)
tgt_dataset = tf.data.TextLineDataset(tgt_file)
skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
iterator = get_iterator_exp(
src_dataset,
tgt_dataset,
src_vocab_table,
tgt_vocab_table,
batch_size=hparams.batch_size,
sos=hparams.sos,
eos=hparams.eos,
random_seed=hparams.random_seed,
bucket_width=hparams.bucket_width,
num_gpus=hparams.num_gpus,
src_max_len=hparams.src_max_len,
tgt_max_len=hparams.tgt_max_len,
skip_count=skip_count_placeholder)
model_device_fn = None
with tf.device(model_device_fn):
model = model_creator(
hparams,
iterator=iterator,
mode=tf.contrib.learn.ModeKeys.TRAIN,
scope=scope)
return TrainModel(
graph=graph,
model=model,
iterator=iterator,
skip_count_placeholder=skip_count_placeholder)
class InferModel(
collections.namedtuple("InferModel",
("graph", "model", "src_placeholder", "tgt_placeholder",
"batch_size_placeholder", "iterator"))):
pass
def create_infer_model(model_creator, hparams, scope=None):
graph = tf.Graph()
src_vocab_file = hparams.src_vocab_file
tgt_vocab_file = hparams.tgt_vocab_file
with graph.as_default(), tf.container(scope or "infer"):
src_vocab_table, tgt_vocab_table = create_vocab_tables(
src_vocab_file, tgt_vocab_file, hparams.share_vocab, hparams.max_vocab_size)
reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(tgt_vocab_file, default_value=UNK)
src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
tgt_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)
src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)
tgt_dataset = tf.data.Dataset.from_tensor_slices(tgt_placeholder)
iterator = get_infer_iterator_exp(
src_dataset,
tgt_dataset,
src_vocab_table,
tgt_vocab_table,
hparams.infer_batch_size,
sos=hparams.sos,
eos=hparams.eos,
src_max_len=hparams.src_max_len_infer,
tgt_max_len=hparams.tgt_max_len_infer)
model = model_creator(
hparams,
iterator=iterator,
mode=tf.contrib.learn.ModeKeys.INFER,
reverse_target_vocab_table=reverse_tgt_vocab_table,
scope=scope)
return InferModel(
graph=graph,
model=model,
src_placeholder=src_placeholder,
tgt_placeholder=tgt_placeholder,
batch_size_placeholder=batch_size_placeholder,
iterator=iterator)
def load_model(model, ckpt, session, name):
start_time = time.time()
model.saver.restore(session, ckpt)
session.run(tf.tables_initializer())
print_out(" loaded %s model parameters from %s, time %.2fs" % (name, ckpt, time.time() - start_time))
return model
def create_or_load_model(model, model_dir, session, name, hparams):
latest_ckpt = tf.train.latest_checkpoint(model_dir)
if latest_ckpt:
model = load_model(model, latest_ckpt, session, name)
else:
start_time = time.time()
session.run(tf.global_variables_initializer())
session.run(tf.tables_initializer())
print_out(" created %s model with fresh parameters, time %.2fs" % (name, time.time() - start_time))
global_step = model.global_step.eval(session=session)
skip_count = model.skip_count.eval(session=session)
return model, global_step, skip_count
def gradient_clip(gradients, max_gradient_norm):
"""Clipping gradients of a model."""
clipped_gradients, gradient_norm = tf.clip_by_global_norm(gradients, max_gradient_norm)
gradient_norm_summary = [tf.summary.scalar("grad_norm", gradient_norm)]
gradient_norm_summary.append(tf.summary.scalar("clipped_gradient", tf.global_norm(clipped_gradients)))
return clipped_gradients, gradient_norm_summary, gradient_norm
def avg_checkpoints(model_dir, num_last_checkpoints, global_step,
global_step_name):
checkpoint_state = tf.train.get_checkpoint_state(model_dir)
if not checkpoint_state:
print_out("# No checkpoint file found in directory: %s" % model_dir)
return None
checkpoints = (checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:])
if len(checkpoints) < num_last_checkpoints:
print_out("# Skipping averaging checkpoints because not enough checkpoints is avaliable.")
return None
avg_model_dir = os.path.join(model_dir, "avg_checkpoints")
if not tf.gfile.Exists(avg_model_dir):
print_out("# Creating new directory %s for saving averaged checkpoints." % avg_model_dir)
tf.gfile.MakeDirs(avg_model_dir)
print_out("# Reading and averaging variables in checkpoints:")
var_list = tf.contrib.framework.list_variables(checkpoints[0])
var_values, var_dtypes = {}, {}
for (name, shape) in var_list:
if name != global_step_name:
var_values[name] = np.zeros(shape)
for checkpoint in checkpoints:
print_out(" %s" % checkpoint)
reader = tf.contrib.framework.load_checkpoint(checkpoint)
for name in var_values:
tensor = reader.get_tensor(name)
var_dtypes[name] = tensor.dtype
if not "skip_count" in name:
var_values[name] += tensor
else:
var_values[name] = tensor
for name in var_values:
var_values[name] /= len(checkpoints)
with tf.Graph().as_default():
tf_vars = [tf.get_variable(name, shape=var_values[name].shape, dtype=var_dtypes[name]) for name in var_values]
placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
global_step_var = tf.Variable(global_step, name=global_step_name, trainable=False)
saver = tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for p, assign_op, (name, value) in zip(placeholders, assign_ops, six.iteritems(var_values)):
sess.run(assign_op, {p: value})
saver.save(sess, os.path.join(avg_model_dir, "qe.ckpt"))
return avg_model_dir
def shape_list(x):
"""Return list of dims, statically where possible."""
x = tf.convert_to_tensor(x)
# If unknown rank, return dynamic shape
if x.get_shape().dims is None:
return tf.shape(x)
static = x.get_shape().as_list()
shape = tf.shape(x)
ret = []
for i in range(len(static)):
dim = static[i]
if dim is None:
dim = shape[i]
ret.append(dim)
return ret
def approximate_split(x, num_splits, axis=0):
"""Split approximately equally into num_splits parts.
Args:
x: a Tensor
num_splits: an integer
axis: an integer.
Returns:
a list of num_splits Tensors.
"""
def fill_gpus(x, num_gpus):
multiples = [num_gpus] + [1] * (len(x.get_shape()) - 1)
return tf.tile(x, multiples)
x = tf.cond(tf.shape(x)[0] < num_splits, lambda: fill_gpus(x, num_splits), lambda: x)
size = shape_list(x)[axis]
size_splits = [tf.div(size + i, num_splits) for i in range(num_splits)]
return tf.split(x, size_splits, axis=axis)
def _get_embed_device(vocab_size):
"""Decide on which device to place an embed matrix given its vocab size."""
if vocab_size > VOCAB_SIZE_THRESHOLD_CPU:
return "/cpu:0"
else:
return "/gpu:0"
def create_emb(embed_name, vocab_size, embed_size, dtype):
with tf.device(_get_embed_device(vocab_size)):
embedding = tf.get_variable(
embed_name, [vocab_size, embed_size], dtype)
return embedding
def create_emb_for_encoder_and_decoder(share_vocab,
src_vocab_size,
tgt_vocab_size,
src_embed_size,
tgt_embed_size,
dtype=tf.float32,
scope=None):
with tf.variable_scope(scope or "embeddings", dtype=dtype) as scope:
# Share embedding
if share_vocab:
if src_vocab_size != tgt_vocab_size:
raise ValueError("Share embedding but different src/tgt vocab sizes"
" %d vs. %d" % (src_vocab_size, tgt_vocab_size))
assert src_embed_size == tgt_embed_size
print_out("# Use the same embedding for source and target")
embedding_encoder = create_emb("embedding_share", src_vocab_size, src_embed_size, dtype)
embedding_decoder = embedding_encoder
else:
with tf.variable_scope("encoder"):
embedding_encoder = create_emb("embedding_encoder", src_vocab_size, src_embed_size, dtype)
with tf.variable_scope("decoder"):
embedding_decoder = create_emb("embedding_decoder", tgt_vocab_size, tgt_embed_size, dtype)
return embedding_encoder, embedding_decoder
def shift_concat(decoder_outputs, fake_tensors, time_major=False):
"""
suppose the target input is <s> y1, y2, ... <e>
target output is y1, y2, ... <e>
Args:
decoder_outputs: tuple with forward and backward outputs
([ft], [bt])
Return: ft and b1 are disgarded, fake can be the initial state.
[fake, b2], [f1, b3], ..., [f_{t-2}, b_t], [f_{t-1}, fake]
"""
(fw_outputs, bw_outputs) = decoder_outputs
if fake_tensors is not None:
(fw_ft, bw_ft) = fake_tensors
fw_ft = tf.expand_dims(fw_ft, 0) if time_major else tf.expand_dims(fw_ft, 1)
bw_ft = tf.expand_dims(bw_ft, 0) if time_major else tf.expand_dims(bw_ft, 1)
else:
fw_ft = tf.zeros_like(fw_outputs[0:1]) if time_major else tf.zeros_like(fw_outputs[:, 0:1])
bw_ft = tf.zeros_like(bw_outputs[0:1]) if time_major else tf.zeros_like(bw_outputs[:, 0:1])
if time_major:
fw_outputs = tf.concat([fw_ft, fw_outputs[:-1]], axis=0)
bw_outputs = tf.concat([bw_outputs[1:], bw_ft], axis=0)
else:
fw_outputs = tf.concat([fw_ft, fw_outputs[:, :-1]], axis=1)
bw_outputs = tf.concat([bw_outputs[:, 1:], bw_ft], axis=1)
return tf.concat([fw_outputs, bw_outputs], axis=-1)
############################## Model Class ####################################
class BilingualExpert(object):
def __init__(self,
hparams,
mode,
iterator,
reverse_target_vocab_table=None,
scope=None):
self.devices = [x.name for x in device_lib.list_local_devices() if x.device_type == "GPU"]
assert len(self.devices) == hparams.num_gpus
self.iterator = iterator
self.mode = mode
self.label_smoothing = hparams.label_smoothing
self.num_encoder_layers = hparams.num_encoder_layers
self.num_decoder_layers = hparams.num_decoder_layers
# Initializer
initializer = tf.random_uniform_initializer(-0.1, 0.1, seed=hparams.random_seed)
tf.get_variable_scope().set_initializer(initializer)
# Embeddings
self.embedding_encoder, self.embedding_decoder = create_emb_for_encoder_and_decoder(
share_vocab=hparams.share_vocab,
src_vocab_size=hparams.src_vocab_size,
tgt_vocab_size=hparams.tgt_vocab_size,
src_embed_size=hparams.embedding_size,
tgt_embed_size=hparams.embedding_size,
scope=scope)
# Expert Model
# encoder
self.encoder = onmt.encoders.self_attention_encoder.SelfAttentionEncoder(
hparams.num_encoder_layers,
num_units=hparams.num_units,
num_heads=hparams.num_heads,
ffn_inner_dim=hparams.ffn_inner_dim,
dropout=hparams.dropout,
attention_dropout=hparams.dropout,
relu_dropout=hparams.dropout,
position_encoder=onmt.layers.position.SinusoidalPositionEncoder())
# fw_decoder
self.fw_decoder = onmt.decoders.self_attention_decoder.SelfAttentionDecoder(
hparams.num_decoder_layers,
num_units=hparams.num_units,
num_heads=hparams.num_heads,
ffn_inner_dim=hparams.ffn_inner_dim,
dropout=hparams.dropout,
attention_dropout=hparams.dropout,
relu_dropout=hparams.dropout,
position_encoder=onmt.layers.position.SinusoidalPositionEncoder())
# bw_decoder
self.bw_decoder = onmt.decoders.self_attention_decoder.SelfAttentionDecoder(
hparams.num_decoder_layers,
num_units=hparams.num_units,
num_heads=hparams.num_heads,
ffn_inner_dim=hparams.ffn_inner_dim,
dropout=hparams.dropout,
attention_dropout=hparams.dropout,
relu_dropout=hparams.dropout,
position_encoder=onmt.layers.position.SinusoidalPositionEncoder())
self.global_step = tf.Variable(0, trainable=False, name="global_step")
self.skip_count = tf.Variable(0, trainable=False, name="skip_count", dtype=tf.int64)
self.reset_sc = tf.assign(self.skip_count, 0)
# data used
self.source, self.target, self.src_sequence_length, self.tgt_sequence_length = self.iterator.qe_iterator.get_next()
# build graph
res = self.build_graph(hparams, scope=scope)
if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
self.train_loss, _ = res
elif self.mode == tf.contrib.learn.ModeKeys.INFER:
_, sampled_ids = res
self.sampled_words = reverse_target_vocab_table.lookup(tf.to_int64(sampled_ids))
self.params = tf.trainable_variables()
# optimization
if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
self.learning_rate = tf.constant(hparams.learning_rate)
self.learning_rate = self._get_learning_rate_warmup(hparams)
self.optimizer = tf.contrib.opt.LazyAdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.998)
# Gradients
gradients = tf.gradients(self.train_loss,
self.params,
colocate_gradients_with_ops=True)
clipped_grads, grad_norm_summary, grad_norm = gradient_clip(gradients, 5.0)
self.grad_norm = grad_norm
self.update = tf.contrib.layers.optimize_loss(self.train_loss,
self.global_step,
learning_rate=None,
optimizer=self.optimizer,
variables=self.params,
colocate_gradients_with_ops=True)
# also update skip count
self.update_sc = tf.assign_add(self.skip_count, tf.size(self.src_sequence_length, out_type=tf.int64), use_locking=True)
# Summary
self.train_summary = tf.summary.merge([
tf.summary.scalar("lr", self.learning_rate),
tf.summary.scalar("train_loss", self.train_loss)
] + grad_norm_summary)
# Saver
self.saver = tf.train.Saver(
tf.global_variables(), max_to_keep=hparams.num_keep_ckpts)
self.best_saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
if hparams.avg_ckpts:
self.avg_best_saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
# Print trainable variables
print_out("# Trainable variables")
for param in self.params:
print_out(" %s, %s, %s" % (param.name, str(param.get_shape()), param.op.device))
def _get_learning_rate_warmup(self, hparams):
warmup_steps = hparams.warmup_steps
print_out(" learning_rate=%g, warmup_steps=%d" % (hparams.learning_rate, warmup_steps))
step_num = tf.to_float(self.global_step) / 2. + 1
inv_decay = hparams.num_units ** -0.5 * tf.minimum(step_num * warmup_steps ** -1.5, step_num ** -0.5)
return inv_decay * self.learning_rate
def build_graph(self, hparams, scope=None):
print_out("# creating %s graph ..." % self.mode)
dtype = tf.float32
with tf.variable_scope(scope or "transformerpredictor", dtype=dtype):
if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
sampled_ids = None
source_shards = approximate_split(self.source, hparams.num_gpus)
src_len_shards = approximate_split(self.src_sequence_length, hparams.num_gpus)
target_shards = approximate_split(self.target, hparams.num_gpus)
tgt_len_shards = approximate_split(self.tgt_sequence_length, hparams.num_gpus)
loss_shards = []
for i, device in enumerate(self.devices):
with tf.name_scope("parallel_{}".format(i)):
with tf.variable_scope(tf.get_variable_scope(), reuse=True if i > 0 else None):
with tf.device(device):
logits, _ = self._build_model(source_shards[i],
src_len_shards[i],
target_shards[i],
tgt_len_shards[i],
hparams)
loss_shards.append(self._compute_loss(logits,
target_shards[i],
tgt_len_shards[i]))
_loss = tuple(list(loss_shard) for loss_shard in zip(*loss_shards))
loss = tf.add_n(_loss[0]) / tf.add_n(_loss[1])
elif self.mode == tf.contrib.learn.ModeKeys.INFER:
loss = None
_, sampled_ids = self._build_model(self.source,
self.src_sequence_length,
self.target,
self.tgt_sequence_length,
hparams)
return loss, sampled_ids
def _build_model(self, src, src_seq_lens, tgt, tgt_seq_lens, hparams):
with tf.variable_scope("encoder"):
encoder_emb_inp = tf.nn.embedding_lookup(self.embedding_encoder, src)
# Encoder_outputs: [max_time, batch_size, num_units]
encoder_outputs, _, encoder_sequence_length = self.encoder.encode(
encoder_emb_inp,
src_seq_lens,
mode=self.mode)
fw_target_input = tgt
fw_decoder_emb_inp = tf.nn.embedding_lookup(self.embedding_decoder, fw_target_input)
bw_decoder_emb_inp = tf.reverse_sequence(
fw_decoder_emb_inp,
tgt_seq_lens,
batch_axis=0,
seq_axis=1)
# Decoder
with tf.variable_scope("decoder"):
self.output_layer = tf.layers.Dense(
hparams.tgt_vocab_size, use_bias=False, name="output_projection")
self.emb_proj_layer = tf.layers.Dense(
2 * hparams.num_units, use_bias=False, name="emb_proj_layer")
with tf.variable_scope("fw_decoder"):
fw_outputs, _, _ = self.fw_decoder.decode(
fw_decoder_emb_inp,
tgt_seq_lens,
vocab_size=hparams.tgt_vocab_size,
initial_state=None, # unused arg for transformer
sampling_probability=None, # unsupported arg for transformer
embedding=None, # unused arg for transformer
output_layer=lambda x: x,
mode=self.mode,
memory=encoder_outputs,
memory_sequence_length=encoder_sequence_length)
with tf.variable_scope("bw_decoder"):
bw_outputs, _, _ = self.bw_decoder.decode(
bw_decoder_emb_inp,
tgt_seq_lens,
vocab_size=hparams.tgt_vocab_size,
initial_state=None, # unused arg for transformer
sampling_probability=None, # unsupported arg for transformer
embedding=None, # unused arg for transformer
output_layer=lambda x: x,
mode=self.mode,
memory=encoder_outputs,
memory_sequence_length=encoder_sequence_length)
bw_outputs_rev = tf.reverse_sequence(
bw_outputs,
tgt_seq_lens,
batch_axis=0,
seq_axis=1)
shift_outputs = shift_concat(
(fw_outputs, bw_outputs_rev),
None)
shift_inputs = shift_concat(
(fw_decoder_emb_inp, fw_decoder_emb_inp),
None)
shift_proj_inputs = self.emb_proj_layer(shift_inputs)
_pre_qefv = tf.concat([shift_outputs, shift_proj_inputs], axis=-1)
# _pre_qefv = shift_outputs + shift_proj_inputs
# Notice, currently <s> is not to predict, but actually in our QE model, we can predict it.
logits = self.output_layer(_pre_qefv)
sampled_ids = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
return logits, sampled_ids[:, 1:]
def _compute_loss(self, logits, tgt, tgt_seq_lens):
crossent, normalizer, _ = onmt.utils.losses.cross_entropy_sequence_loss(
logits,
tgt,
tgt_seq_lens,
self.label_smoothing,
average_in_time=True,
mode=self.mode)
return crossent, normalizer
def train(self, sess):
assert self.mode == tf.contrib.learn.ModeKeys.TRAIN
return sess.run([self.update,
self.update_sc,
self.train_loss,
self.train_summary,
self.global_step,
self.grad_norm,
self.learning_rate])
def infer(self, sess):
assert self.mode == tf.contrib.learn.ModeKeys.INFER
return sess.run(self.sampled_words)
############################## Training #################################
def load_data(inference_input_file):
"""Load inference data."""
with codecs.getreader("utf-8")(
tf.gfile.GFile(inference_input_file, mode="rb")) as f:
inference_data = f.read().splitlines()
return inference_data
def save_hparams(out_dir, hparams):
hparams_file = os.path.join(out_dir, "hparams")
print_out(" saving hparams to %s" % hparams_file)
with codecs.getwriter("utf-8")(tf.gfile.GFile(hparams_file, "wb")) as f:
f.write(hparams.to_json(indent="\t"))
def add_summary(summary_writer, global_step, tag, value):
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
summary_writer.add_summary(summary, global_step)
def evaluate(ref_file, pred_file, metric):
if metric.lower() == "bleu":
evaluation_score = _bleu(ref_file, pred_file)
elif metric.lower() == "word_accuracy":
evaluation_score = _word_accuracy(ref_file, pred_file)
else:
raise ValueError("Unknown metric %s" % metric)
return evaluation_score
def _clean(sentence, subword_option):
"""Clean and handle BPE or SPM outputs."""
sentence = sentence.strip()
# BPE
if subword_option == "bpe":
sentence = re.sub("@@ ", "", sentence)
# SPM
elif subword_option == "spm":
sentence = u"".join(sentence.split()).replace(u"\u2581", u" ").lstrip()
return sentence
def _bleu(ref_file, trans_file, subword_option=None):
"""Compute BLEU scores and handling BPE."""
max_order = 4
smooth = False
ref_files = [ref_file]
reference_text = []
for reference_filename in ref_files:
with codecs.getreader("utf-8")(tf.gfile.GFile(reference_filename, "rb")) as fh:
reference_text.append(fh.readlines())
per_segment_references = []
for references in zip(*reference_text):
reference_list = []
for reference in references:
reference = _clean(reference, subword_option)
reference_list.append(reference.split(" "))
per_segment_references.append(reference_list)
translations = []
with codecs.getreader("utf-8")(tf.gfile.GFile(trans_file, "rb")) as fh:
for line in fh:
line = _clean(line, subword_option)
translations.append(line.split(" "))
# bleu_score, precisions, bp, ratio, translation_length, reference_length
bleu_score, _, _, _, _, _ = compute_bleu(
per_segment_references, translations, max_order, smooth)
return 100 * bleu_score
def _get_ngrams(segment, max_order):
"""Extracts all n-grams upto a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i:i + order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus, translation_corpus, max_order=4, smooth=False):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of lists of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
smooth: Whether or not to apply Lin et al. 2004 smoothing.
Returns:
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
precisions and brevity penalty.
"""