diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index 47e23fd7..b502113a 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -1,57 +1,116 @@ """ Data module for pytorch lightning """ - -import resource +from glob import glob import torch -from ocf_datapipes.batch import batch_to_tensor, stack_np_examples_into_batch -from ocf_datapipes.training.pvnet import pvnet_datapipe -from torch.utils.data.datapipes.iter import FileLister +from lightning.pytorch import LightningDataModule +from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset +from ocf_datapipes.batch import ( + NumpyBatch, + TensorBatch, + batch_to_tensor, + stack_np_examples_into_batch, +) +from torch.utils.data import DataLoader, Dataset + + +class NumpybatchPremadeSamplesDataset(Dataset): + """Dataset to load NumpyBatch samples""" + + def __init__(self, sample_dir): + """Dataset to load NumpyBatch samples + + Args: + sample_dir: Path to the directory of pre-saved samples. + """ + self.sample_paths = glob(f"{sample_dir}/*.pt") -from pvnet.data.base import BaseDataModule + def __len__(self): + return len(self.sample_paths) -rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) -resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) + def __getitem__(self, idx): + return torch.load(self.sample_paths[idx]) -class DataModule(BaseDataModule): +def collate_fn(samples: list[NumpyBatch]) -> TensorBatch: + """Convert a list of NumpyBatch samples to a tensor batch""" + return batch_to_tensor(stack_np_examples_into_batch(samples)) + + +class DataModule(LightningDataModule): """Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`.""" - def _get_datapipe(self, start_time, end_time): - data_pipeline = pvnet_datapipe( - self.configuration, - start_time=start_time, - end_time=end_time, - ) + def __init__( + self, + configuration: str | None = None, + sample_dir: str | None = None, + batch_size: int = 16, + num_workers: int = 0, + prefetch_factor: int | None = None, + train_period: list[str | None] = [None, None], + val_period: list[str | None] = [None, None], + ): + """Datamodule for training pvnet architecture. + + Can also be used with pre-made batches if `sample_dir` is set. + + Args: + configuration: Path to datapipe configuration file. + sample_dir: Path to the directory of pre-saved samples. Cannot be used together with + `configuration` or '[train/val]_period'. + batch_size: Batch size. + num_workers: Number of workers to use in multiprocess batch loading. + prefetch_factor: Number of data will be prefetched at the end of each worker process. + train_period: Date range filter for train dataloader. + val_period: Date range filter for val dataloader. - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) + """ + super().__init__() + + if not ((sample_dir is not None) ^ (configuration is not None)): + raise ValueError("Exactly one of `sample_dir` or `configuration` must be set.") + + if sample_dir is not None: + if any([period != [None, None] for period in [train_period, val_period]]): + raise ValueError("Cannot set `(train/val)_period` with presaved samples") + + self.configuration = configuration + self.sample_dir = sample_dir + self.train_period = train_period + self.val_period = val_period + + self._common_dataloader_kwargs = dict( + batch_size=batch_size, + sampler=None, + batch_sampler=None, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=prefetch_factor, + persistent_workers=False, ) - return data_pipeline - - def _get_premade_batches_datapipe(self, subdir, shuffle=False): - data_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False) - if shuffle: - data_pipeline = ( - data_pipeline.shuffle(buffer_size=10_000) - .sharding_filter() - .map(torch.load) - # Split the batches and reshuffle them to be combined into new batches - .split_batches() - .shuffle(buffer_size=self.shuffle_factor * self.batch_size) - ) + + def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: + return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time) + + def _get_premade_samples_dataset(self, subdir) -> Dataset: + split_dir = f"{self.sample_dir}/{subdir}" + return NumpybatchPremadeSamplesDataset(split_dir) + + def train_dataloader(self) -> DataLoader: + """Construct train dataloader""" + if self.sample_dir is not None: + dataset = self._get_premade_samples_dataset("train") else: - data_pipeline = ( - data_pipeline.sharding_filter().map(torch.load) - # Split the batches so we can use any batch-size - .split_batches() - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) + dataset = self._get_streamed_samples_dataset(*self.train_period) + return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs) - return data_pipeline + def val_dataloader(self) -> DataLoader: + """Construct val dataloader""" + if self.sample_dir is not None: + dataset = self._get_premade_samples_dataset("val") + else: + dataset = self._get_streamed_samples_dataset(*self.val_period) + return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 24ce5bfc..7546bf34 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -18,17 +18,15 @@ from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import HfApi -from ocf_datapipes.batch import BatchKey -from ocf_ml_metrics.evaluation.evaluation import evaluation +from ocf_datapipes.batch import BatchKey, copy_batch_to_device from pvnet.models.utils import ( BatchAccumulator, MetricAccumulator, PredAccumulator, - WeightedLosses, ) from pvnet.optimizers import AbstractOptimizer -from pvnet.utils import construct_ocf_ml_metrics_batch_df, plot_batch_forecasts +from pvnet.utils import plot_batch_forecasts DATA_CONFIG_NAME = "data_config.yaml" @@ -236,6 +234,11 @@ def get_data_config( return data_config_file + def _save_pretrained(self, save_directory: Path) -> None: + """Save weights from a Pytorch model to a local directory.""" + model_to_save = self.module if hasattr(self, "module") else self # type: ignore + torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME) + def save_pretrained( self, save_directory: Union[str, Path], @@ -348,7 +351,6 @@ def __init__( target_key: str = "gsp", interval_minutes: int = 30, timestep_intervals_to_plot: Optional[list[int]] = None, - use_weighted_loss: bool = False, forecast_minutes_ignore: Optional[int] = 0, ): """Abtstract base class for PVNet submodels. @@ -362,7 +364,6 @@ def __init__( target_key: The key of the target variable in the batch interval_minutes: The interval in minutes between each timestep in the data timestep_intervals_to_plot: Intervals, in timesteps, to plot during training - use_weighted_loss: Whether to use a weighted loss function forecast_minutes_ignore: Number of forecast minutes to ignore when calculating losses. For example if set to 60, the model doesnt predict the first 60 minutes """ @@ -394,8 +395,6 @@ def __init__( self.forecast_len = (forecast_minutes - forecast_minutes_ignore) // interval_minutes self.forecast_len_ignore = forecast_minutes_ignore // interval_minutes - self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len) - self._accumulated_metrics = MetricAccumulator() self._accumulated_batches = BatchAccumulator(key_to_keep=self._target_key_name) self._accumulated_y_hat = PredAccumulator() @@ -403,7 +402,6 @@ def __init__( # Store whether the model should use quantile regression or simply predict the mean self.use_quantile_regression = self.output_quantiles is not None - self.use_weighted_loss = use_weighted_loss # Store the number of ouput features that the model should predict for if self.use_quantile_regression: @@ -414,6 +412,10 @@ def __init__( # save all validation results to array, so we can save these to weights n biases self.validation_epoch_results = [] + def transfer_batch_to_device(self, batch, device, dataloader_idx): + """Method to move custom batches to a given device""" + return copy_batch_to_device(batch, device) + def _quantiles_to_prediction(self, y_quantiles): """ Convert network prediction into a point prediction. @@ -455,13 +457,11 @@ def _calculate_quantile_loss(self, y_quantiles, y): errors = y - y_quantiles[..., i] losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1)) losses = 2 * torch.cat(losses, dim=2) - if self.use_weighted_loss: - weights = self.weighted_losses.weights.unsqueeze(1).unsqueeze(0).to(y.device) - losses = losses * weights + return losses.mean() def _calculate_common_losses(self, y, y_hat): - """Calculate losses common to train, test, and val""" + """Calculate losses common to train, and val""" losses = {} @@ -473,10 +473,6 @@ def _calculate_common_losses(self, y, y_hat): mse_loss = F.mse_loss(y_hat, y) mae_loss = F.l1_loss(y_hat, y) - # calculate mse, mae with exp weighted loss - mse_exp = self.weighted_losses.get_mse_exp(output=y_hat, target=y) - mae_exp = self.weighted_losses.get_mae_exp(output=y_hat, target=y) - # TODO: Compute correlation coef using np.corrcoef(tensor with # shape (2, num_timesteps))[0, 1] on each example, and taking # the mean across the batch? @@ -484,8 +480,6 @@ def _calculate_common_losses(self, y, y_hat): { "MSE": mse_loss, "MAE": mae_loss, - "MSE_EXP": mse_exp, - "MAE_EXP": mae_exp, } ) @@ -531,12 +525,6 @@ def _calculate_val_losses(self, y, y_hat): losses.update(self._step_mae_and_mse(y, y_persist, dict_key_root="persistence")) return losses - def _calculate_test_losses(self, y, y_hat): - """Calculate additional test losses""" - # No additional test losses - losses = {} - return losses - def _training_accumulate_log(self, batch, batch_idx, losses, y_hat): """Internal function to accumulate training batches and log results. @@ -582,7 +570,7 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat): def training_step(self, batch, batch_idx): """Run training step""" y_hat = self(batch) - y = batch[self._target_key][:, -self.forecast_len :, 0] + y = batch[self._target_key][:, -self.forecast_len :] losses = self._calculate_common_losses(y, y_hat) losses = {f"{k}/train": v for k, v in losses.items()} @@ -617,7 +605,7 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): """Append validation results to self.validation_epoch_results""" # get truth values, shape (b, forecast_len) - y = batch[self._target_key][:, -self.forecast_len :, 0] + y = batch[self._target_key][:, -self.forecast_len :] y = y.detach().cpu().numpy() batch_size = y.shape[0] @@ -663,8 +651,8 @@ def validation_step(self, batch: dict, batch_idx): accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches y_hat = self(batch) - # Sensor seems to be in batch, station, time order - y = batch[self._target_key][:, -self.forecast_len :, 0] + + y = batch[self._target_key][:, -self.forecast_len :] if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: self._log_validation_results(batch, y_hat, accum_batch_num) @@ -763,37 +751,6 @@ def on_validation_epoch_end(self): print("Failed to log horizon_loss_curve to wandb") print(e) - def test_step(self, batch, batch_idx): - """Run test step""" - y_hat = self(batch) - y = batch[self._target_key][:, -self.forecast_len :, 0] - - losses = self._calculate_common_losses(y, y_hat) - losses.update(self._calculate_val_losses(y, y_hat)) - losses.update(self._calculate_test_losses(y, y_hat)) - logged_losses = {f"{k}/test": v for k, v in losses.items()} - - self.log_dict( - logged_losses, - on_step=False, - on_epoch=True, - ) - - if self.use_quantile_regression: - y_hat = self._quantiles_to_prediction(y_hat) - - return construct_ocf_ml_metrics_batch_df(batch, y, y_hat) - - def on_test_epoch_end(self, outputs): - """Evalauate test results using oc_ml_metrics""" - results_df = pd.concat(outputs) - # setting model_name="test" gives us keys like "test/mw/forecast_horizon_30_minutes/mae" - metrics = evaluation(results_df=results_df, model_name="test", outturn_unit="mw") - - self.log_dict( - metrics, - ) - def configure_optimizers(self): """Configure the optimizers using learning rate found with LR finder if used""" if self.lr is not None: diff --git a/pvnet/models/baseline/last_value.py b/pvnet/models/baseline/last_value.py index 2e1c82dc..1fc40b10 100644 --- a/pvnet/models/baseline/last_value.py +++ b/pvnet/models/baseline/last_value.py @@ -36,7 +36,7 @@ def forward(self, x: dict): # take the last value non forecaster value and the first in the pv yeild # (this is the pv site we are preditcting for) - y_hat = gsp_yield[:, -self.forecast_len - 1, 0] + y_hat = gsp_yield[:, -self.forecast_len - 1] # expand the last valid forward n predict steps out = y_hat.unsqueeze(1).repeat(1, self.forecast_len) diff --git a/pvnet/models/baseline/single_value.py b/pvnet/models/baseline/single_value.py index 980feb95..77f83b3f 100644 --- a/pvnet/models/baseline/single_value.py +++ b/pvnet/models/baseline/single_value.py @@ -33,5 +33,5 @@ def __init__( def forward(self, x: dict): """Run model forward on dict batch of data""" # Returns a single value at all steps - y_hat = torch.zeros_like(x[BatchKey.gsp][:, : self.forecast_len, 0]) + self._value + y_hat = torch.zeros_like(x[BatchKey.gsp][:, : self.forecast_len]) + self._value return y_hat diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index 1541f69c..2e68542d 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -71,7 +71,6 @@ def __init__( num_embeddings: Optional[int] = 318, timestep_intervals_to_plot: Optional[list[int]] = None, adapt_batches: Optional[bool] = False, - use_weighted_loss: Optional[bool] = False, forecast_minutes_ignore: Optional[int] = 0, ): """Neural network which combines information from different sources. @@ -133,7 +132,6 @@ def __init__( adapt_batches: If set to true, we attempt to slice the batches to the expected shape for the model to use. This allows us to overprepare batches and slice from them for the data we need for a model run. - use_weighted_loss: Whether to use a weighted loss function forecast_minutes_ignore: Number of forecast minutes to ignore when calculating losses. For example if set to 60, the model doesnt predict the first 60 minutes """ @@ -160,7 +158,6 @@ def __init__( target_key=target_key, interval_minutes=interval_minutes, timestep_intervals_to_plot=timestep_intervals_to_plot, - use_weighted_loss=use_weighted_loss, forecast_minutes_ignore=forecast_minutes_ignore, ) @@ -321,7 +318,7 @@ def forward(self, x): sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels if self.add_image_embedding_channel: - id = x[BatchKey[f"{self._target_key_name}_id"]][:, 0].int() + id = x[BatchKey[f"{self._target_key_name}_id"]].int() sat_data = self.sat_embed(sat_data, id) modes["sat"] = self.sat_encoder(sat_data) @@ -338,7 +335,7 @@ def forward(self, x): nwp_data = torch.clip(nwp_data, min=-50, max=50) if self.add_image_embedding_channel: - id = x[BatchKey[f"{self._target_key_name}_id"]][:, 0].int() + id = x[BatchKey[f"{self._target_key_name}_id"]].int() nwp_data = self.nwp_embed_dict[nwp_source](nwp_data, id) nwp_out = self.nwp_encoders_dict[nwp_source](nwp_data) @@ -365,7 +362,7 @@ def forward(self, x): # ********************** Embedding of GSP ID ******************** if self.embedding_dim: - id = x[BatchKey[f"{self._target_key_name}_id"]][:, 0].int() + id = x[BatchKey[f"{self._target_key_name}_id"]].int() id_embedding = self.embed(id) modes["id"] = id_embedding @@ -380,16 +377,6 @@ def forward(self, x): # This needs to be a Batch as input modes["wind"] = self.wind_encoder(x_tmp) - # *********************** Sensor Data ************************************ - if self.include_sensor: - if self._target_key_name != "sensor": - modes["sensor"] = self.sensor_encoder(x) - else: - x_tmp = x.copy() - x_tmp[BatchKey.sensor] = x_tmp[BatchKey.sensor][:, : self.history_len + 1] - # This needs to be a Batch as input - modes["sensor"] = self.sensor_encoder(x_tmp) - if self.include_sun: sun = torch.cat( ( diff --git a/pvnet/models/multimodal/unimodal_teacher.py b/pvnet/models/multimodal/unimodal_teacher.py index 8720f973..71692222 100644 --- a/pvnet/models/multimodal/unimodal_teacher.py +++ b/pvnet/models/multimodal/unimodal_teacher.py @@ -219,7 +219,7 @@ def teacher_forward(self, x): sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels if self.add_image_embedding_channel: - id = x[BatchKey.gsp_id][:, 0].int() + id = x[BatchKey.gsp_id].int() sat_data = teacher_model.sat_embed(sat_data, id) modes[mode] = teacher_model.sat_encoder(sat_data) @@ -233,7 +233,7 @@ def teacher_forward(self, x): nwp_data = torch.swapaxes(nwp_data, 1, 2) # switch time and channels nwp_data = torch.clip(nwp_data, min=-50, max=50) if teacher_model.add_image_embedding_channel: - id = x[BatchKey.gsp_id][:, 0].int() + id = x[BatchKey.gsp_id].int() nwp_data = teacher_model.nwp_embed_dict[nwp_source](nwp_data, id) nwp_out = teacher_model.nwp_encoders_dict[nwp_source](nwp_data) @@ -260,7 +260,7 @@ def forward(self, x, return_modes=False): sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels if self.add_image_embedding_channel: - id = x[BatchKey.gsp_id][:, 0].int() + id = x[BatchKey.gsp_id].int() sat_data = self.sat_embed(sat_data, id) modes["sat"] = self.sat_encoder(sat_data) @@ -276,7 +276,7 @@ def forward(self, x, return_modes=False): nwp_data = torch.clip(nwp_data, min=-50, max=50) if self.add_image_embedding_channel: - id = x[BatchKey.gsp_id][:, 0].int() + id = x[BatchKey.gsp_id].int() nwp_data = self.nwp_embed_dict[nwp_source](nwp_data, id) nwp_out = self.nwp_encoders_dict[nwp_source](nwp_data) @@ -301,7 +301,7 @@ def forward(self, x, return_modes=False): # ********************** Embedding of GSP ID ******************** if self.embedding_dim: - id = x[BatchKey.gsp_id][:, 0].int() + id = x[BatchKey.gsp_id].int() id_embedding = self.embed(id) modes["id"] = id_embedding diff --git a/pvnet/models/utils.py b/pvnet/models/utils.py index b06086e4..2bbe78f2 100644 --- a/pvnet/models/utils.py +++ b/pvnet/models/utils.py @@ -1,8 +1,6 @@ """Utility functions""" import logging -import math -from typing import Optional import numpy as np import torch @@ -124,55 +122,3 @@ def flush(self) -> dict[BatchKey, list[torch.Tensor]]: batch[k] = torch.cat(v, dim=0) self._batches = {} return batch - - -class WeightedLosses: - """Class: Weighted loss depending on the forecast horizon.""" - - def __init__(self, decay_rate: Optional[int] = None, forecast_length: int = 6): - """ - Want to set up the MSE loss function so the weights only have to be calculated once. - - Args: - decay_rate: The weights exponentially decay depending on the 'decay_rate'. - forecast_length: The forecast length is needed to make sure the weights sum to 1 - """ - self.decay_rate = decay_rate - self.forecast_length = forecast_length - - logger.debug( - f"Setting up weights with decay rate {decay_rate} and of length {forecast_length}" - ) - - # set default rate of ln(2) if not set - if self.decay_rate is None: - self.decay_rate = math.log(2) - - # make weights from decay rate - weights = np.exp(-self.decay_rate * np.arange(self.forecast_length)) - weights = torch.tensor(weights) - - # normalized the weights, so there mean is 1. - # To calculate the loss, we times the weights by the differences between truth - # and predictions and then take the mean across all forecast horizons and the batch - self.weights = weights / weights.mean() - - def get_mse_exp(self, output, target): - """Loss function weighted MSE""" - - weights = self.weights.to(target.device) - # get the differences weighted by the forecast horizon weights - diff_with_weights = weights * ((output - target) ** 2) - - # average across batches - return torch.mean(diff_with_weights) - - def get_mae_exp(self, output, target): - """Loss function weighted MAE""" - - weights = self.weights.to(target.device) - # get the differences weighted by the forecast horizon weights - diff_with_weights = weights * torch.abs(output - target) - - # average across batches - return torch.mean(diff_with_weights) diff --git a/pvnet/training.py b/pvnet/training.py index cafcdac4..ea8535e8 100644 --- a/pvnet/training.py +++ b/pvnet/training.py @@ -117,7 +117,7 @@ def train(config: DictConfig) -> Optional[float]: if data_config is None: # Data config can be none if using presaved batches. We go to the presaved # batches to get the data config - data_config = f"{config.datamodule.batch_dir}/data_configuration.yaml" + data_config = f"{config.datamodule.sample_dir}/data_configuration.yaml" assert os.path.isfile(data_config), f"Data config file not found: {data_config}" shutil.copyfile(data_config, f"{callback.dirpath}/data_config.yaml") @@ -162,11 +162,6 @@ def train(config: DictConfig) -> Optional[float]: # Train the model completely trainer.fit(model=model, datamodule=datamodule) - if config.test_after_training: - # Evaluate model on test set, using the best model achieved during training - log.info("Starting testing!") - trainer.test(model=model, datamodule=datamodule, ckpt_path="best") - # Make sure everything closed properly log.info("Finalizing!") utils.finish( diff --git a/pvnet/utils.py b/pvnet/utils.py index b2c4d99a..ee0ad6aa 100644 --- a/pvnet/utils.py +++ b/pvnet/utils.py @@ -6,7 +6,6 @@ import lightning.pytorch as pl import matplotlib.pyplot as plt -import numpy as np import pandas as pd import pylab import rich.syntax @@ -16,7 +15,6 @@ from lightning.pytorch.utilities import rank_zero_only from ocf_datapipes.batch import BatchKey from ocf_datapipes.utils import Location -from ocf_datapipes.utils.geospatial import osgb_to_lon_lat from omegaconf import DictConfig, OmegaConf @@ -265,14 +263,14 @@ def _get_numpy(key): y_id_key = BatchKey[f"{key_to_plot}_id"] BatchKey[f"{key_to_plot}_t0_idx"] time_utc_key = BatchKey[f"{key_to_plot}_time_utc"] - y = batch[y_key][:, :, 0].cpu().numpy() # Select the one it is trained on + y = batch[y_key].cpu().numpy() # Select the one it is trained on y_hat = y_hat.cpu().numpy() # Select between the timesteps in timesteps to plot plotting_name = key_to_plot.upper() gsp_ids = batch[y_id_key].cpu().numpy().squeeze() - times_utc = batch[time_utc_key].cpu().numpy().squeeze().astype("datetime64[s]") + times_utc = batch[time_utc_key].cpu().numpy().squeeze().astype("datetime64[ns]") times_utc = [pd.to_datetime(t) for t in times_utc] if timesteps_to_plot is not None: y = y[:, timesteps_to_plot[0] : timesteps_to_plot[1]] @@ -323,45 +321,3 @@ def _get_numpy(key): plt.tight_layout() return fig - - -def construct_ocf_ml_metrics_batch_df(batch, y, y_hat): - """Helper function tot construct DataFrame for ocf_ml_metrics""" - - def _repeat(x): - return np.repeat(x.squeeze(), n_times) - - def _get_numpy(key): - return batch[key].cpu().numpy().squeeze() - - t0_idx = batch[BatchKey.gsp_t0_idx] - times_utc = _get_numpy(BatchKey.gsp_time_utc) - n_times = len(times_utc[0]) - t0_idx - 1 - - y_osgb_centre = _get_numpy(BatchKey.gsp_y_osgb) - x_osgb_centre = _get_numpy(BatchKey.gsp_x_osgb) - longitude, latitude = osgb_to_lon_lat(x=x_osgb_centre, y=y_osgb_centre) - - # Store df columns in dict - df_dict = {} - - # Repeat these features for each forecast time - df_dict["latitude"] = _repeat(latitude) - df_dict["longitude"] = _repeat(longitude) - df_dict["id"] = _repeat(_get_numpy(BatchKey.gsp_id)) - df_dict["t0_datetime_utc"] = _repeat(times_utc[:, t0_idx]).astype("datetime64[s]") - df_dict["capacity_mwp"] = _repeat(_get_numpy(BatchKey.gsp_capacity_megawatt_power)) - - # TODO: Some (10%) of these values are NaN -> 0 for time t0 for pvnet pipeline - # Better to search for last non-nan (non-zero)? - df_dict["t0_actual_pv_outturn_mw"] = _repeat( - (_get_numpy(BatchKey.gsp_capacity_megawatt_power)[:, None] * _get_numpy(BatchKey.gsp))[ - :, t0_idx - ] - ) - - # Flatten the forecasts times to 1D - df_dict["target_datetime_utc"] = times_utc[:, t0_idx + 1 :].flatten().astype("datetime64[s]") - df_dict["actual_pv_outturn_mw"] = y.cpu().numpy().flatten() * df_dict["capacity_mwp"] - df_dict["forecast_pv_outturn_mw"] = y_hat.cpu().numpy().flatten() * df_dict["capacity_mwp"] - return pd.DataFrame(df_dict) diff --git a/pyproject.toml b/pyproject.toml index 0a066326..cc3ea1cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ dynamic = ["version", "readme"] license={file="LICENCE"} dependencies = [ + "ocf_data_sampler==0.0.26", "ocf_datapipes>=3.3.34", "ocf_ml_metrics>=0.0.11", "numpy", diff --git a/scripts/save_batches.py b/scripts/save_batches.py deleted file mode 100644 index 953e7244..00000000 --- a/scripts/save_batches.py +++ /dev/null @@ -1,184 +0,0 @@ -""" -Constructs batches and saves them to disk. - -Currently a slightly hacky implementation due to the way the configs are done. This script will use -the same config file currently set to train the model. - -use: -``` -python save_batches.py -``` -if setting all values in the datamodule config file, or - -``` -python save_batches.py \ - datamodule.batch_output_dir="/mnt/disks/bigbatches/batches_v0" \ - datamodule.batch_size=2 \ - datamodule.num_workers=2 \ - datamodule.num_train_batches=0 \ - datamodule.num_val_batches=2 -``` -if wanting to override these values for example -""" -# This is needed to get multiprocessing/multiple workers to behave -try: - import torch.multiprocessing as mp - - mp.set_start_method("spawn", force=True) -except RuntimeError: - pass - -import logging -import os -import shutil -import sys - -# Tired of seeing these warnings -import warnings - -import hydra -import torch -from ocf_datapipes.batch import batch_to_tensor, stack_np_examples_into_batch -from ocf_datapipes.training.pvnet import pvnet_datapipe -from ocf_datapipes.training.pvnet_site import pvnet_site_datapipe -from ocf_datapipes.training.windnet import windnet_datapipe -from omegaconf import DictConfig, OmegaConf -from sqlalchemy import exc as sa_exc -from torch.utils.data import DataLoader -from torch.utils.data.datapipes.iter import IterableWrapper -from tqdm import tqdm - -from pvnet.utils import print_config - -warnings.filterwarnings("ignore", category=sa_exc.SAWarning) - -logger = logging.getLogger(__name__) - -logging.basicConfig(stream=sys.stdout, level=logging.ERROR) - - -class _save_batch_func_factory: - def __init__(self, batch_dir, output_format: str = "torch"): - self.batch_dir = batch_dir - self.output_format = output_format - - def __call__(self, input): - i, batch = input - if self.output_format == "torch": - torch.save(batch, f"{self.batch_dir}/{i:06}.pt") - elif self.output_format == "netcdf": - batch.to_netcdf(f"{self.batch_dir}/{i:06}.nc", mode="w", engine="h5netcdf") - - -def _get_datapipe(config_path, start_time, end_time, batch_size, renewable: str = "pv"): - if renewable == "pv": - data_pipeline_fn = pvnet_datapipe - elif renewable == "wind": - data_pipeline_fn = windnet_datapipe - elif renewable in ["pv_india", "pv_site"]: - data_pipeline_fn = pvnet_site_datapipe - else: - raise ValueError(f"Unknown renewable: {renewable}") - data_pipeline = data_pipeline_fn( - config_path, - start_time=start_time, - end_time=end_time, - ) - if renewable == "pv": - data_pipeline = ( - data_pipeline.batch(batch_size).map(stack_np_examples_into_batch).map(batch_to_tensor) - ) - return data_pipeline - - -def _save_batches_with_dataloader( - batch_pipe, batch_dir, num_batches, dataloader_kwargs, output_format: str = "torch" -): - save_func = _save_batch_func_factory(batch_dir, output_format=output_format) - filenumber_pipe = IterableWrapper(range(num_batches)).sharding_filter() - save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func) - - dataloader = DataLoader(save_pipe, **dataloader_kwargs) - - pbar = tqdm(total=num_batches) - for i, batch in zip(range(num_batches), dataloader): - pbar.update() - pbar.close() - del dataloader - - -@hydra.main(config_path="../configs/", config_name="config.yaml", version_base="1.2") -def main(config: DictConfig): - """Constructs and saves validation and training batches.""" - config_dm = config.datamodule - - print_config(config, resolve=False) - - # Set up directory - os.makedirs(config_dm.batch_output_dir, exist_ok=False) - - with open(f"{config_dm.batch_output_dir}/datamodule.yaml", "w") as f: - f.write(OmegaConf.to_yaml(config_dm)) - - shutil.copyfile( - config_dm.configuration, f"{config_dm.batch_output_dir}/data_configuration.yaml" - ) - - dataloader_kwargs = dict( - shuffle=False, - batch_size=None, # batched in datapipe step - sampler=None, - batch_sampler=None, - num_workers=config_dm.num_workers, - collate_fn=None, - pin_memory=False, - drop_last=False, - timeout=0, - worker_init_fn=None, - prefetch_factor=config_dm.prefetch_factor, - persistent_workers=False, - ) - - if config_dm.num_val_batches > 0: - os.mkdir(f"{config_dm.batch_output_dir}/val") - print("----- Saving val batches -----") - - val_batch_pipe = _get_datapipe( - config_dm.configuration, - *config_dm.val_period, - config_dm.batch_size, - renewable=config.renewable, - ) - - _save_batches_with_dataloader( - batch_pipe=val_batch_pipe, - batch_dir=f"{config_dm.batch_output_dir}/val", - num_batches=config_dm.num_val_batches, - dataloader_kwargs=dataloader_kwargs, - output_format="torch" if config.renewable == "pv" else "netcdf", - ) - - if config_dm.num_train_batches > 0: - os.mkdir(f"{config_dm.batch_output_dir}/train") - print("----- Saving train batches -----") - - train_batch_pipe = _get_datapipe( - config_dm.configuration, - *config_dm.train_period, - config_dm.batch_size, - renewable=config.renewable, - ) - - _save_batches_with_dataloader( - batch_pipe=train_batch_pipe, - batch_dir=f"{config_dm.batch_output_dir}/train", - num_batches=config_dm.num_train_batches, - dataloader_kwargs=dataloader_kwargs, - output_format="torch" if config.renewable == "pv" else "netcdf", - ) - - print("done") - - -if __name__ == "__main__": - main() diff --git a/scripts/save_samples.py b/scripts/save_samples.py new file mode 100644 index 00000000..d38a45f9 --- /dev/null +++ b/scripts/save_samples.py @@ -0,0 +1,208 @@ +""" +Constructs samples and saves them to disk. + +Currently a slightly hacky implementation due to the way the configs are done. This script will use +the same config file currently set to train the model. + +use: +``` +python save_samples.py +``` +if setting all values in the datamodule config file, or + +``` +python save_samples.py \ + +datamodule.sample_output_dir="/mnt/disks/bigbatches/samples_v0" \ + +datamodule.num_train_samples=0 \ + +datamodule.num_val_samples=2 \ + datamodule.num_workers=2 \ + datamodule.prefetch_factor=2 +``` +if wanting to override these values for example +""" + +# Ensure this block of code runs only in the main process to avoid issues with worker processes. +if __name__ == "__main__": + import torch.multiprocessing as mp + + # Set the start method for torch multiprocessing. Choose either "forkserver" or "spawn" to be + # compatible with dask's multiprocessing. + mp.set_start_method("forkserver") + + # Set the sharing strategy to 'file_system' to handle file descriptor limitations. This is + # important because libraries like Zarr may open many files, which can exhaust the file + # descriptor limit if too many workers are used. + mp.set_sharing_strategy("file_system") + + +import logging +import os +import shutil +import sys +import warnings + +import dask +import hydra +import torch +from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset +from omegaconf import DictConfig, OmegaConf +from sqlalchemy import exc as sa_exc +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from pvnet.utils import print_config + +dask.config.set(scheduler="threads", num_workers=4) + + +# ------- filter warning and set up config ------- + +warnings.filterwarnings("ignore", category=sa_exc.SAWarning) + +logger = logging.getLogger(__name__) + +logging.basicConfig(stream=sys.stdout, level=logging.ERROR) + +# ------------------------------------------------- + + +class SaveFuncFactory: + """Factory for creating a function to save a sample to disk.""" + + def __init__(self, save_dir: str, renewable: str = "pv"): + """Factory for creating a function to save a sample to disk.""" + self.save_dir = save_dir + self.renewable = renewable + + def __call__(self, sample, sample_num: int): + """Save a sample to disk""" + if self.renewable == "pv": + torch.save(sample, f"{self.save_dir}/{sample_num:08}.pt") + elif self.renewable in ["wind", "pv_india", "pv_site"]: + raise NotImplementedError + else: + raise ValueError(f"Unknown renewable: {self.renewable}") + + +def get_dataset(config_path: str, start_time: str, end_time: str, renewable: str = "pv") -> Dataset: + """Get the dataset for the given renewable type.""" + if renewable == "pv": + dataset_cls = PVNetUKRegionalDataset + elif renewable in ["wind", "pv_india", "pv_site"]: + raise NotImplementedError + else: + raise ValueError(f"Unknown renewable: {renewable}") + + return dataset_cls(config_path, start_time=start_time, end_time=end_time) + + +def save_samples_with_dataloader( + dataset: Dataset, + save_dir: str, + num_samples: int, + dataloader_kwargs: dict, + renewable: str = "pv", +) -> None: + """Save samples from a dataset using a dataloader.""" + save_func = SaveFuncFactory(save_dir, renewable=renewable) + + dataloader = DataLoader(dataset, **dataloader_kwargs) + + pbar = tqdm(total=num_samples) + for i, sample in zip(range(num_samples), dataloader): + save_func(sample, i) + pbar.update() + pbar.close() + + +@hydra.main(config_path="../configs/", config_name="config.yaml", version_base="1.2") +def main(config: DictConfig) -> None: + """Constructs and saves validation and training samples.""" + config_dm = config.datamodule + + print_config(config, resolve=False) + + # Set up directory + os.makedirs(config_dm.sample_output_dir, exist_ok=False) + + # Copy across configs which define the samples into the new sample directory + with open(f"{config_dm.sample_output_dir}/datamodule.yaml", "w") as f: + f.write(OmegaConf.to_yaml(config_dm)) + + shutil.copyfile( + config_dm.configuration, f"{config_dm.sample_output_dir}/data_configuration.yaml" + ) + + # Define the keywargs going into the train and val dataloaders + dataloader_kwargs = dict( + shuffle=True, + batch_size=None, + sampler=None, + batch_sampler=None, + num_workers=config_dm.num_workers, + collate_fn=None, + pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=config_dm.prefetch_factor, + persistent_workers=False, # Not needed since we only enter the dataloader loop once + ) + + if config_dm.num_val_samples > 0: + print("----- Saving val samples -----") + + val_output_dir = f"{config_dm.sample_output_dir}/val" + + # Make directory for val samples + os.mkdir(val_output_dir) + + # Get the dataset + val_dataset = get_dataset( + config_dm.configuration, + *config_dm.val_period, + renewable=config.renewable, + ) + + # Save samples + save_samples_with_dataloader( + dataset=val_dataset, + save_dir=val_output_dir, + num_samples=config_dm.num_val_samples, + dataloader_kwargs=dataloader_kwargs, + renewable=config.renewable, + ) + + del val_dataset + + if config_dm.num_train_samples > 0: + print("----- Saving train samples -----") + + train_output_dir = f"{config_dm.sample_output_dir}/train" + + # Make directory for train samples + os.mkdir(train_output_dir) + + # Get the dataset + train_dataset = get_dataset( + config_dm.configuration, + *config_dm.train_period, + renewable=config.renewable, + ) + + # Save samples + save_samples_with_dataloader( + dataset=train_dataset, + save_dir=train_output_dir, + num_samples=config_dm.num_train_samples, + dataloader_kwargs=dataloader_kwargs, + renewable=config.renewable, + ) + + del train_dataset + + print("----- Saving complete -----") + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py index 8d2ab630..ba657af5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -106,25 +106,22 @@ def sample_train_val_datamodule(): file_n = 0 - for file in glob.glob("tests/test_data/sample_batches/train/*.pt"): - batch = torch.load(file) + for file_n, file in enumerate(glob.glob("tests/test_data/presaved_samples/train/*.pt")): + sample = torch.load(file) for i in range(n_duplicates): # Save fopr both train and val - torch.save(batch, f"{tmpdirname}/train/{file_n:06}.pt") - torch.save(batch, f"{tmpdirname}/val/{file_n:06}.pt") - - file_n += 1 + torch.save(sample, f"{tmpdirname}/train/{file_n:06}.pt") + torch.save(sample, f"{tmpdirname}/val/{file_n:06}.pt") dm = DataModule( configuration=None, + sample_dir=f"{tmpdirname}", batch_size=2, num_workers=0, prefetch_factor=None, train_period=[None, None], val_period=[None, None], - test_period=[None, None], - batch_dir=f"{tmpdirname}", ) yield dm @@ -132,14 +129,13 @@ def sample_train_val_datamodule(): @pytest.fixture() def sample_datamodule(): dm = DataModule( + sample_dir="tests/test_data/presaved_samples", configuration=None, batch_size=2, num_workers=0, prefetch_factor=None, train_period=[None, None], val_period=[None, None], - test_period=[None, None], - batch_dir="tests/test_data/sample_batches", ) return dm @@ -157,9 +153,10 @@ def sample_satellite_batch(sample_batch): @pytest.fixture() -def sample_pv_batch(sample_batch): - pv_data = sample_batch[BatchKey.pv] - return pv_data +def sample_pv_batch(): + # TODO: Once PV site inputs are available from ocf-data-sampler UK regional remove these + # old batches. For now we use the old batches to test the site encoder models + return torch.load("tests/test_data/presaved_batches/train/000000.pt") @pytest.fixture() @@ -191,7 +188,7 @@ def model_minutes_kwargs(): def encoder_model_kwargs(): # Used to test encoder model on satellite data kwargs = dict( - sequence_length=(90 - 30) // 5 + 1, + sequence_length=7, # 30 minutes of 5 minutely satellite data = 7 time steps image_size_pixels=24, in_channels=11, out_features=128, @@ -240,7 +237,7 @@ def raw_multimodal_model_kwargs(model_minutes_kwargs): "ukv": dict( _target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet, _partial_=True, - in_channels=2, + in_channels=11, out_features=128, number_of_conv3d_layers=6, conv3d_channels=32, @@ -248,15 +245,8 @@ def raw_multimodal_model_kwargs(model_minutes_kwargs): ), }, add_image_embedding_channel=True, - pv_encoder=dict( - _target_=pvnet.models.multimodal.site_encoders.encoders.SingleAttentionNetwork, - _partial_=True, - num_sites=349, - out_features=40, - num_heads=4, - kdim=40, - id_embed_dim=20, - ), + # ocf-data-sampler doesn't supprt PV site inputs yet + pv_encoder=None, output_network=dict( _target_=pvnet.models.multimodal.linear_networks.networks.ResFCNet2, _partial_=True, @@ -268,11 +258,10 @@ def raw_multimodal_model_kwargs(model_minutes_kwargs): embedding_dim=16, include_sun=True, include_gsp_yield_history=True, - sat_history_minutes=90, + sat_history_minutes=30, nwp_history_minutes={"ukv": 120}, nwp_forecast_minutes={"ukv": 480}, - pv_history_minutes=180, - min_sat_delay_minutes=30, + min_sat_delay_minutes=0, ) kwargs.update(model_minutes_kwargs) @@ -297,14 +286,6 @@ def multimodal_quantile_model(multimodal_model_kwargs): return model -@pytest.fixture() -def multimodal_weighted_quantile_model(multimodal_model_kwargs): - model = Model( - output_quantiles=[0.1, 0.5, 0.9], **multimodal_model_kwargs, use_weighted_loss=True - ) - return model - - @pytest.fixture() def multimodal_quantile_model_ignore_minutes(multimodal_model_kwargs): """Only forecsat second half of the 8 hours""" diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index ff07fe0f..00b1705d 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,3 +1,4 @@ +import pytest from pvnet.data.datamodule import DataModule from pvnet.data.wind_datamodule import WindDataModule from pvnet.data.pv_site_datamodule import PVSiteDataModule @@ -8,16 +9,16 @@ def test_init(): dm = DataModule( configuration=None, + sample_dir="tests/test_data/presaved_samples", batch_size=2, num_workers=0, prefetch_factor=None, train_period=[None, None], val_period=[None, None], - test_period=[None, None], - batch_dir="tests/test_data/sample_batches", ) +@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet") def test_wind_init(): dm = WindDataModule( configuration=None, @@ -31,6 +32,7 @@ def test_wind_init(): ) +@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet") def test_wind_init_with_nwp_filter(): dm = WindDataModule( configuration=None, @@ -53,6 +55,7 @@ def test_wind_init_with_nwp_filter(): assert batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp].shape[2] == 2 +@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet") def test_pv_site_init(): dm = PVSiteDataModule( configuration=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_batches/data_configuration.yaml", @@ -69,13 +72,12 @@ def test_pv_site_init(): def test_iter(): dm = DataModule( configuration=None, + sample_dir="tests/test_data/presaved_samples", batch_size=2, num_workers=0, prefetch_factor=None, train_period=[None, None], val_period=[None, None], - test_period=[None, None], - batch_dir="tests/test_data/sample_batches", ) batch = next(iter(dm.train_dataloader())) @@ -84,15 +86,21 @@ def test_iter(): def test_iter_multiprocessing(): dm = DataModule( configuration=None, - batch_size=2, + sample_dir="tests/test_data/presaved_samples", + batch_size=1, num_workers=2, - prefetch_factor=2, + prefetch_factor=1, train_period=[None, None], val_period=[None, None], - test_period=[None, None], - batch_dir="tests/test_data/sample_batches", ) - batch = next(iter(dm.train_dataloader())) + served_batches = 0 for batch in dm.train_dataloader(): - pass + served_batches += 1 + + # Stop once we've got 2 batches + if served_batches == 2: + break + + # Make sure we've served 2 batches + assert served_batches == 2 diff --git a/tests/models/multimodal/site_encoders/test_encoders.py b/tests/models/multimodal/site_encoders/test_encoders.py index 3d109f76..41969b22 100644 --- a/tests/models/multimodal/site_encoders/test_encoders.py +++ b/tests/models/multimodal/site_encoders/test_encoders.py @@ -10,10 +10,10 @@ import pytest -def _test_model_forward(batch, model_class, kwargs): +def _test_model_forward(batch, model_class, kwargs, batch_size): model = model_class(**kwargs) y = model(batch) - assert tuple(y.shape) == (2, kwargs["out_features"]), y.shape + assert tuple(y.shape) == (batch_size, kwargs["out_features"]), y.shape def _test_model_backward(batch, model_class, kwargs): @@ -24,22 +24,37 @@ def _test_model_backward(batch, model_class, kwargs): # Test model forward on all models -def test_simplelearnedaggregator_forward(sample_batch, site_encoder_model_kwargs): - _test_model_forward(sample_batch, SimpleLearnedAggregator, site_encoder_model_kwargs) +def test_simplelearnedaggregator_forward(sample_pv_batch, site_encoder_model_kwargs): + _test_model_forward( + sample_pv_batch, + SimpleLearnedAggregator, + site_encoder_model_kwargs, + batch_size=8, + ) -def test_singleattentionnetwork_forward(sample_batch, site_encoder_model_kwargs): - _test_model_forward(sample_batch, SingleAttentionNetwork, site_encoder_model_kwargs) +def test_singleattentionnetwork_forward(sample_pv_batch, site_encoder_model_kwargs): + _test_model_forward( + sample_pv_batch, + SingleAttentionNetwork, + site_encoder_model_kwargs, + batch_size=8, + ) def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs): - _test_model_forward(sample_wind_batch, SingleAttentionNetwork, site_encoder_sensor_model_kwargs) + _test_model_forward( + sample_wind_batch, + SingleAttentionNetwork, + site_encoder_sensor_model_kwargs, + batch_size=2, + ) # Test model backward on all models -def test_simplelearnedaggregator_backward(sample_batch, site_encoder_model_kwargs): - _test_model_backward(sample_batch, SimpleLearnedAggregator, site_encoder_model_kwargs) +def test_simplelearnedaggregator_backward(sample_pv_batch, site_encoder_model_kwargs): + _test_model_backward(sample_pv_batch, SimpleLearnedAggregator, site_encoder_model_kwargs) -def test_singleattentionnetwork_backward(sample_batch, site_encoder_model_kwargs): - _test_model_backward(sample_batch, SingleAttentionNetwork, site_encoder_model_kwargs) +def test_singleattentionnetwork_backward(sample_pv_batch, site_encoder_model_kwargs): + _test_model_backward(sample_pv_batch, SingleAttentionNetwork, site_encoder_model_kwargs) diff --git a/tests/models/multimodal/test_multimodal.py b/tests/models/multimodal/test_multimodal.py index 3e82e171..83f74657 100644 --- a/tests/models/multimodal/test_multimodal.py +++ b/tests/models/multimodal/test_multimodal.py @@ -1,5 +1,6 @@ from torch.optim import SGD import pytest +from ocf_datapipes.batch.batches import BatchKey, NWPBatchKey def test_model_forward(multimodal_model, sample_batch): @@ -36,23 +37,6 @@ def test_quantile_model_backward(multimodal_quantile_model, sample_batch): y_quantiles.sum().backward() -def test_weighted_quantile_model_forward(multimodal_weighted_quantile_model, sample_batch): - y_quantiles = multimodal_weighted_quantile_model(sample_batch) - - # check output is the correct shape - # batch size=2, forecast_len=15, num_quantiles=3 - assert tuple(y_quantiles.shape) == (2, 16, 3), y_quantiles.shape - - -def test_weighted_quantile_model_backward(multimodal_weighted_quantile_model, sample_batch): - opt = SGD(multimodal_weighted_quantile_model.parameters(), lr=0.001) - - y_quantiles = multimodal_weighted_quantile_model(sample_batch) - - # Backwards on sum drives sum to zero - y_quantiles.sum().backward() - - def test_weighted_quantile_model_forward(multimodal_quantile_model_ignore_minutes, sample_batch): y_quantiles = multimodal_quantile_model_ignore_minutes(sample_batch) diff --git a/tests/models/multimodal/test_unimodal_teacher.py b/tests/models/multimodal/test_unimodal_teacher.py index 758b32a3..fbed5e92 100644 --- a/tests/models/multimodal/test_unimodal_teacher.py +++ b/tests/models/multimodal/test_unimodal_teacher.py @@ -76,6 +76,7 @@ def test_model_init(unimodal_model_kwargs): def test_model_forward(unimodal_teacher_model, sample_batch): + # assert False y, _ = unimodal_teacher_model(sample_batch, return_modes=True) # check output is the correct shape diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py deleted file mode 100644 index 15a26a58..00000000 --- a/tests/models/test_utils.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -import torch - -from pvnet.models.utils import WeightedLosses - - -def test_weight_losses_weights(): - """Test weighted loss""" - forecast_length = 2 - w = WeightedLosses(forecast_length=forecast_length) - - assert w.weights.cpu().numpy()[0] == pytest.approx(4 / 3) - assert w.weights.cpu().numpy()[1] == pytest.approx(2 / 3) - - -def test_mae_exp(): - """Test MAE exp with weighted loss""" - forecast_length = 2 - w = WeightedLosses(forecast_length=forecast_length) - - output = torch.Tensor([[1, 3], [1, 3]]) - target = torch.Tensor([[1, 5], [1, 9]]) - - loss = w.get_mae_exp(output=output, target=target) - - # 0.5((1-1)*2/3 + (5-3)*1/3) + 0.5((1-1)*2/3 + (9-3)*1/3) = 1/3 + 3/3 - assert loss == pytest.approx(4 / 3) - - -def test_mse_exp(): - """Test MSE exp with weighted loss""" - forecast_length = 2 - w = WeightedLosses(forecast_length=forecast_length) - - output = torch.Tensor([[1, 3], [1, 3]]) - target = torch.Tensor([[1, 5], [1, 9]]) - - loss = w.get_mse_exp(output=output, target=target) - - # 0.5((1-1)^2*2/3 + (5-3)^2*1/3) + 0.5((1-1)^2*2/3 + (9-3)^2*1/3) = 2/3 + 18/3 - assert loss == pytest.approx(20 / 3) - - -def test_mae_exp_rand(): - """Test MAE exp with weighted loss with random tensors""" - forecast_length = 6 - batch_size = 32 - - w = WeightedLosses(forecast_length=6) - - output = torch.randn(batch_size, forecast_length) - target = torch.randn(batch_size, forecast_length) - - loss = w.get_mae_exp(output=output, target=target) - assert loss > 0 - - -def test_mse_exp_rand(): - """Test MSE exp with weighted loss with random tensors""" - forecast_length = 6 - batch_size = 32 - - w = WeightedLosses(forecast_length=6) - - output = torch.randn(batch_size, forecast_length) - target = torch.randn(batch_size, forecast_length) - - loss = w.get_mse_exp(output=output, target=target) - assert loss > 0 diff --git a/tests/test_data/sample_batches/data_configuration.yaml b/tests/test_data/presaved_batches/data_configuration.yaml similarity index 100% rename from tests/test_data/sample_batches/data_configuration.yaml rename to tests/test_data/presaved_batches/data_configuration.yaml diff --git a/tests/test_data/sample_batches/datamodule.yaml b/tests/test_data/presaved_batches/datamodule.yaml similarity index 100% rename from tests/test_data/sample_batches/datamodule.yaml rename to tests/test_data/presaved_batches/datamodule.yaml diff --git a/tests/test_data/sample_batches/train/000000.pt b/tests/test_data/presaved_batches/train/000000.pt similarity index 100% rename from tests/test_data/sample_batches/train/000000.pt rename to tests/test_data/presaved_batches/train/000000.pt diff --git a/tests/test_data/sample_batches/train/000001.pt b/tests/test_data/presaved_batches/train/000001.pt similarity index 100% rename from tests/test_data/sample_batches/train/000001.pt rename to tests/test_data/presaved_batches/train/000001.pt diff --git a/tests/test_data/presaved_samples/data_configuration.yaml b/tests/test_data/presaved_samples/data_configuration.yaml new file mode 100644 index 00000000..c65402eb --- /dev/null +++ b/tests/test_data/presaved_samples/data_configuration.yaml @@ -0,0 +1,130 @@ +general: + description: Config for producing batches on GCP + name: gcp_pvnet + +input_data: + gsp: + gsp_zarr_path: /mnt/disks/nwp_rechunk/pv_gsp/pvlive_gsp.zarr + interval_start_minutes: -120 + interval_end_minutes: 480 + time_resolution_minutes: 30 + # A random value from the list below will be chosen as the delay when dropout is used + # If set to null no dropout is applied. Only values before t0 are dropped out for GSP. + # Values after t0 are assumed as targets and cannot be dropped. + dropout_timedeltas_minutes: null + dropout_fraction: 0 # Fraction of samples with dropout + + nwp: + ukv: + nwp_provider: ukv + nwp_zarr_path: + - /mnt/disks/nwp_rechunk/nwp/ukv/UKV_intermediate_version_7.1.zarr + - /mnt/disks/nwp_rechunk/nwp/ukv/UKV_2021_missing.zarr + - /mnt/disks/nwp_rechunk/nwp/ukv/UKV_2022.zarr + - /mnt/disks/nwp_rechunk/nwp/ukv/UKV_2023.zarr + interval_start_minutes: -120 + interval_end_minutes: 480 + time_resolution_minutes: 60 + nwp_channels: + # These variables exist in the CEDA training set and in the live MetOffice live service + - t # 2-metre temperature + - dswrf # downwards short-wave radiation flux + - dlwrf # downwards long-wave radiation flux + - hcc # high cloud cover + - mcc # medium cloud cover + - lcc # low cloud cover + - sde # snow depth water equivalent + - r # relative humidty + - vis # visibility + - si10 # 10-metre wind speed + - prate # precipitation rate + nwp_image_size_pixels_height: 24 + nwp_image_size_pixels_width: 24 + dropout_timedeltas_minutes: [-180] + dropout_fraction: 1.0 + max_staleness_minutes: null + + ecmwf: + nwp_provider: ecmwf + nwp_zarr_path: /mnt/disks/nwp_rechunk/nwp/ecmwf/UK_v2.zarr + interval_start_minutes: -120 + interval_end_minutes: 480 + time_resolution_minutes: 60 + + nwp_channels: + - t2m # 2-metre temperature + - dswrf # downwards short-wave radiation flux + - dlwrf # downwards long-wave radiation flux + - hcc # high cloud cover + - mcc # medium cloud cover + - lcc # low cloud cover + - tcc # total cloud cover + - sde # snow depth water equivalent + - sr # direct solar radiation + - duvrs # downwards UV radiation at surface + - u10 # 10-metre U component of wind speed + - v10 # 10-metre V component of wind speed + + # The following channels are accumulated and need to be diffed + nwp_accum_channels: + - dswrf + - dlwrf + - sr + - duvrs + + nwp_image_size_pixels_height: 12 # roughly equivalent to ukv 48 + nwp_image_size_pixels_width: 12 # roughly equivalent to ukv 48 + dropout_timedeltas_minutes: [-360] + dropout_fraction: 1.0 + max_staleness_minutes: null + + sat_pred: + nwp_provider: sat_pred + nwp_zarr_path: /mnt/disks/sat_preds/simvp_preds/*.zarr + interval_start_minutes: 15 + interval_end_minutes: 180 + time_resolution_minutes: 15 + nwp_channels: + - IR_016 + - IR_039 + - IR_087 + - IR_097 + - IR_108 + - IR_120 + - IR_134 + - VIS006 + - VIS008 + - WV_062 + - WV_073 + nwp_image_size_pixels_height: 24 + nwp_image_size_pixels_width: 24 + dropout_timedeltas_minutes: null + dropout_fraction: 0 + max_staleness_minutes: null + + satellite: + satellite_zarr_path: + - /mnt/disks/nwp_rechunk/sat/2019_nonhrv.zarr + - /mnt/disks/nwp_rechunk/sat/2020_nonhrv.zarr + - /mnt/disks/nwp_rechunk/sat/2021_nonhrv.zarr + - /mnt/disks/nwp_rechunk/sat/2022_nonhrv.zarr + - /mnt/disks/nwp_rechunk/sat/2023_nonhrv.zarr + interval_start_minutes: -30 + interval_end_minutes: 0 + time_resolution_minutes: 5 + satellite_channels: + - IR_016 + - IR_039 + - IR_087 + - IR_097 + - IR_108 + - IR_120 + - IR_134 + - VIS006 + - VIS008 + - WV_062 + - WV_073 + satellite_image_size_pixels_height: 24 + satellite_image_size_pixels_width: 24 + dropout_timedeltas_minutes: null + dropout_fraction: 0. diff --git a/tests/test_data/presaved_samples/datamodule.yaml b/tests/test_data/presaved_samples/datamodule.yaml new file mode 100644 index 00000000..4b661250 --- /dev/null +++ b/tests/test_data/presaved_samples/datamodule.yaml @@ -0,0 +1,18 @@ +_target_: pvnet.data.datamodule.DataModule +configuration: /home/jamesfulton/repos/PVNet/configs/datamodule/configuration/gcp_configuration.yaml +num_workers: 160 +prefetch_factor: 1 +batch_size: 8 +train_period: + - null + - "2022-05-07" +val_period: + - "2023-01-01" + - "2023-05-08" +test_period: + - "2022-05-08" + - "2023-05-08" +seed: ${seed} +sample_output_dir: /mnt/disks/extra_batches/samples_with_simvp_sat_pred +num_train_samples: 1600000 +num_val_samples: 64000 diff --git a/tests/test_data/presaved_samples/train/00000000.pt b/tests/test_data/presaved_samples/train/00000000.pt new file mode 100644 index 00000000..292e1a65 Binary files /dev/null and b/tests/test_data/presaved_samples/train/00000000.pt differ diff --git a/tests/test_data/presaved_samples/train/00000001.pt b/tests/test_data/presaved_samples/train/00000001.pt new file mode 100644 index 00000000..4a46e05b Binary files /dev/null and b/tests/test_data/presaved_samples/train/00000001.pt differ diff --git a/tests/test_data/presaved_samples/train/00000002.pt b/tests/test_data/presaved_samples/train/00000002.pt new file mode 100644 index 00000000..f7fa3986 Binary files /dev/null and b/tests/test_data/presaved_samples/train/00000002.pt differ diff --git a/tests/test_data/presaved_samples/train/00000003.pt b/tests/test_data/presaved_samples/train/00000003.pt new file mode 100644 index 00000000..6c4a4666 Binary files /dev/null and b/tests/test_data/presaved_samples/train/00000003.pt differ diff --git a/tests/test_data/presaved_samples/train/00000004.pt b/tests/test_data/presaved_samples/train/00000004.pt new file mode 100644 index 00000000..13169acb Binary files /dev/null and b/tests/test_data/presaved_samples/train/00000004.pt differ diff --git a/tests/test_data/presaved_samples/train/00000005.pt b/tests/test_data/presaved_samples/train/00000005.pt new file mode 100644 index 00000000..670f5978 Binary files /dev/null and b/tests/test_data/presaved_samples/train/00000005.pt differ diff --git a/tests/test_data/presaved_samples/train/00000006.pt b/tests/test_data/presaved_samples/train/00000006.pt new file mode 100644 index 00000000..8e238b6b Binary files /dev/null and b/tests/test_data/presaved_samples/train/00000006.pt differ diff --git a/tests/test_data/presaved_samples/train/00000007.pt b/tests/test_data/presaved_samples/train/00000007.pt new file mode 100644 index 00000000..c62cd85c Binary files /dev/null and b/tests/test_data/presaved_samples/train/00000007.pt differ diff --git a/tests/test_end2end.py b/tests/test_end2end.py index 7a7161d5..90a02222 100644 --- a/tests/test_end2end.py +++ b/tests/test_end2end.py @@ -2,5 +2,8 @@ def test_model_trainer_fit(multimodal_model, sample_train_val_datamodule): + batch = next(iter(sample_train_val_datamodule.train_dataloader())) + y = multimodal_model(batch) + trainer = lightning.pytorch.trainer.trainer.Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model=multimodal_model, datamodule=sample_train_val_datamodule)