From 860e5e08d8955241b493f2a4c3095dce05fa7167 Mon Sep 17 00:00:00 2001 From: James Fulton <41546094+dfulu@users.noreply.github.com> Date: Thu, 9 May 2024 15:49:15 +0100 Subject: [PATCH] Add first draft backtest script (#175) * first draft backtest script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactoring and tidying * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * re-add function * minor fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update backtest_uk_gsp.py * add model loading functionality * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * docs --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: peterdudfield --- pvnet/load_model.py | 70 ++++ pvnet/utils.py | 29 +- scripts/backtest_uk_gsp.py | 467 +++++++++++++++++++++++++++ scripts/checkpoint_to_huggingface.py | 61 +--- scripts/hindcast.py | 218 ------------- 5 files changed, 548 insertions(+), 297 deletions(-) create mode 100644 pvnet/load_model.py create mode 100644 scripts/backtest_uk_gsp.py delete mode 100644 scripts/hindcast.py diff --git a/pvnet/load_model.py b/pvnet/load_model.py new file mode 100644 index 00000000..ba6de363 --- /dev/null +++ b/pvnet/load_model.py @@ -0,0 +1,70 @@ +""" Load a model from its checkpoint directory """ +import glob +import os + +import hydra +import torch +from pyaml_env import parse_config + +from pvnet.models.ensemble import Ensemble +from pvnet.models.multimodal.unimodal_teacher import Model as UMTModel + + +def get_model_from_checkpoints( + checkpoint_dir_paths: list[str], + val_best: bool = True, +): + """Load a model from its checkpoint directory""" + is_ensemble = len(checkpoint_dir_paths) > 1 + + model_configs = [] + models = [] + data_configs = [] + + for path in checkpoint_dir_paths: + # Load the model + model_config = parse_config(f"{path}/model_config.yaml") + + model = hydra.utils.instantiate(model_config) + + if val_best: + # Only one epoch (best) saved per model + files = glob.glob(f"{path}/epoch*.ckpt") + if len(files) != 1: + raise ValueError( + f"Found {len(files)} checkpoints @ {path}/epoch*.ckpt. Expected one." + ) + checkpoint = torch.load(files[0], map_location="cpu") + else: + checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu") + + model.load_state_dict(state_dict=checkpoint["state_dict"]) + + if isinstance(model, UMTModel): + model, model_config = model.convert_to_multimodal_model(model_config) + + # Check for data config + data_config = f"{path}/data_config.yaml" + + if os.path.isfile(data_config): + data_configs.append(data_config) + else: + data_configs.append(None) + + model_configs.append(model_config) + models.append(model) + + if is_ensemble: + model_config = { + "_target_": "pvnet.models.ensemble.Ensemble", + "model_list": model_configs, + } + model = Ensemble(model_list=models) + data_config = data_configs[0] + + else: + model_config = model_configs[0] + model = models[0] + data_config = data_configs[0] + + return model, model_config, data_config diff --git a/pvnet/utils.py b/pvnet/utils.py index ccfa6826..d9a5ba3d 100644 --- a/pvnet/utils.py +++ b/pvnet/utils.py @@ -1,6 +1,5 @@ """Utils""" import logging -import os import warnings from collections.abc import Sequence from typing import Optional @@ -13,7 +12,6 @@ import rich.syntax import rich.tree import xarray as xr -import yaml from lightning.pytorch.loggers import Logger from lightning.pytorch.utilities import rank_zero_only from ocf_datapipes.batch import BatchKey @@ -21,26 +19,6 @@ from ocf_datapipes.utils.geospatial import osgb_to_lon_lat from omegaconf import DictConfig, OmegaConf -import pvnet - - -def load_config(config_file): - """ - Open yam configruation file, and get rid eof '_target_' line - """ - - # get full path of config file - path = os.path.dirname(pvnet.__file__) - config_file = f"{path}/../{config_file}" - - with open(config_file) as cfg: - config = yaml.load(cfg, Loader=yaml.FullLoader) - - if "_target_" in config.keys(): - config.pop("_target_") # This is only for Hydra - - return config - def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: """Initializes multi-GPU-friendly python logger.""" @@ -236,11 +214,10 @@ def finish( """Makes sure everything closed properly.""" # without this sweeps with wandb logger might crash! - for logger in loggers: - if isinstance(logger, pl.loggers.wandb.WandbLogger): - import wandb + if any([isinstance(logger, pl.loggers.wandb.WandbLogger) for logger in loggers]): + import wandb - wandb.finish() + wandb.finish() def plot_batch_forecasts( diff --git a/scripts/backtest_uk_gsp.py b/scripts/backtest_uk_gsp.py new file mode 100644 index 00000000..33d7eb62 --- /dev/null +++ b/scripts/backtest_uk_gsp.py @@ -0,0 +1,467 @@ +""" +A script to run backtest for PVNet and the summation model for UK regional and national + +Use: + +- This script uses hydra to construct the config, just like in `run.py`. So you need to make sure + that the data config is set up appropriate for the model being run in this script +- The PVNet and summation model checkpoints; the time range over which to make predictions are made; + and the output directory where the results near the top of the script as hard coded user + variables. These should be changed. + + +``` +python backtest_uk_gsp.py +``` + +""" + +try: + import torch.multiprocessing as mp + + mp.set_start_method("spawn", force=True) + mp.set_sharing_strategy("file_system") +except RuntimeError: + pass + +import logging +import os +import sys + +import hydra +import numpy as np +import pandas as pd +import torch +import xarray as xr +from ocf_datapipes.batch import ( + BatchKey, + NumpyBatch, + batch_to_tensor, + copy_batch_to_device, + stack_np_examples_into_batch, +) +from ocf_datapipes.config.load import load_yaml_configuration +from ocf_datapipes.load import OpenGSP +from ocf_datapipes.training.common import create_t0_and_loc_datapipes +from ocf_datapipes.training.pvnet import ( + _get_datapipes_dict, + construct_sliced_data_pipeline, +) +from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD +from omegaconf import DictConfig + +# TODO: Having this script rely on pvnet_app sets up a circular dependency. The function +# `preds_to_dataarray()` should probably be moved here +from pvnet_app.utils import preds_to_dataarray +from torch.utils.data import DataLoader +from torch.utils.data.datapipes.iter import IterableWrapper +from tqdm import tqdm + +from pvnet.load_model import get_model_from_checkpoints +from pvnet.utils import GSPLocationLookup + +# ------------------------------------------------------------------ +# USER CONFIGURED VARIABLES +output_dir = "/mnt/disks/backtest/test_backtest" + +# Local directory to load the PVNet checkpoint from. By default this should pull the best performing +# checkpoint on the val set +model_chckpoint_dir = "/home/jamesfulton/repos/PVNet/checkpoints/kqaknmuc" + +# Local directory to load the summation model checkpoint from. By default this should pull the best +# performing checkpoint on the val set. If set to None a simple sum is used instead +summation_chckpoint_dir = ( + "/home/jamesfulton/repos/PVNet_summation/checkpoints/pvnet_summation/nw673nw2" +) + +# Forecasts will be made for all available init times between these +start_datetime = "2022-05-08 00:00" +end_datetime = "2022-05-08 00:30" + +# ------------------------------------------------------------------ +# SET UP LOGGING + +logger = logging.getLogger(__name__) +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +# ------------------------------------------------------------------ +# DERIVED VARIABLES + +# This will run on GPU if it exists +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# ------------------------------------------------------------------ +# GLOBAL VARIABLES + +# The frequency of the GSP data +FREQ_MINS = 30 + +# When sun as elevation below this, the forecast is set to zero +MIN_DAY_ELEVATION = 0 + +# All regional GSP IDs - not including national which is treated separately +ALL_GSP_IDS = np.arange(1, 318) + +# ------------------------------------------------------------------ +# FUNCTIONS + + +def get_gsp_ds(config_path: str) -> xr.Dataset: + """Load GSP data from the path in the data config. + + Args: + config_path: Path to the data configuration file + + Returns: + xarray.Dataset of PVLive truths and capacities + """ + + config = load_yaml_configuration(config_path) + gsp_datapipe = OpenGSP(gsp_pv_power_zarr_path=config.input_data.gsp.gsp_zarr_path) + ds_gsp = next(iter(gsp_datapipe)) + + return ds_gsp + + +def get_available_t0_times(start_datetime, end_datetime, config_path): + """Filter a list of t0 init-times to those for which all required input data is available. + + Args: + start_datetime: First potential t0 time + end_datetime: Last potential t0 time + config_path: Path to data config file + + Returns: + pandas.DatetimeIndex of the init-times available for required inputs + """ + + start_datetime = pd.Timestamp(start_datetime) + end_datetime = pd.Timestamp(end_datetime) + # Open all the input data so we can check what of the potential data init times we have input + # data for + datapipes_dict = _get_datapipes_dict(config_path, production=False) + + # Pop out the config file + config = datapipes_dict.pop("config") + + # We are going to abuse the `create_t0_and_loc_datapipes()` function to find the init-times in + # potential_init_times which we have input data for. To do this, we will feed in some fake GSP + # data which has the potential_init_times as timestamps. This is a bit hacky but works for now + + # Set up init-times we would like to make predictions for + potential_init_times = pd.date_range(start_datetime, end_datetime, freq=f"{FREQ_MINS}min") + + # We buffer the potential init-times so that we don't lose any init-times from the + # start and end. Again this is a hacky step + history_duration = pd.Timedelta(config.input_data.gsp.history_minutes, "min") + forecast_duration = pd.Timedelta(config.input_data.gsp.forecast_minutes, "min") + buffered_potential_init_times = pd.date_range( + start_datetime - history_duration, end_datetime + forecast_duration, freq=f"{FREQ_MINS}min" + ) + + ds_fake_gsp = buffered_potential_init_times.to_frame().to_xarray().rename({"index": "time_utc"}) + ds_fake_gsp = ds_fake_gsp.rename({0: "gsp_pv_power_mw"}) + ds_fake_gsp = ds_fake_gsp.expand_dims("gsp_id", axis=1) + ds_fake_gsp = ds_fake_gsp.assign_coords( + gsp_id=[0], + x_osgb=("gsp_id", [0]), + y_osgb=("gsp_id", [0]), + ) + ds_fake_gsp = ds_fake_gsp.gsp_pv_power_mw.astype(float) * 1e-18 + + # Overwrite the GSP data which is already in the datapipes dict + datapipes_dict["gsp"] = IterableWrapper([ds_fake_gsp]) + + # Use create_t0_and_loc_datapipes to get datapipe of init-times + location_pipe, t0_datapipe = create_t0_and_loc_datapipes( + datapipes_dict, + configuration=config, + key_for_t0="gsp", + shuffle=False, + ) + + # Create a full list of available init-times. Note that we need to loop over the t0s AND + # locations to avoid the torch datapipes buffer overflow but we don't actually use the location + available_init_times = [t0 for _, t0 in zip(location_pipe, t0_datapipe)] + available_init_times = pd.to_datetime(available_init_times) + + logger.info( + f"{len(available_init_times)} out of {len(potential_init_times)} " + "requested init-times have required input data" + ) + + return available_init_times + + +def get_loctimes_datapipes(config_path): + """Create location and init-time datapipes + + Args: + config_path: Path to data config file + + Returns: + tuple: A tuple of datapipes + - Datapipe yielding locations + - Datapipe yielding init-times + """ + + # Set up ID location query object + ds_gsp = get_gsp_ds(config_path) + gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) + + # Filter the init-times to times we have all input data for + available_target_times = get_available_t0_times( + start_datetime, + end_datetime, + config_path, + ) + num_t0s = len(available_target_times) + + # Save the init-times which predictions are being made for. This is really helpful to check + # whilst the backtest is running since it takes a long time. This lets you see what init-times + # the backtest will end up producing + available_target_times.to_frame().to_csv(f"{output_dir}/t0_times.csv") + + # Cycle the GSP locations + location_pipe = IterableWrapper([[gsp_id_to_loc(gsp_id) for gsp_id in ALL_GSP_IDS]]).repeat( + num_t0s + ) + + # Shard and then unbatch the locations so that each worker will generate all samples for all + # GSPs and for a single init-time + location_pipe = location_pipe.sharding_filter() + location_pipe = location_pipe.unbatch(unbatch_level=1) + + # Create times datapipe so each worker receives 317 copies of the same datetime for its batch + t0_datapipe = IterableWrapper([[t0 for gsp_id in ALL_GSP_IDS] for t0 in available_target_times]) + t0_datapipe = t0_datapipe.sharding_filter() + t0_datapipe = t0_datapipe.unbatch(unbatch_level=1) + + t0_datapipe = t0_datapipe.set_length(num_t0s * len(ALL_GSP_IDS)) + location_pipe = location_pipe.set_length(num_t0s * len(ALL_GSP_IDS)) + + return location_pipe, t0_datapipe + + +class ModelPipe: + """A class to conveniently make and process predictions from batches""" + + def __init__(self, model, summation_model, ds_gsp: xr.Dataset): + """A class to conveniently make and process predictions from batches + + Args: + model: PVNet GSP level model + summation_model: Summation model to make national forecast from GSP level forecasts + ds_gsp:xarray dataset of PVLive true values and capacities + """ + self.model = model + self.summation_model = summation_model + self.ds_gsp = ds_gsp + + def predict_batch(self, batch: NumpyBatch) -> xr.Dataset: + """Run the batch through the model and compile the predictions into an xarray DataArray + + Args: + batch: A batch of samples with inputs for each GSP for the same init-time + + Returns: + xarray.Dataset of all GSP and national forecasts for the batch + """ + + # Unpack some variables from the batch + id0 = batch[BatchKey.gsp_t0_idx] + t0 = batch[BatchKey.gsp_time_utc].cpu().numpy().astype("datetime64[s]")[0, id0] + n_valid_times = len(batch[BatchKey.gsp_time_utc][0, id0 + 1 :]) + ds_gsp = self.ds_gsp + model = self.model + summation_model = self.summation_model + + # Get valid times for this forecast + valid_times = pd.to_datetime( + [t0 + np.timedelta64((i + 1) * FREQ_MINS, "m") for i in range(n_valid_times)] + ) + + # Get effective capacities for this forecast + gsp_capacities = ds_gsp.effective_capacity_mwp.sel( + time_utc=t0, gsp_id=slice(1, None) + ).values + national_capacity = ds_gsp.effective_capacity_mwp.sel(time_utc=t0, gsp_id=0).item() + + # Get the solar elevations. We need to un-normalise these from the values in the batch + elevation = batch[BatchKey.gsp_solar_elevation] * ELEVATION_STD + ELEVATION_MEAN + # We only need elevation mask for forecasted values, not history + elevation = elevation[:, id0 + 1 :] + + # Make mask dataset for sundown + da_sundown_mask = xr.DataArray( + data=elevation < MIN_DAY_ELEVATION, + dims=["gsp_id", "target_datetime_utc"], + coords=dict( + gsp_id=ALL_GSP_IDS, + target_datetime_utc=valid_times, + ), + ) + + with torch.no_grad(): + # Run batch through model to get 0-1 predictions for all GSPs + device_batch = copy_batch_to_device(batch_to_tensor(batch), device) + y_normed_gsp = model(device_batch).detach().cpu().numpy() + + da_normed_gsp = preds_to_dataarray(y_normed_gsp, model, valid_times, ALL_GSP_IDS) + + # Multiply normalised forecasts by capacities and clip negatives + da_abs_gsp = da_normed_gsp.clip(0, None) * gsp_capacities[:, None, None] + + # Apply sundown mask + da_abs_gsp = da_abs_gsp.where(~da_sundown_mask).fillna(0.0) + + # Make national predictions using summation model + if summation_model is not None: + with torch.no_grad(): + # Construct sample for the summation model + summation_inputs = { + "pvnet_outputs": torch.Tensor(y_normed_gsp[np.newaxis]).to(device), + "effective_capacity": ( + torch.Tensor(gsp_capacities / national_capacity) + .to(device) + .unsqueeze(0) + .unsqueeze(-1) + ), + } + + # Run batch through the summation model + y_normed_national = ( + summation_model(summation_inputs).detach().squeeze().cpu().numpy() + ) + + # Convert national predictions to DataArray + da_normed_national = preds_to_dataarray( + y_normed_national[np.newaxis], summation_model, valid_times, gsp_ids=[0] + ) + + # Multiply normalised forecasts by capacities and clip negatives + da_abs_national = da_normed_national.clip(0, None) * national_capacity + + # Apply sundown mask - All GSPs must be masked to mask national + da_abs_national = da_abs_national.where(~da_sundown_mask.all(dim="gsp_id")).fillna(0.0) + + # If no summation model, make national predictions using simple sum + else: + da_abs_national = ( + da_abs_gsp.sum(dim="gsp_id") + .expand_dims(dim="gsp_id", axis=0) + .assign_coords(gsp_id=[0]) + ) + + # Concat the regional GSP and national predictions + da_abs_all = xr.concat([da_abs_national, da_abs_gsp], dim="gsp_id") + ds_abs_all = da_abs_all.to_dataset(name="hindcast") + + ds_abs_all = ds_abs_all.expand_dims(dim="init_time_utc", axis=0).assign_coords( + init_time_utc=[t0] + ) + + return ds_abs_all + + +def get_datapipe(config_path: str) -> NumpyBatch: + """Construct datapipe yielding batches of concurrent samples for all GSPs + + Args: + config_path: Path to the data configuration file + + Returns: + NumpyBatch: Concurrent batch of samples for each GSP + """ + + # Construct location and init-time datapipes + location_pipe, t0_datapipe = get_loctimes_datapipes(config_path) + + # Get the number of init-times + num_batches = len(t0_datapipe) // len(ALL_GSP_IDS) + + # Construct sample datapipes + data_pipeline = construct_sliced_data_pipeline( + config_path, + location_pipe, + t0_datapipe, + ) + + # Batch so that each worker returns a batch of all locations for a single init-time + # Also convert to tensor for model + data_pipeline = ( + data_pipeline.batch(len(ALL_GSP_IDS)).map(stack_np_examples_into_batch).map(batch_to_tensor) + ) + + data_pipeline = data_pipeline.set_length(num_batches) + + return data_pipeline + + +@hydra.main(config_path="../configs", config_name="config.yaml", version_base="1.2") +def main(config: DictConfig): + """Runs the backtest""" + + dataloader_kwargs = dict( + shuffle=False, + batch_size=None, + sampler=None, + batch_sampler=None, + # Number of workers set in the config file + num_workers=config.datamodule.num_workers, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=config.datamodule.prefetch_factor, + persistent_workers=False, + ) + + # Set up output dir + os.makedirs(output_dir) + + # Create concurrent batch datapipe + # Each batch includes a sample for each of the 317 GSPs for a single init-time + batch_pipe = get_datapipe(config.datamodule.configuration) + num_batches = len(batch_pipe) + + # Load the GSP data as an xarray object + ds_gsp = get_gsp_ds(config.datamodule.configuration) + + # Create a dataloader for the concurrent batches and use multiprocessing + dataloader = DataLoader(batch_pipe, **dataloader_kwargs) + + # Load the PVNet model and summation model + model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True) + model = model.eval().to(device) + if summation_chckpoint_dir is None: + summation_model = None + else: + summation_model, *_ = get_model_from_checkpoints([summation_chckpoint_dir], val_best=True) + summation_model = summation_model.eval().to(device) + + # Create object to make predictions for each input batch + model_pipe = ModelPipe(model, summation_model, ds_gsp) + + # Loop through the batches + pbar = tqdm(total=num_batches) + for i, batch in zip(range(num_batches), dataloader): + # Make predictions for the init-time + ds_abs_all = model_pipe.predict_batch(batch) + + t0 = ds_abs_all.init_time_utc.values[0] + + # Save the predictioons + filename = f"{output_dir}/{t0}.nc" + ds_abs_all.to_netcdf(filename) + + pbar.update() + + # Close down + pbar.close() + del dataloader + + +if __name__ == "__main__": + main() diff --git a/scripts/checkpoint_to_huggingface.py b/scripts/checkpoint_to_huggingface.py index dadaaf96..6ad2ec81 100644 --- a/scripts/checkpoint_to_huggingface.py +++ b/scripts/checkpoint_to_huggingface.py @@ -5,19 +5,16 @@ --local-path="~/tmp/this_model" \ --no-push-to-hub """ -import glob -import os + import tempfile from typing import Optional -import hydra -import torch import typer import wandb -from pyaml_env import parse_config -from pvnet.models.ensemble import Ensemble -from pvnet.models.multimodal.unimodal_teacher import Model as UMTModel +from pvnet.load_model import get_model_from_checkpoints + +wandb_repo = "openclimatefix/pvnet2.1" def push_to_huggingface( @@ -39,13 +36,11 @@ def push_to_huggingface( assert push_to_hub or local_path is not None - os.path.dirname(os.path.abspath(__file__)) - is_ensemble = len(checkpoint_dir_paths) > 1 # Check if checkpoint dir name is wandb run ID if wandb_ids == []: - all_wandb_ids = [run.id for run in wandb.Api().runs(path="openclimatefix/pvnet2.1")] + all_wandb_ids = [run.id for run in wandb.Api().runs(path=wandb_repo)] for path in checkpoint_dir_paths: dirname = path.split("/")[-1] if dirname in all_wandb_ids: @@ -53,49 +48,9 @@ def push_to_huggingface( else: wandb_ids.append(None) - model_configs = [] - models = [] - data_configs = [] - - for path in checkpoint_dir_paths: - # Load the model - model_config = parse_config(f"{path}/model_config.yaml") - - model = hydra.utils.instantiate(model_config) - - if val_best: - # Only one epoch (best) saved per model - files = glob.glob(f"{path}/epoch*.ckpt") - assert len(files) == 1 - checkpoint = torch.load(files[0], map_location="cpu") - else: - checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu") + model, model_config, data_config = get_model_from_checkpoints(checkpoint_dir_paths, val_best) - model.load_state_dict(state_dict=checkpoint["state_dict"]) - - if isinstance(model, UMTModel): - model, model_config = model.convert_to_multimodal_model(model_config) - - # Check for data config - data_config = f"{path}/data_config.yaml" - assert os.path.isfile(data_config) - - model_configs.append(model_config) - models.append(model) - data_configs.append(data_config) - - if is_ensemble: - model_config = { - "_target_": "pvnet.models.ensemble.Ensemble", - "model_list": model_configs, - } - model = Ensemble(model_list=models) - data_config = data_configs[0] - - else: - model_config = model_configs[0] - model = models[0] - data_config = data_configs[0] + if not is_ensemble: wandb_ids = wandb_ids[0] # Push to hub @@ -111,7 +66,7 @@ def push_to_huggingface( data_config=data_config, wandb_ids=wandb_ids, push_to_hub=push_to_hub, - repo_id="openclimatefix/pvnet_v2" if push_to_hub else None, + repo_id=wandb_repo if push_to_hub else None, ) if local_path is None: diff --git a/scripts/hindcast.py b/scripts/hindcast.py deleted file mode 100644 index fd0951e4..00000000 --- a/scripts/hindcast.py +++ /dev/null @@ -1,218 +0,0 @@ -"""Script to run hindcasts for a given PVNet model on dates from 2021""" - -import logging -import os -import warnings -from datetime import timedelta -from functools import reduce - -import numpy as np -import pandas as pd -import torch -import xarray as xr -from ocf_datapipes.batch import BatchKey -from ocf_datapipes.load import OpenGSP -from ocf_datapipes.training.pvnet import construct_sliced_data_pipeline -from ocf_datapipes.utils.utils import stack_np_examples_into_batch -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService -from torchdata.datapipes.iter import IterableWrapper -from tqdm import tqdm - -from pvnet.data.datamodule import batch_to_tensor, copy_batch_to_device -from pvnet.models.base_model import BaseModel -from pvnet.utils import GSPLocationLookup - - -def get_dataloader_for_loctimes(loc_list, t0_list, num_workers=0, batch_size=None): - """Get the datalolader for given""" - batch_size = len(loc_list) if batch_size is None else batch_size - - readingservice_config = dict( - num_workers=num_workers, - multiprocessing_context="spawn", - worker_prefetch_cnt=0 if num_workers == 0 else 2, - ) - - # This iterates though all times for loc_list[0] before moving on to loc_list[1] - # This stops us wasting time if some timestamp in the day has missing values - location_pipe = IterableWrapper(loc_list).repeat(len(t0_list)) - t0_datapipe = IterableWrapper(t0_list).cycle(len(loc_list)) - - location_pipe = location_pipe.sharding_filter() - t0_datapipe = t0_datapipe.sharding_filter() - - batch_pipe = construct_sliced_data_pipeline( - config_filename="../configs/datamodule/configuration/gcp_configuration.yaml", - location_pipe=location_pipe, - t0_datapipe=t0_datapipe, - block_sat=False, - block_nwp=False, - ) - - batch_pipe = ( - batch_pipe.batch(batch_size=batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - - rs = MultiProcessingReadingService(**readingservice_config) - dataloader = DataLoader2(batch_pipe, reading_service=rs) - return dataloader - - -def save_date_preds(x, date, path_root): - """Save the predictions for date to zarr""" - a = np.zeros((len(x.keys()), 317, 16)) - - for i, k in enumerate(list(x.keys())): - df = pd.DataFrame(x[k])[np.arange(1, 318)] - a[i] = df.values.T - - ds = xr.DataArray( - a, - dims=["t0_time", "gsp_id", "step"], - coords=dict(t0_time=list(x.keys()), gsp_id=np.arange(1, 318), step=np.arange(16)), - ).to_dataset(name="preds") - ds.to_zarr(f"{path_root}/{date}.zarr") - - -if __name__ == "__main__": - # --------------------------------------------------------------------------- - # User params - - path_root = "/mnt/disks/batches2/may-july_hindcast" - model_name = "openclimatefix/pvnet_v2" - model_version = "7cc7e9f8e5fc472a753418c45b2af9f123547b6c" - - # We assume we only ever look at 2021 - # If set to None this sctipt will find all days where we have input data - # date_range = ("2021-01-01", "2021-01-02") - will give 2 days - date_range = ("2021-04-19", "2021-06-19") - - gsp_ids = np.arange(1, 318) - - times = [timedelta(minutes=i * 30) for i in range(48)] - batch_size = 10 - num_workers = 20 - - # --------------------------------------------------------------------------- - # Initial set-up - - logger = logging.getLogger(__name__) - logger.setLevel(logging.INFO) - logger.info("Setting up") - - # Set up directory first in case of path already existing error - os.makedirs(path_root, exist_ok=False) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - ds_gsp = next(iter(OpenGSP("gs://solar-pv-nowcasting-data/PV/GSP/v5/pv_gsp.zarr"))) - - # Set up ID location query object - gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) - - location_list = [gsp_id_to_loc(gsp_id) for gsp_id in gsp_ids] - - # --------------------------------------------------------------------------- - # Construct array of dates - - logger.info("Constructing date arrays") - - if date_range is None: - ds_nwp = xr.open_zarr("/mnt/disks/nwp/UKV_intermediate_version_7.zarr") - - ds_sat = xr.open_zarr("/mnt/disks/data_ssd/2021_nonhrv.zarr") - - potential_dates = reduce( - np.intersect1d, - [ - np.unique(ds_nwp.init_time.dt.date), - np.unique(ds_sat.time.dt.date), - np.unique(ds_gsp.time_utc.dt.date), - ], - ) - del ds_sat, ds_nwp - - else: - potential_dates = pd.date_range(*date_range, freq=timedelta(days=1)).date - - # --------------------------------------------------------------------------- - # Load the model - - logger.info("Loading model") - - model = BaseModel.from_pretrained(model_name, revision=model_version) - model = model.to(device) - model = model.eval() - - # --------------------------------------------------------------------------- - # Run - - logger.info("Beginning hindcasts") - - pbar = tqdm(total=len(potential_dates) * len(times) * len(location_list)) - - # Expected n on pbar after next date iteration - # Store this in case some dates fail. Allows pbar to be kept up to dat regardless of failure on - # some dates. - pbar_n = 0 - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - for date in potential_dates: - pbar_n += len(location_list) * len(times) - - t0_list = [np.datetime64(pd.Timestamp(date) + dt, "s") for dt in times] - - date_preds = {t0: dict() for t0 in t0_list} - - dataloader = get_dataloader_for_loctimes( - loc_list=location_list, - t0_list=t0_list, - num_workers=num_workers, - batch_size=batch_size, - ) - # We lump all times in a day together. We either complete the entire day of forecasts or - # fail on the entire day - try: - for i, batch in enumerate(dataloader): - with torch.no_grad(): - preds = model(copy_batch_to_device(batch, device)).detach().cpu().numpy() - - batch_times = ( - (batch[BatchKey.gsp_time_utc][:, batch[BatchKey.gsp_t0_idx]]) - .numpy() - .astype("datetime64[s]") - ) - - for id, pred, time in zip(batch[BatchKey.gsp_id], preds, batch_times): - if id in date_preds[time]: - logger.warning( - f"ID {id} already exists in entry for datetime {time}" - ) - date_preds[time][id.item()] = pred - pbar.update() - should_save = True - except Exception: - logger.exception(f"Date: {date} failed") - # Round up the progress bar - pbar.update(pbar_n - pbar.n) - should_save = False - - # This gives a hacky way to stop this program. Deleting the output dir will cause it to - # error out - if should_save: - save_date_preds(date_preds, date, path_root=path_root) - - pbar.close() - # --------------------------------------------------------------------------- - # Consolidate up all the zarr stores - ds = xr.open_mfdataset(f"{path_root}/*.zarr", engine="zarr").compute() - ds.to_zarr(f"{path_root}.zarr") - - os.system(f"rm -r {path_root}/*.zarr") - os.system(f"rmdir {path_root}") - - logger.info("Hindcast complete")