Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Sep 21, 2023
1 parent 7fc27f6 commit 2dd5209
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
5 changes: 5 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,11 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs["last_hidden_state"][1] = f"{common_outputs['last_hidden_state'][1]} / 2"
return common_outputs

class SpeechT5OnnxConfig():
NORMALIZED_CONFIG_CLASS =




class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
Expand Down
36 changes: 24 additions & 12 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,26 +159,27 @@ class TasksManager:
# task in a Hub repo that has no pipeline_tag, and no transformersInfo.pipeline_tag, as we then rely on
# on transformersInfo["auto_model"] and this dictionary.
_TRANSFORMERS_TASKS_TO_MODEL_LOADERS = {
"audio-classification": "AutoModelForAudioClassification",
"audio-frame-classification": "AutoModelForAudioFrameClassification",
"audio-xvector": "AutoModelForAudioXVector",
"automatic-speech-recognition": ("AutoModelForSpeechSeq2Seq", "AutoModelForCTC"),
"conversational": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
"feature-extraction": "AutoModel",
"fill-mask": "AutoModelForMaskedLM",
"text-generation": "AutoModelForCausalLM",
"text2text-generation": "AutoModelForSeq2SeqLM",
"text-classification": "AutoModelForSequenceClassification",
"token-classification": "AutoModelForTokenClassification",
"multiple-choice": "AutoModelForMultipleChoice",
"object-detection": "AutoModelForObjectDetection",
"question-answering": "AutoModelForQuestionAnswering",
"image-classification": "AutoModelForImageClassification",
"image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"),
"image-to-text": "AutoModelForVision2Seq",
"mask-generation": "AutoModel",
"masked-im": "AutoModelForMaskedImageModeling",
"multiple-choice": "AutoModelForMultipleChoice",
"object-detection": "AutoModelForObjectDetection",
"question-answering": "AutoModelForQuestionAnswering",
"semantic-segmentation": "AutoModelForSemanticSegmentation",
"automatic-speech-recognition": ("AutoModelForSpeechSeq2Seq", "AutoModelForCTC"),
"audio-classification": "AutoModelForAudioClassification",
"audio-frame-classification": "AutoModelForAudioFrameClassification",
"audio-xvector": "AutoModelForAudioXVector",
"image-to-text": "AutoModelForVision2Seq",
"text-to-speech": "AutoModelForTextToSpectrogram",
"text-generation": "AutoModelForCausalLM",
"text2text-generation": "AutoModelForSeq2SeqLM",
"text-classification": "AutoModelForSequenceClassification",
"token-classification": "AutoModelForTokenClassification",
"zero-shot-image-classification": "AutoModelForZeroShotImageClassification",
"zero-shot-object-detection": "AutoModelForZeroShotObjectDetection",
}
Expand Down Expand Up @@ -264,6 +265,8 @@ class TasksManager:
("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"),
# VisionEncoderDecoderModel is not registered in AutoModelForDocumentQuestionAnswering
("pt", "vision-encoder-decoder", "document-question-answering"): ("transformers", "VisionEncoderDecoderModel"),
# audio-to-audio task has no AutoModel class.
("pt", "speecht5", "audio-to-audio"): ("transformers", "SpeechT5ForSpeechToSpeech"),
}

# TODO: why feature-extraction-with-past is here?
Expand Down Expand Up @@ -838,6 +841,15 @@ class TasksManager:
"automatic-speech-recognition-with-past",
onnx="Speech2TextOnnxConfig",
),
"speecht5": supported_tasks_mapping(
"audio-to-audio",
"audio-to-audio-with-past",
"automatic-speech-recognition",
"automatic-speech-recognition-with-past",
"text-to-speech",
"text-to-speech-with-past",
onnx="SpeechT5OnnxConfig",
),
"splinter": supported_tasks_mapping(
"feature-extraction",
"question-answering",
Expand Down

0 comments on commit 2dd5209

Please sign in to comment.