Skip to content

Commit

Permalink
[examples] update whisper fine-tuning (#29938)
Browse files Browse the repository at this point in the history
* [examples] update whisper fine-tuning

* deprecate forced/suppress tokens

* item assignment

* update readme

* final fix
  • Loading branch information
sanchit-gandhi authored Apr 26, 2024
1 parent aafa7ce commit 38b53da
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
9 changes: 4 additions & 5 deletions examples/pytorch/speech-recognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ python run_speech_recognition_seq2seq.py \
--dataset_name="mozilla-foundation/common_voice_11_0" \
--dataset_config_name="hi" \
--language="hindi" \
--task="transcribe" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--max_steps="5000" \
Expand All @@ -384,12 +385,10 @@ python run_speech_recognition_seq2seq.py \
--save_steps="1000" \
--generation_max_length="225" \
--preprocessing_num_workers="16" \
--length_column_name="input_length" \
--max_duration_in_seconds="30" \
--text_column_name="sentence" \
--freeze_feature_encoder="False" \
--gradient_checkpointing \
--group_by_length \
--fp16 \
--overwrite_output_dir \
--do_train \
Expand All @@ -399,7 +398,8 @@ python run_speech_recognition_seq2seq.py \
```
On a single V100, training should take approximately 8 hours, with a final cross-entropy loss of **1e-4** and word error rate of **32.6%**.

If training on a different language, you should be sure to change the `language` argument. The `language` argument should be omitted for English speech recognition.
If training on a different language, you should be sure to change the `language` argument. The `language` and `task`
arguments should be omitted for English speech recognition.

#### Multi GPU Whisper Training
The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using 2 GPU devices in half-precision:
Expand All @@ -410,6 +410,7 @@ torchrun \
--dataset_name="mozilla-foundation/common_voice_11_0" \
--dataset_config_name="hi" \
--language="hindi" \
--task="transcribe" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--max_steps="5000" \
Expand All @@ -425,12 +426,10 @@ torchrun \
--save_steps="1000" \
--generation_max_length="225" \
--preprocessing_num_workers="16" \
--length_column_name="input_length" \
--max_duration_in_seconds="30" \
--text_column_name="sentence" \
--freeze_feature_encoder="False" \
--gradient_checkpointing \
--group_by_length \
--fp16 \
--overwrite_output_dir \
--do_train \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,16 @@ class ModelArguments:
)
forced_decoder_ids: List[List[int]] = field(
default=None,
metadata={
metadata={"help": "Deprecated. Please use the `language` and `task` arguments instead."},
)
suppress_tokens: List[int] = field(
default=None, metadata={
"help": (
"A list of pairs of integers which indicates a mapping from generation indices to token indices "
"that will be forced before sampling. For example, [[0, 123]] means the first generated token "
"will always be a token of index 123."
"Deprecated. The use of `suppress_tokens` should not be required for the majority of fine-tuning examples."
"Should you need to use `suppress_tokens`, please manually update them in the fine-tuning script directly."
)
},
)
suppress_tokens: List[int] = field(
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
)
apply_spec_augment: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -400,8 +399,6 @@ def main():
trust_remote_code=model_args.trust_remote_code,
)

config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})

# SpecAugment for whisper models
if getattr(config, "model_type", None) == "whisper":
config.update({"apply_spec_augment": model_args.apply_spec_augment})
Expand Down Expand Up @@ -440,9 +437,35 @@ def main():
model.freeze_encoder()
model.model.encoder.gradient_checkpointing = False

if data_args.language is not None:
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
# We only need to set the language and task ids in a multilingual setting
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
model.generation_config.update(
**{
"language": data_args.language,
"task": data_args.task,
}
)
elif data_args.language is not None:
raise ValueError(
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
"only be set for multilingual checkpoints."
)

# TODO (Sanchit): deprecate these arguments in v4.41
if model_args.forced_decoder_ids is not None:
logger.warning(
"The use of `forced_decoder_ids` is deprecated and will be removed in v4.41."
"Please use the `language` and `task` arguments instead"
)
model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids

if model_args.suppress_tokens is not None:
logger.warning(
"The use of `suppress_tokens` is deprecated and will be removed in v4.41."
"Should you need `suppress_tokens`, please manually set them in the fine-tuning script."
)
model.generation_config.suppress_tokens = model_args.suppress_tokens

# 6. Resample speech dataset if necessary
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
Expand Down

0 comments on commit 38b53da

Please sign in to comment.