Skip to content

Commit

Permalink
Merge commit 'd96be19c75457375b719c9baf097ee1fd40170eb' into fix-pypi…
Browse files Browse the repository at this point in the history
…-release
  • Loading branch information
peterdudfield committed May 9, 2024
2 parents 0944d44 + d96be19 commit 4d11849
Show file tree
Hide file tree
Showing 7 changed files with 550 additions and 299 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[bumpversion]
commit = True
tag = True
current_version = 3.0.31
current_version = 3.0.32
message = Bump version: {current_version} → {new_version} [skip ci]

[bumpversion:file:pvnet/__init__.py]
Expand Down
2 changes: 1 addition & 1 deletion pvnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""PVNet"""
__version__ = "3.0.31"
__version__ = "3.0.32"
70 changes: 70 additions & 0 deletions pvnet/load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
""" Load a model from its checkpoint directory """
import glob
import os

import hydra
import torch
from pyaml_env import parse_config

from pvnet.models.ensemble import Ensemble
from pvnet.models.multimodal.unimodal_teacher import Model as UMTModel


def get_model_from_checkpoints(
checkpoint_dir_paths: list[str],
val_best: bool = True,
):
"""Load a model from its checkpoint directory"""
is_ensemble = len(checkpoint_dir_paths) > 1

model_configs = []
models = []
data_configs = []

for path in checkpoint_dir_paths:
# Load the model
model_config = parse_config(f"{path}/model_config.yaml")

model = hydra.utils.instantiate(model_config)

if val_best:
# Only one epoch (best) saved per model
files = glob.glob(f"{path}/epoch*.ckpt")
if len(files) != 1:
raise ValueError(
f"Found {len(files)} checkpoints @ {path}/epoch*.ckpt. Expected one."
)
checkpoint = torch.load(files[0], map_location="cpu")
else:
checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu")

model.load_state_dict(state_dict=checkpoint["state_dict"])

if isinstance(model, UMTModel):
model, model_config = model.convert_to_multimodal_model(model_config)

# Check for data config
data_config = f"{path}/data_config.yaml"

if os.path.isfile(data_config):
data_configs.append(data_config)
else:
data_configs.append(None)

model_configs.append(model_config)
models.append(model)

if is_ensemble:
model_config = {
"_target_": "pvnet.models.ensemble.Ensemble",
"model_list": model_configs,
}
model = Ensemble(model_list=models)
data_config = data_configs[0]

else:
model_config = model_configs[0]
model = models[0]
data_config = data_configs[0]

return model, model_config, data_config
29 changes: 3 additions & 26 deletions pvnet/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Utils"""
import logging
import os
import warnings
from collections.abc import Sequence
from typing import Optional
Expand All @@ -13,34 +12,13 @@
import rich.syntax
import rich.tree
import xarray as xr
import yaml
from lightning.pytorch.loggers import Logger
from lightning.pytorch.utilities import rank_zero_only
from ocf_datapipes.batch import BatchKey
from ocf_datapipes.utils import Location
from ocf_datapipes.utils.geospatial import osgb_to_lon_lat
from omegaconf import DictConfig, OmegaConf

import pvnet


def load_config(config_file):
"""
Open yam configruation file, and get rid eof '_target_' line
"""

# get full path of config file
path = os.path.dirname(pvnet.__file__)
config_file = f"{path}/../{config_file}"

with open(config_file) as cfg:
config = yaml.load(cfg, Loader=yaml.FullLoader)

if "_target_" in config.keys():
config.pop("_target_") # This is only for Hydra

return config


def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger."""
Expand Down Expand Up @@ -236,11 +214,10 @@ def finish(
"""Makes sure everything closed properly."""

# without this sweeps with wandb logger might crash!
for logger in loggers:
if isinstance(logger, pl.loggers.wandb.WandbLogger):
import wandb
if any([isinstance(logger, pl.loggers.wandb.WandbLogger) for logger in loggers]):
import wandb

wandb.finish()
wandb.finish()


def plot_batch_forecasts(
Expand Down
Loading

0 comments on commit 4d11849

Please sign in to comment.