From 60bc6ce52a58077ca799820a4f7e04d1d9075b12 Mon Sep 17 00:00:00 2001 From: James Cross Date: Fri, 24 May 2019 18:11:24 -0700 Subject: [PATCH] hybrid: enable no bottleneck Differential Revision: D15507269 fbshipit-source-id: c908fec701a5b20f5b550fa5b44184babd48c292 --- pytorch_translate/hybrid_transformer_rnn.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/pytorch_translate/hybrid_transformer_rnn.py b/pytorch_translate/hybrid_transformer_rnn.py index 8efdb8ac..d1c8c66b 100644 --- a/pytorch_translate/hybrid_transformer_rnn.py +++ b/pytorch_translate/hybrid_transformer_rnn.py @@ -137,6 +137,7 @@ def add_args(parser): ) parser.add_argument( "--decoder-out-embed-dim", + default=None, type=int, metavar="N", help="decoder output embedding dimension", @@ -220,7 +221,7 @@ def _init_dims(self, args, src_dict, dst_dict, embed_tokens): self.input_dim = self.lstm_units + self.attention_dim self.num_attention_heads = args.decoder_attention_heads - self.out_embed_dim = args.decoder_out_embed_dim + self.bottleneck_dim = args.decoder_out_embed_dim def _init_components(self, args, src_dict, dst_dict, embed_tokens): self.initial_rnn_layer = nn.LSTM( @@ -249,9 +250,14 @@ def _init_components(self, args, src_dict, dst_dict, embed_tokens): nn.LSTM(input_size=self.input_dim, hidden_size=self.lstm_units) ) - self.bottleneck_layer = fairseq_transformer.Linear( - self.input_dim, self.out_embed_dim - ) + self.bottleneck_layer = None + if self.bottleneck_dim is not None: + self.out_embed_dim = self.bottleneck_dim + self.bottleneck_layer = fairseq_transformer.Linear( + self.input_dim, self.out_embed_dim + ) + else: + self.out_embed_dim = self.input_dim self.embed_out = nn.Parameter(torch.Tensor(len(dst_dict), self.out_embed_dim)) nn.init.normal_(self.embed_out, mean=0, std=self.out_embed_dim ** -0.5) @@ -378,7 +384,8 @@ def forward( x = torch.cat([x, attention_out], dim=2) x = self._concat_latent_code(x, encoder_out) - x = self.bottleneck_layer(x) + if self.bottleneck_layer is not None: + x = self.bottleneck_layer(x) # T x B x C -> B x T x C x = x.transpose(0, 1)