From 85e8ee94326614b653a44d86c993c4dd8dc34ec9 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Wed, 24 Jul 2024 17:02:31 +0100 Subject: [PATCH] remove huggingface mixing --- pvnet/models/base_model.py | 42 ++++++++++++------- .../models/multimodal/test_from_pretrained.py | 4 +- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index df5c2d80..dd14fe73 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -141,33 +141,35 @@ def minimize_data_config(input_path, output_path, model): yaml.dump(config, outfile, default_flow_style=False) -class PVNetModelHubMixin(PyTorchModelHubMixin): +class PVNetModelHubMixin: """ Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities. """ @classmethod - @_deprecate_positional_args(version="0.16") - def _from_pretrained( + def from_pretrained( cls, *, model_id: str, revision: str, - cache_dir: str, - force_download: bool, - proxies: Optional[Dict], - resume_download: bool, - local_files_only: bool, - token: Union[str, bool, None], + cache_dir: Optional[Union[str, Path]] = None, + force_download: bool = False, + proxies: Optional[Dict] = None, + resume_download: Optional[bool] = None, + local_files_only: bool = False, + token: Union[str, bool, None] = None, map_location: str = "cpu", strict: bool = False, - **model_kwargs, ): """Load Pytorch pretrained weights and return the loaded model.""" + if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) + config_file = os.path.join(model_id, CONFIG_NAME) else: + + # load model file model_file = hf_hub_download( repo_id=model_id, filename=PYTORCH_WEIGHTS_NAME, @@ -180,11 +182,23 @@ def _from_pretrained( local_files_only=local_files_only, ) - if "config" in model_kwargs: - logger.debug("Removing config from model_kwargs to avoid conflicts with model init.") - model_kwargs.update(model_kwargs.pop("config")) + # load config file + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) - model = hydra.utils.instantiate(model_kwargs) + model = hydra.utils.instantiate(config) state_dict = torch.load(model_file, map_location=torch.device(map_location)) model.load_state_dict(state_dict, strict=strict) # type: ignore diff --git a/tests/models/multimodal/test_from_pretrained.py b/tests/models/multimodal/test_from_pretrained.py index 7c264ff1..65223957 100644 --- a/tests/models/multimodal/test_from_pretrained.py +++ b/tests/models/multimodal/test_from_pretrained.py @@ -3,9 +3,9 @@ def test_from_pretrained(): model_name = "openclimatefix/pvnet_uk_region" - model_version = "aa73cdafd1db8df3c8b7f5ecfdb160989e7639ac" + model_version = "92266cd9040c590a9e90ee33eafd0e7b92548be8" _ = BaseModel.from_pretrained( - model_name, + model_id=model_name, revision=model_version, )