diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 4edbeaa4600..3e11c7e614a 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1460,9 +1460,6 @@ class WhisperOnnxConfig(AudioToTextOnnxConfig): feature_size="num_mel_bins", allow_new=True, ) - # a custom generation mixin was introduced in 4.37.0 with batched long form generation - # https://github.com/huggingface/transformers/pull/27658 - MIN_TRANSFORMERS_VERSION = version.parse("4.37.0") ATOL_FOR_VALIDATION = 1e-3 @property diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 7248be4e37a..03d51978c7f 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -39,7 +39,6 @@ from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from transformers.models.auto.modeling_auto import MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES -from transformers.models.whisper.generation_whisper import WhisperGenerationMixin import onnxruntime as ort @@ -67,6 +66,14 @@ ) +if check_if_transformers_greater("4.37.0"): + from transformers.models.whisper.generation_whisper import WhisperGenerationMixin +else: + + class WhisperGenerationMixin: + generate = WhisperForConditionalGeneration.generate + + if check_if_transformers_greater("4.25.0"): from transformers.generation import GenerationMixin else: