How to inference with Transformer ? decoder.decode_seq() vs decoder() ? #632
-
Hi, guys, I train a transformer net recently, but i got a problem I thought decoder.decode_seq is used for training, and we have to input target_seq, so we can train the net with teacher forcing, am i wrong ? there are two different code down below, first one use the api decoder.decode_seq(), second one use decoder() to implement teacher forcing, I thought these two code could get same result, but actually not, and i cant tell why, anyone knows why ? or how could I do the inference correctly ? Thanks decoder.decode_seq()
decoder()
|
Beta Was this translation helpful? Give feedback.
Replies: 13 comments
-
@chenjunweii you're right in that decode_seq is implementing teacher forcing. Note that in teacher forcing, the decoder knows the ground truth of the previous step, whereas a regular decoder only takes its own prediction from last step. This means for the decoder, in the teacher forcing method it's actually dealing with a simpler problem than the free decode in the second case. Small error in free decode can accumulate from the previous steps and cause the prediction to go worse with longer sequences. |
Beta Was this translation helpful? Give feedback.
-
@szha Thanks for your reply, I know the difference between free running and teacher, but what i mean is i thought the second code is also teacher forcing, decoder take each step of target sequence as input, so i thought two code would have the same result, but not, and I can't figure out what is the problem of the second code , Thanks |
Beta Was this translation helpful? Give feedback.
-
🤦♂️ sorry that I completely misunderstood you as I missed that the |
Beta Was this translation helpful? Give feedback.
-
|
Beta Was this translation helpful? Give feedback.
-
@chenjunweii For inference, you can use following code: gluon-nlp/scripts/machine_translation/train_transformer.py Lines 241 to 248 in 6ec0c84 It uses translator to output the sampled sequences. Specifically, it internally uses beam search and function decode_step .
|
Beta Was this translation helpful? Give feedback.
-
@szhengac What are shapes of inputs in translator.translate? I tried to run them as in your example, but I get a mismatch MXNetError: Shape inconsistent, Provided = [n_hid,3200], inferred shape=(n_hid,n_hid) Where 3200 = sequence length * n_hid Details: example_input_batch is a tensor of shape batch_size=100, length=25 I run either
I added print-debugs to your code to show intermediate tensor shapes, when I print before offending line (
|
Beta Was this translation helpful? Give feedback.
-
@lambdaofgod The input has following shape: gluon-nlp/scripts/machine_translation/translation.py Lines 60 to 61 in e1910c5
|
Beta Was this translation helpful? Give feedback.
-
@szhengac my prints seem to validate that, if so, do you have an idea why doesn't it work? For me it seems like second decoder_state should only have shape (batch_size,) but I've also checked your NMT script for that and its shapes are consistent with what I get here... Still, your NMT script runs, and I get that bug, even though the input is consistent with the one from your example |
Beta Was this translation helpful? Give feedback.
-
@lambdaofgod Can you comment out |
Beta Was this translation helpful? Give feedback.
-
@szhengac Thanks for your reply, and that's what I need thanks ! by the way what is Parallel Transformer ? could it speed up training process ? |
Beta Was this translation helpful? Give feedback.
-
@chenjunweii ParallelTransformer uses multi-threading for multi-gpu training. The naive implementation for multi-gpu training in gluon does not fully achieve parallelization. |
Beta Was this translation helpful? Give feedback.
-
Ok, this is getting unwieldy, I created this issue |
Beta Was this translation helpful? Give feedback.
-
Let me know if you need this issue reopened. |
Beta Was this translation helpful? Give feedback.
@chenjunweii For inference, you can use following code:
gluon-nlp/scripts/machine_translation/train_transformer.py
Lines 241 to 248 in 6ec0c84
It uses translator to output the sampled sequences. Specifically, it internally uses beam search and function
decode_step
.