diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 081e9bc5..9faff47d 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -13,7 +13,7 @@ import torch.nn.functional as F import wandb import yaml -from huggingface_hub import ModelCard, ModelCardData +from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import HfApi @@ -140,7 +140,7 @@ def minimize_data_config(input_path, output_path, model): yaml.dump(config, outfile, default_flow_style=False) -class PVNetModelHubMixin: +class PVNetModelHubMixin(PyTorchModelHubMixin): """ Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities. """