Skip to content

Commit

Permalink
deprecate forced/suppress tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Apr 1, 2024
1 parent 55295c8 commit a20a96f
Showing 1 changed file with 22 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,16 @@ class ModelArguments:
)
forced_decoder_ids: List[List[int]] = field(
default=None,
metadata={
metadata={"help": "Deprecated. Please use the `language` and `task` arguments instead."},
)
suppress_tokens: List[int] = field(
default=None, metadata={
"help": (
"A list of pairs of integers which indicates a mapping from generation indices to token indices "
"that will be forced before sampling. For example, [[0, 123]] means the first generated token "
"will always be a token of index 123."
"Deprecated. The use of `suppress_tokens` should not be required for the majority of fine-tuning examples."
"Should you need to use `suppress_tokens`, please manually update them in the fine-tuning script directly."
)
},
)
suppress_tokens: List[int] = field(
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
)
apply_spec_augment: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -400,8 +399,6 @@ def main():
trust_remote_code=model_args.trust_remote_code,
)

config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})

# SpecAugment for whisper models
if getattr(config, "model_type", None) == "whisper":
config.update({"apply_spec_augment": model_args.apply_spec_augment})
Expand Down Expand Up @@ -455,9 +452,22 @@ def main():
"only be set for multilingual checkpoints."
)

if hasattr(model.generation_config, "forced_decoder_ids"):
# forced decoder ids are now handled entirely by the decoder input ids
model.generation_config.forced_decoder_ids = None
# TODO (Sanchit): deprecate these arguments in v4.40
if model_args.forced_decoder_ids is not None:
logger.warning(
"The use of `forced_decoder_ids` is deprecated and will be removed in v4.40."
"Please use the `language` and `task` arguments instead"
)
model.generation_config["forced_decoder_ids"] = model_args.forced_decoder_ids
else:
model.generation_config["forced_decoder_ids"] = None

if model_args.suppress_tokens is not None:
logger.warning(
"The use of `suppress_tokens` is deprecated and will be removed in v4.40."
"Should you need `suppress_tokens`, please manually set them in the fine-tuning script."
)
model.generation_config["suppress_tokens"] = model_args.suppress_tokens

# 6. Resample speech dataset if necessary
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
Expand Down

0 comments on commit a20a96f

Please sign in to comment.