Skip to content

Commit

Permalink
test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Oct 5, 2023
1 parent 918893e commit 74ba08c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
13 changes: 2 additions & 11 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,20 +366,11 @@ def main_export(
if not is_stable_diffusion:
if model_type in TasksManager._UNSUPPORTED_CLI_MODEL_TYPE:
raise ValueError(
f"{model_type} is not supported yet. Only {TasksManager._SUPPORTED_CLI_MODEL_TYPE} are supported. "
f"{model_type} is not supported yet. Only {list(TasksManager._SUPPORTED_CLI_MODEL_TYPE.keys())} are supported. "
f"If you want to support {model_type} please propose a PR or open up an issue."
)
if model.config.model_type.replace("-", "_") not in TasksManager._SUPPORTED_MODEL_TYPE:
if model.config.model_type.replace("_", "-") not in TasksManager._SUPPORTED_MODEL_TYPE:
custom_architecture = True
elif task not in TasksManager.get_supported_tasks_for_model_type(model.config.model_type, "onnx"):
if original_task == "auto":
autodetected_message = " (auto-detected)"
else:
autodetected_message = ""
model_tasks = TasksManager.get_supported_tasks_for_model_type(model.config.model_type, exporter="onnx")
raise ValueError(
f"Asked to export a {model.config.model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model.config.model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model.config.model_type}."
)

# TODO: support onnx_config.py in the model repo
if custom_architecture and custom_onnx_configs is None:
Expand Down
4 changes: 2 additions & 2 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,8 +903,8 @@ def post_process_exported_models(
path, models_and_onnx_configs, onnx_files_subpaths
)

# Attempt to merge only if the decoder was exported without/with past
if self.use_past is True or self.variant == "with-past":
# Attempt to merge only if the decoder was exported without/with past, and ignore seq2seq models exported with text-generation task
if len(onnx_files_subpaths) >= 3 and self.use_past is True or self.variant == "with-past":
decoder_path = Path(path, onnx_files_subpaths[1])
decoder_with_past_path = Path(path, onnx_files_subpaths[2])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
Expand Down
12 changes: 11 additions & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,7 @@ def get_supported_tasks_for_model_type(
`TaskNameToExportConfigDict`: The dictionary mapping each task to a corresponding `ExportConfig`
constructor.
"""
model_type = model_type.lower()
model_type = model_type.lower().replace("_", "-")
model_type_and_model_name = f"{model_type} ({model_name})" if model_name else model_type
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
raise KeyError(
Expand Down Expand Up @@ -1687,6 +1687,16 @@ def get_model_from_task(
if original_task == "auto" and config.architectures is not None:
model_class_name = config.architectures[0]

if task not in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx"):
if original_task == "auto":
autodetected_message = " (auto-detected)"
else:
autodetected_message = ""
model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx")
raise ValueError(
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
)

model_class = TasksManager.get_model_class_for_task(
task, framework, model_type=model_type, model_class_name=model_class_name, library=library_name
)
Expand Down

0 comments on commit 74ba08c

Please sign in to comment.