forked from ghchen18/cdalign
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmultilingual_transformer.py
196 lines (169 loc) · 8.74 KB
/
multilingual_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from fairseq import utils
from fairseq.models import (
FairseqMultiModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
base_architecture,
Embedding,
TransformerModel,
TransformerEncoder,
TransformerDecoder,
)
@register_model('multilingual_transformer')
class MultilingualTransformerModel(FairseqMultiModel):
"""Train Transformer models for multiple language pairs simultaneously.
Requires `--task multilingual_translation`.
We inherit all arguments from TransformerModel and assume that all language
pairs use a single Transformer architecture. In addition, we provide several
options that are specific to the multilingual setting.
Args:
--share-encoder-embeddings: share encoder embeddings across all source languages
--share-decoder-embeddings: share decoder embeddings across all target languages
--share-encoders: share all encoder params (incl. embeddings) across all source languages
--share-decoders: share all decoder params (incl. embeddings) across all target languages
"""
def __init__(self, encoders, decoders):
super().__init__(encoders, decoders)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
TransformerModel.add_args(parser)
parser.add_argument('--share-encoder-embeddings', action='store_true',
help='share encoder embeddings across languages')
parser.add_argument('--share-decoder-embeddings', action='store_true',
help='share decoder embeddings across languages')
parser.add_argument('--share-encoders', action='store_true',
help='share encoders across languages')
parser.add_argument('--share-decoders', action='store_true',
help='share decoders across languages')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
assert isinstance(task, MultilingualTranslationTask)
# make sure all arguments are present in older models
base_multilingual_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = 1024
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = 1024
src_langs = [lang_pair.split('-')[0] for lang_pair in task.model_lang_pairs]
tgt_langs = [lang_pair.split('-')[1] for lang_pair in task.model_lang_pairs]
if args.share_encoders:
args.share_encoder_embeddings = True
if args.share_decoders:
args.share_decoder_embeddings = True
def build_embedding(dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
# build shared embeddings (if applicable)
shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
if args.share_all_embeddings:
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
dicts=task.dicts,
langs=task.langs,
embed_dim=args.encoder_embed_dim,
build_embedding=build_embedding,
pretrained_embed_path=args.encoder_embed_path,
)
shared_decoder_embed_tokens = shared_encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
if args.share_encoder_embeddings:
shared_encoder_embed_tokens = (
FairseqMultiModel.build_shared_embeddings(
dicts=task.dicts,
langs=src_langs,
embed_dim=args.encoder_embed_dim,
build_embedding=build_embedding,
pretrained_embed_path=args.encoder_embed_path,
)
)
if args.share_decoder_embeddings:
shared_decoder_embed_tokens = (
FairseqMultiModel.build_shared_embeddings(
dicts=task.dicts,
langs=tgt_langs,
embed_dim=args.decoder_embed_dim,
build_embedding=build_embedding,
pretrained_embed_path=args.decoder_embed_path,
)
)
# encoders/decoders for each language
lang_encoders, lang_decoders = {}, {}
def get_encoder(lang):
if lang not in lang_encoders:
if shared_encoder_embed_tokens is not None:
encoder_embed_tokens = shared_encoder_embed_tokens
else:
encoder_embed_tokens = build_embedding(
task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path
)
lang_encoders[lang] = TransformerEncoder(args, task.dicts[lang], encoder_embed_tokens)
return lang_encoders[lang]
def get_decoder(lang):
if lang not in lang_decoders:
if shared_decoder_embed_tokens is not None:
decoder_embed_tokens = shared_decoder_embed_tokens
else:
decoder_embed_tokens = build_embedding(
task.dicts[lang], args.decoder_embed_dim, args.decoder_embed_path
)
lang_decoders[lang] = TransformerDecoder(args, task.dicts[lang], decoder_embed_tokens)
return lang_decoders[lang]
# shared encoders/decoders (if applicable)
shared_encoder, shared_decoder = None, None
if args.share_encoders:
shared_encoder = get_encoder(src_langs[0])
if args.share_decoders:
shared_decoder = get_decoder(tgt_langs[0])
encoders, decoders = OrderedDict(), OrderedDict()
for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs):
encoders[lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(src)
decoders[lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(tgt)
return MultilingualTransformerModel(encoders, decoders)
def load_state_dict(self, state_dict, strict=True):
state_dict_subset = state_dict.copy()
for k, _ in state_dict.items():
assert k.startswith('models.')
lang_pair = k.split('.')[1]
if lang_pair not in self.models:
del state_dict_subset[k]
super().load_state_dict(state_dict_subset, strict=strict)
@register_model_architecture('multilingual_transformer', 'multilingual_transformer')
def base_multilingual_architecture(args):
base_architecture(args)
args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', False)
args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', False)
args.share_encoders = getattr(args, 'share_encoders', False)
args.share_decoders = getattr(args, 'share_decoders', False)
@register_model_architecture('multilingual_transformer', 'multilingual_transformer_iwslt_de_en')
def multilingual_transformer_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
base_multilingual_architecture(args)