Skip to content

Commit

Permalink
Add first draft backtest script (#175)
Browse files Browse the repository at this point in the history
* first draft backtest script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactoring and tidying

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* re-add function

* minor fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update backtest_uk_gsp.py

* add model loading functionality

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* docs

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: peterdudfield <[email protected]>
  • Loading branch information
3 people authored May 9, 2024
1 parent 7db4e74 commit 860e5e0
Show file tree
Hide file tree
Showing 5 changed files with 548 additions and 297 deletions.
70 changes: 70 additions & 0 deletions pvnet/load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
""" Load a model from its checkpoint directory """
import glob
import os

import hydra
import torch
from pyaml_env import parse_config

from pvnet.models.ensemble import Ensemble
from pvnet.models.multimodal.unimodal_teacher import Model as UMTModel


def get_model_from_checkpoints(
checkpoint_dir_paths: list[str],
val_best: bool = True,
):
"""Load a model from its checkpoint directory"""
is_ensemble = len(checkpoint_dir_paths) > 1

model_configs = []
models = []
data_configs = []

for path in checkpoint_dir_paths:
# Load the model
model_config = parse_config(f"{path}/model_config.yaml")

model = hydra.utils.instantiate(model_config)

if val_best:
# Only one epoch (best) saved per model
files = glob.glob(f"{path}/epoch*.ckpt")
if len(files) != 1:
raise ValueError(
f"Found {len(files)} checkpoints @ {path}/epoch*.ckpt. Expected one."
)
checkpoint = torch.load(files[0], map_location="cpu")
else:
checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu")

model.load_state_dict(state_dict=checkpoint["state_dict"])

if isinstance(model, UMTModel):
model, model_config = model.convert_to_multimodal_model(model_config)

# Check for data config
data_config = f"{path}/data_config.yaml"

if os.path.isfile(data_config):
data_configs.append(data_config)
else:
data_configs.append(None)

model_configs.append(model_config)
models.append(model)

if is_ensemble:
model_config = {
"_target_": "pvnet.models.ensemble.Ensemble",
"model_list": model_configs,
}
model = Ensemble(model_list=models)
data_config = data_configs[0]

else:
model_config = model_configs[0]
model = models[0]
data_config = data_configs[0]

return model, model_config, data_config
29 changes: 3 additions & 26 deletions pvnet/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Utils"""
import logging
import os
import warnings
from collections.abc import Sequence
from typing import Optional
Expand All @@ -13,34 +12,13 @@
import rich.syntax
import rich.tree
import xarray as xr
import yaml
from lightning.pytorch.loggers import Logger
from lightning.pytorch.utilities import rank_zero_only
from ocf_datapipes.batch import BatchKey
from ocf_datapipes.utils import Location
from ocf_datapipes.utils.geospatial import osgb_to_lon_lat
from omegaconf import DictConfig, OmegaConf

import pvnet


def load_config(config_file):
"""
Open yam configruation file, and get rid eof '_target_' line
"""

# get full path of config file
path = os.path.dirname(pvnet.__file__)
config_file = f"{path}/../{config_file}"

with open(config_file) as cfg:
config = yaml.load(cfg, Loader=yaml.FullLoader)

if "_target_" in config.keys():
config.pop("_target_") # This is only for Hydra

return config


def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger."""
Expand Down Expand Up @@ -236,11 +214,10 @@ def finish(
"""Makes sure everything closed properly."""

# without this sweeps with wandb logger might crash!
for logger in loggers:
if isinstance(logger, pl.loggers.wandb.WandbLogger):
import wandb
if any([isinstance(logger, pl.loggers.wandb.WandbLogger) for logger in loggers]):
import wandb

wandb.finish()
wandb.finish()


def plot_batch_forecasts(
Expand Down
Loading

0 comments on commit 860e5e0

Please sign in to comment.