Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dtype mismatch in beam_search.py #10

Open
keyakkie opened this issue Nov 11, 2021 · 0 comments
Open

dtype mismatch in beam_search.py #10

keyakkie opened this issue Nov 11, 2021 · 0 comments

Comments

@keyakkie
Copy link

Sorry for the rudimentary question.
For predicting queries, I run the following code as instructed in README.md:

python ./OpenNMT-py/translate.py \
  -gpu 0 \
  -model ${DATA_DIR}/doc2query_step_10000.pt \
  -src ${DATA_DIR}/opennmt_format/src-collection.txt \
  -output ${DATA_DIR}/opennmt_format/pred-collection_beam5.txt \
  -batch_size 32 \
  -beam_size 5 \
  --n_best 5 \
  -replace_unk \
  -report_time

Then, I got an error as follows:

Traceback (most recent call last):
  File "/home/work/doc2query/./OpenNMT-py/translate.py", line 46, in <module>
    main(opt)
  File "/home/work/doc2query/./OpenNMT-py/translate.py", line 25, in main
    translator.translate(
  File "/home/work/doc2query/OpenNMT-py/onmt/translate/translator.py", line 314, in translate
    batch_data = self.translate_batch(
  File "/home/work/doc2query/OpenNMT-py/onmt/translate/translator.py", line 498, in translate_batch
    return self._translate_batch(
  File "/home/work/doc2query/OpenNMT-py/onmt/translate/translator.py", line 650, in _translate_batch
    beam.advance(log_probs, attn)
  File "/home/work/doc2query/OpenNMT-py/onmt/translate/beam_search.py", line 155, in advance
    torch.div(self.topk_ids, vocab_size, out=self._batch_index)
RuntimeError: result type Float can't be cast to the desired output type Long

ref:

self._batch_index = torch.empty([batch_size, beam_size],
  dtype=torch.long, device=mb_device)

It seems dtype of self._batch_index is torch.float , but it causes another error.

Does anyone know how to fix it?
Thanks in advance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant