From a041e1ae9cd5d69af993f5da6561223ad407f5da Mon Sep 17 00:00:00 2001 From: Xu Song Date: Fri, 18 Dec 2020 07:40:57 -0800 Subject: [PATCH] Fix parameter (#3045) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? `src_lengths` is not a required parameter in `TransformerEncoder`. It is a dummy variable. Maybe more changes should be done to fix this issue in Class such as `Transformer`, `FairseqEncoderDecoderModel`, `BARTModel` etc. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3045 Reviewed By: ngoyal2707 Differential Revision: D25632992 Pulled By: myleott fbshipit-source-id: 762d595144b611e1a6c236248d7001049afed0ab --- fairseq/models/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 9655578e52..fa4c29855b 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -402,7 +402,7 @@ def forward_embedding( def forward( self, src_tokens, - src_lengths, + src_lengths: Optional[torch.Tensor] = None, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, ): @@ -418,7 +418,7 @@ def forward( default `None` will recompute embeddings Returns: - namedtuple: + dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of