diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 1c83840880..3535447ac7 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1628,6 +1628,33 @@ def inputs(self) -> Dict[str, Dict[int, str]]: # def inputs(self) -> Dict[str, Dict[int, str]]: # return {"input_features": {0: "batch_size", 1: "sequence_classification"}} +class MoonshineOnnxConfig(AudioToTextOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig + + # torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::triu' to ONNX opset version 11 is not supported. + # Support for this operator was added in version 14, try exporting with this version. + DEFAULT_ONNX_OPSET = 14 + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = {} + + if self._behavior is not ConfigBehavior.DECODER: + common_inputs["input_values"] = {0: "batch_size", 1: "num_samples"} + + if self._behavior is not ConfigBehavior.ENCODER: + if self.use_past_in_inputs: + common_inputs["decoder_input_ids"] = {0: "batch_size"} + self.add_past_key_values(common_inputs, direction="inputs") + else: + common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} + + if self._behavior is ConfigBehavior.DECODER: + common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"} + + return common_inputs + + class WhisperOnnxConfig(AudioToTextOnnxConfig): DEFAULT_ONNX_OPSET = 14 # Whisper now uses F.scaled_dot_product_attention by default for torch>=2.1.1. diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 4db4130302..16e3dc77ad 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -869,6 +869,13 @@ class TasksManager: "image-classification", onnx="MobileNetV2OnnxConfig", ), + "moonshine": supported_tasks_mapping( + "feature-extraction", + "feature-extraction-with-past", + "automatic-speech-recognition", + "automatic-speech-recognition-with-past", + onnx="MoonshineOnnxConfig", + ), "mpnet": supported_tasks_mapping( "feature-extraction", "fill-mask",