Skip to content

Commit

Permalink
fix custom models
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Oct 5, 2023
1 parent 74ba08c commit c5a8a1d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
26 changes: 25 additions & 1 deletion optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pathlib import Path

from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import is_torch_available

from ...commands.export.onnx import parse_args_onnx
Expand Down Expand Up @@ -338,6 +338,30 @@ def main_export(
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)

if library_name == "transformers":
config = AutoConfig.from_pretrained(
model_name_or_path,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
)
model_type = config.model_type.replace("_", "-")
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
custom_architecture = True
elif task not in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx"):
if original_task == "auto":
autodetected_message = " (auto-detected)"
else:
autodetected_message = ""
model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx")
raise ValueError(
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
)

model = TasksManager.get_model_from_task(
task,
model_name_or_path,
Expand Down
12 changes: 1 addition & 11 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ def get_supported_tasks_for_model_type(
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
raise KeyError(
f"{model_type_and_model_name} is not supported yet. "
f"Only {TasksManager._SUPPORTED_MODEL_TYPE} are supported. "
f"Only {list(TasksManager._SUPPORTED_MODEL_TYPE.keys())} are supported. "
f"If you want to support {model_type} please propose a PR or open up an issue."
)
elif exporter not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]:
Expand Down Expand Up @@ -1687,16 +1687,6 @@ def get_model_from_task(
if original_task == "auto" and config.architectures is not None:
model_class_name = config.architectures[0]

if task not in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx"):
if original_task == "auto":
autodetected_message = " (auto-detected)"
else:
autodetected_message = ""
model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx")
raise ValueError(
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
)

model_class = TasksManager.get_model_class_for_task(
task, framework, model_type=model_type, model_class_name=model_class_name, library=library_name
)
Expand Down

0 comments on commit c5a8a1d

Please sign in to comment.