diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index b4a23d3d..bf042c6e 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -13,11 +13,10 @@ import torch.nn.functional as F import wandb import yaml -from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin +from huggingface_hub import ModelCard, ModelCardData from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import HfApi -from huggingface_hub.utils._deprecation import _deprecate_positional_args from ocf_datapipes.batch import BatchKey from ocf_ml_metrics.evaluation.evaluation import evaluation @@ -141,33 +140,34 @@ def minimize_data_config(input_path, output_path, model): yaml.dump(config, outfile, default_flow_style=False) -class PVNetModelHubMixin(PyTorchModelHubMixin): +class PVNetModelHubMixin: """ Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities. """ @classmethod - @_deprecate_positional_args(version="0.16") - def _from_pretrained( + def from_pretrained( cls, *, model_id: str, revision: str, - cache_dir: str, - force_download: bool, - proxies: Optional[Dict], - resume_download: bool, - local_files_only: bool, - token: Union[str, bool, None], + cache_dir: Optional[Union[str, Path]] = None, + force_download: bool = False, + proxies: Optional[Dict] = None, + resume_download: Optional[bool] = None, + local_files_only: bool = False, + token: Union[str, bool, None] = None, map_location: str = "cpu", strict: bool = False, - **model_kwargs, ): """Load Pytorch pretrained weights and return the loaded model.""" + if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) + config_file = os.path.join(model_id, CONFIG_NAME) else: + # load model file model_file = hf_hub_download( repo_id=model_id, filename=PYTORCH_WEIGHTS_NAME, @@ -180,11 +180,23 @@ def _from_pretrained( local_files_only=local_files_only, ) - if "config" not in model_kwargs: - raise ValueError("Config must be supplied to instantiate model") + # load config file + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) - model_kwargs.update(model_kwargs.pop("config")) - model = hydra.utils.instantiate(model_kwargs) + model = hydra.utils.instantiate(config) state_dict = torch.load(model_file, map_location=torch.device(map_location)) model.load_state_dict(state_dict, strict=strict) # type: ignore diff --git a/pvnet/models/utils.py b/pvnet/models/utils.py index 7c8bcdba..b06086e4 100644 --- a/pvnet/models/utils.py +++ b/pvnet/models/utils.py @@ -149,7 +149,8 @@ def __init__(self, decay_rate: Optional[int] = None, forecast_length: int = 6): self.decay_rate = math.log(2) # make weights from decay rate - weights = torch.from_numpy(np.exp(-self.decay_rate * np.arange(self.forecast_length))) + weights = np.exp(-self.decay_rate * np.arange(self.forecast_length)) + weights = torch.tensor(weights) # normalized the weights, so there mean is 1. # To calculate the loss, we times the weights by the differences between truth diff --git a/tests/models/multimodal/test_from_pretrained.py b/tests/models/multimodal/test_from_pretrained.py new file mode 100644 index 00000000..65223957 --- /dev/null +++ b/tests/models/multimodal/test_from_pretrained.py @@ -0,0 +1,11 @@ +from pvnet.models.base_model import BaseModel + + +def test_from_pretrained(): + model_name = "openclimatefix/pvnet_uk_region" + model_version = "92266cd9040c590a9e90ee33eafd0e7b92548be8" + + _ = BaseModel.from_pretrained( + model_id=model_name, + revision=model_version, + )