From 1b5a63da593599b1e6e178754146e0109d3305d9 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Thu, 10 Oct 2024 11:52:35 +0200 Subject: [PATCH] Remove the need for the config to be in the subfolder (#2044) * remove the need for the config to be in the subfolder * fix * fix for offline mode * add log * fix * enable load local model in subfolder * fix windows --- optimum/modeling_base.py | 36 ++++++++++++++++++----------- optimum/onnxruntime/modeling_ort.py | 6 ++--- tests/onnxruntime/test_modeling.py | 15 ++++++++++++ 3 files changed, 39 insertions(+), 18 deletions(-) diff --git a/optimum/modeling_base.py b/optimum/modeling_base.py index 29521b7c0c6..48c738514ae 100644 --- a/optimum/modeling_base.py +++ b/optimum/modeling_base.py @@ -380,27 +380,35 @@ def from_pretrained( ) model_id, revision = model_id.split("@") + all_files, _ = TasksManager.get_model_files( + model_id, + subfolder=subfolder, + cache_dir=cache_dir, + revision=revision, + token=token, + ) + + config_folder = subfolder + if cls.config_name not in all_files: + logger.info( + f"{cls.config_name} not found in the specified subfolder {subfolder}. Using the top level {cls.config_name}." + ) + config_folder = "" + library_name = TasksManager.infer_library_from_model( - model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + model_id, subfolder=config_folder, revision=revision, cache_dir=cache_dir, token=token ) if library_name == "timm": config = PretrainedConfig.from_pretrained( - model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + model_id, subfolder=config_folder, revision=revision, cache_dir=cache_dir, token=token ) if config is None: - if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME: - if CONFIG_NAME in os.listdir(os.path.join(model_id, subfolder)): - config = AutoConfig.from_pretrained( - os.path.join(model_id, subfolder), trust_remote_code=trust_remote_code - ) - elif CONFIG_NAME in os.listdir(model_id): + if os.path.isdir(os.path.join(model_id, config_folder)) and cls.config_name == CONFIG_NAME: + if CONFIG_NAME in os.listdir(os.path.join(model_id, config_folder)): config = AutoConfig.from_pretrained( - os.path.join(model_id, CONFIG_NAME), trust_remote_code=trust_remote_code - ) - logger.info( - f"config.json not found in the specified subfolder {subfolder}. Using the top level config.json." + os.path.join(model_id, config_folder), trust_remote_code=trust_remote_code ) else: raise OSError(f"config.json not found in {model_id} local folder") @@ -411,7 +419,7 @@ def from_pretrained( cache_dir=cache_dir, token=token, force_download=force_download, - subfolder=subfolder, + subfolder=config_folder, trust_remote_code=trust_remote_code, ) elif isinstance(config, (str, os.PathLike)): @@ -421,7 +429,7 @@ def from_pretrained( cache_dir=cache_dir, token=token, force_download=force_download, - subfolder=subfolder, + subfolder=config_folder, trust_remote_code=trust_remote_code, ) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 9b29afa566b..ce1d68536ac 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -510,13 +510,12 @@ def _from_pretrained( if file_name is None: if model_path.is_dir(): - onnx_files = list(model_path.glob("*.onnx")) + onnx_files = list((model_path / subfolder).glob("*.onnx")) else: repo_files, _ = TasksManager.get_model_files( model_id, revision=revision, cache_dir=cache_dir, token=token ) repo_files = map(Path, repo_files) - pattern = "*.onnx" if subfolder == "" else f"{subfolder}/*.onnx" onnx_files = [p for p in repo_files if p.match(pattern)] @@ -983,10 +982,9 @@ def _cached_file( token = use_auth_token model_path = Path(model_path) - # locates a file in a local folder and repo, downloads and cache it if necessary. if model_path.is_dir(): - model_cache_path = model_path / file_name + model_cache_path = model_path / subfolder / file_name preprocessors = maybe_load_preprocessors(model_path.as_posix()) else: model_cache_path = hf_hub_download( diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 665f253c480..501c7dac240 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -28,6 +28,7 @@ import requests import timm import torch +from huggingface_hub import HfApi from huggingface_hub.constants import default_cache_path from parameterized import parameterized from PIL import Image @@ -1263,6 +1264,20 @@ def test_trust_remote_code(self): torch.allclose(pt_logits, ort_logits, atol=1e-4), f" Maxdiff: {torch.abs(pt_logits - ort_logits).max()}" ) + @parameterized.expand(("", "onnx")) + def test_loading_with_config_not_from_subfolder(self, subfolder): + # config.json file in the root directory and not in the subfolder + model_id = "sentence-transformers-testing/stsb-bert-tiny-onnx" + # hub model + ORTModelForFeatureExtraction.from_pretrained(model_id, subfolder=subfolder, export=subfolder == "") + # local model + api = HfApi() + with tempfile.TemporaryDirectory() as tmpdirname: + local_dir = Path(tmpdirname) / "model" + api.snapshot_download(repo_id=model_id, local_dir=local_dir) + ORTModelForFeatureExtraction.from_pretrained(local_dir, subfolder=subfolder, export=subfolder == "") + remove_directory(tmpdirname) + class ORTModelForQuestionAnsweringIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES = [