Skip to content

Commit

Permalink
add test from loading from pre-trained (#238)
Browse files Browse the repository at this point in the history
* add test from loading from pre-trained

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* remove config in model kwargs - optional

* use more recent models

* remove huggingface mixing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
peterdudfield and pre-commit-ci[bot] authored Jul 25, 2024
1 parent 4473d33 commit e3427e5
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 17 deletions.
44 changes: 28 additions & 16 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
import torch.nn.functional as F
import wandb
import yaml
from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin
from huggingface_hub import ModelCard, ModelCardData
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.hf_api import HfApi
from huggingface_hub.utils._deprecation import _deprecate_positional_args
from ocf_datapipes.batch import BatchKey
from ocf_ml_metrics.evaluation.evaluation import evaluation

Expand Down Expand Up @@ -141,33 +140,34 @@ def minimize_data_config(input_path, output_path, model):
yaml.dump(config, outfile, default_flow_style=False)


class PVNetModelHubMixin(PyTorchModelHubMixin):
class PVNetModelHubMixin:
"""
Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities.
"""

@classmethod
@_deprecate_positional_args(version="0.16")
def _from_pretrained(
def from_pretrained(
cls,
*,
model_id: str,
revision: str,
cache_dir: str,
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Union[str, bool, None],
cache_dir: Optional[Union[str, Path]] = None,
force_download: bool = False,
proxies: Optional[Dict] = None,
resume_download: Optional[bool] = None,
local_files_only: bool = False,
token: Union[str, bool, None] = None,
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
):
"""Load Pytorch pretrained weights and return the loaded model."""

if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
config_file = os.path.join(model_id, CONFIG_NAME)
else:
# load model file
model_file = hf_hub_download(
repo_id=model_id,
filename=PYTORCH_WEIGHTS_NAME,
Expand All @@ -180,11 +180,23 @@ def _from_pretrained(
local_files_only=local_files_only,
)

if "config" not in model_kwargs:
raise ValueError("Config must be supplied to instantiate model")
# load config file
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)

with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)

model_kwargs.update(model_kwargs.pop("config"))
model = hydra.utils.instantiate(model_kwargs)
model = hydra.utils.instantiate(config)

state_dict = torch.load(model_file, map_location=torch.device(map_location))
model.load_state_dict(state_dict, strict=strict) # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion pvnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def __init__(self, decay_rate: Optional[int] = None, forecast_length: int = 6):
self.decay_rate = math.log(2)

# make weights from decay rate
weights = torch.from_numpy(np.exp(-self.decay_rate * np.arange(self.forecast_length)))
weights = np.exp(-self.decay_rate * np.arange(self.forecast_length))
weights = torch.tensor(weights)

# normalized the weights, so there mean is 1.
# To calculate the loss, we times the weights by the differences between truth
Expand Down
11 changes: 11 additions & 0 deletions tests/models/multimodal/test_from_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from pvnet.models.base_model import BaseModel


def test_from_pretrained():
model_name = "openclimatefix/pvnet_uk_region"
model_version = "92266cd9040c590a9e90ee33eafd0e7b92548be8"

_ = BaseModel.from_pretrained(
model_id=model_name,
revision=model_version,
)

0 comments on commit e3427e5

Please sign in to comment.