Skip to content

Commit

Permalink
Correct the new defaults (#34377)
Browse files Browse the repository at this point in the history
* Correct the new defaults

* CIs

* add check

* Update utils.py

* Update utils.py

* Add the max_length in generate test checking shape without passing length

* style

* CIs

* fix fx CI issue
  • Loading branch information
Cyrilvallez authored Oct 24, 2024
1 parent 1c5918d commit 4c6e0c9
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
5 changes: 4 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,8 +1440,11 @@ def _prepare_generated_length(
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]
else: # by default let's always generate 10 new tokens
elif has_default_max_length: # by default let's always generate 20 new tokens
generation_config.max_length = generation_config.max_length + input_ids_length
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
if max_position_embeddings is not None:
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)

# same for min length
if generation_config.min_new_tokens is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,9 @@ def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config

# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
input_ids,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
)
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ def check_encoder_decoder_model_generate(

# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
inputs,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
)
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val

# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
inputs,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
)
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))

Expand Down Expand Up @@ -873,6 +875,7 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
Expand Down Expand Up @@ -990,6 +993,7 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
Expand Down Expand Up @@ -1107,6 +1111,7 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
Expand Down

0 comments on commit 4c6e0c9

Please sign in to comment.