From c5a8a1d89a336bcf41ece1c89598d9f6d4f09aaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 5 Oct 2023 15:52:38 +0200 Subject: [PATCH] fix custom models --- optimum/exporters/onnx/__main__.py | 26 +++++++++++++++++++++++++- optimum/exporters/tasks.py | 12 +----------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 3dbee581062..fcafd3fd8d9 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -19,7 +19,7 @@ from pathlib import Path from requests.exceptions import ConnectionError as RequestsConnectionError -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from transformers.utils import is_torch_available from ...commands.export.onnx import parse_args_onnx @@ -338,6 +338,30 @@ def main_export( f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" ) + if library_name == "transformers": + config = AutoConfig.from_pretrained( + model_name_or_path, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + ) + model_type = config.model_type.replace("_", "-") + if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: + custom_architecture = True + elif task not in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx"): + if original_task == "auto": + autodetected_message = " (auto-detected)" + else: + autodetected_message = "" + model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx") + raise ValueError( + f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}." + ) + model = TasksManager.get_model_from_task( task, model_name_or_path, diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 8cf8cae4863..baf163b1691 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1058,7 +1058,7 @@ def get_supported_tasks_for_model_type( if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: raise KeyError( f"{model_type_and_model_name} is not supported yet. " - f"Only {TasksManager._SUPPORTED_MODEL_TYPE} are supported. " + f"Only {list(TasksManager._SUPPORTED_MODEL_TYPE.keys())} are supported. " f"If you want to support {model_type} please propose a PR or open up an issue." ) elif exporter not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]: @@ -1687,16 +1687,6 @@ def get_model_from_task( if original_task == "auto" and config.architectures is not None: model_class_name = config.architectures[0] - if task not in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx"): - if original_task == "auto": - autodetected_message = " (auto-detected)" - else: - autodetected_message = "" - model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx") - raise ValueError( - f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}." - ) - model_class = TasksManager.get_model_class_for_task( task, framework, model_type=model_type, model_class_name=model_class_name, library=library_name )