From 2143caa3382655319690914f4a506853bce9ceb7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jul 2023 10:00:31 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- configs/config.yaml | 4 +- configs/datamodule/default.yaml | 2 +- configs/model/default.yaml | 1 - pvnet_summation/__init__.py | 2 +- pvnet_summation/data/datamodule.py | 157 ++++++++++++--------------- pvnet_summation/models/base_model.py | 40 +++---- pvnet_summation/models/model.py | 39 +++---- pvnet_summation/training.py | 40 +++---- pvnet_summation/utils.py | 7 +- requirements.txt | 2 +- run.py | 5 +- tests/conftest.py | 39 +++---- tests/data/test_datamodule.py | 32 +++--- tests/test_end2end.py | 2 +- 14 files changed, 159 insertions(+), 213 deletions(-) diff --git a/configs/config.yaml b/configs/config.yaml index 08c8fb4..0cb36a1 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -10,8 +10,8 @@ defaults: - hydra: default.yaml # Whether to loop through the PVNet outputs and save them out before training -presave_pvnet_outputs: True - +presave_pvnet_outputs: + True # enable color logging # - override hydra/hydra_logging: colorlog diff --git a/configs/datamodule/default.yaml b/configs/datamodule/default.yaml index 3ed7921..4e9f6af 100644 --- a/configs/datamodule/default.yaml +++ b/configs/datamodule/default.yaml @@ -3,4 +3,4 @@ batch_dir: "/mnt/disks/bigbatches/concurrent_batches_v3.6_-60mins" gsp_zarr_path: "/mnt/disks/nwp/pv_gsp.zarr" batch_size: 8 num_workers: 20 -prefetch_factor: 2 \ No newline at end of file +prefetch_factor: 2 diff --git a/configs/model/default.yaml b/configs/model/default.yaml index 0418c70..d481aaf 100644 --- a/configs/model/default.yaml +++ b/configs/model/default.yaml @@ -18,7 +18,6 @@ output_network_kwargs: res_block_layers: 2 dropout_frac: 0.0 - # Foreast and time settings forecast_minutes: 480 diff --git a/pvnet_summation/__init__.py b/pvnet_summation/__init__.py index ed2c3c5..ed53582 100644 --- a/pvnet_summation/__init__.py +++ b/pvnet_summation/__init__.py @@ -1 +1 @@ -"""PVNet_summation""" \ No newline at end of file +"""PVNet_summation""" diff --git a/pvnet_summation/data/datamodule.py b/pvnet_summation/data/datamodule.py index b91d0c4..029a952 100644 --- a/pvnet_summation/data/datamodule.py +++ b/pvnet_summation/data/datamodule.py @@ -2,36 +2,30 @@ import torch from lightning.pytorch import LightningDataModule -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService -from torchdata.datapipes.iter import FileLister, IterDataPipe -from ocf_datapipes.utils.consts import BatchKey from ocf_datapipes.load import OpenGSP from ocf_datapipes.training.pvnet import normalize_gsp -from torchdata.datapipes.iter import Zipper +from ocf_datapipes.utils.consts import BatchKey +from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService +from torchdata.datapipes.iter import FileLister, IterDataPipe, Zipper -from pvnet.data.datamodule import ( - copy_batch_to_device, - batch_to_tensor, - split_batches, -) # https://github.com/pytorch/pytorch/issues/973 -torch.multiprocessing.set_sharing_strategy('file_system') +torch.multiprocessing.set_sharing_strategy("file_system") class GetNationalPVLive(IterDataPipe): """Select national output targets for given times""" + def __init__(self, gsp_data, times_datapipe): """Select national output targets for given times - + Args: gsp_data: xarray Dataarray of the national outputs times_datapipe: IterDataPipe yeilding arrays of target times. """ self.gsp_data = gsp_data self.times_datapipe = times_datapipe - + def __iter__(self): - gsp_data = self.gsp_data for times in self.times_datapipe: national_outputs = torch.as_tensor( @@ -42,54 +36,54 @@ def __iter__(self): class GetBatchTime(IterDataPipe): """Extract the valid times from the concurrent sample batch""" - + def __init__(self, sample_datapipe): """Extract the valid times from the concurrent sample batch - + Args: sample_datapipe: IterDataPipe yeilding concurrent sample batches """ self.sample_datapipe = sample_datapipe - + def __iter__(self): for sample in self.sample_datapipe: - # Times for each GSP in the sample batch should be the same - take first + # Times for each GSP in the sample batch should be the same - take first id0 = sample[BatchKey.gsp_t0_idx] - times = sample[BatchKey.gsp_time_utc][0, id0+1:] + times = sample[BatchKey.gsp_time_utc][0, id0 + 1 :] yield times - + class PivotDictList(IterDataPipe): """Convert list of dicts to dict of lists""" - + def __init__(self, source_datapipe): """Convert list of dicts to dict of lists - + Args: - source_datapipe: + source_datapipe: """ self.source_datapipe = source_datapipe - + def __iter__(self): for list_of_dicts in self.source_datapipe: keys = list_of_dicts[0].keys() batch_dict = {k: [d[k] for d in list_of_dicts] for k in keys} yield batch_dict - - + + class DictApply(IterDataPipe): """Apply functions to elements of a dictionary and return processed dictionary.""" - + def __init__(self, source_datapipe, **transforms): """Apply functions to elements of a dictionary and return processed dictionary. - + Args: source_datapipe: Datapipe which yields dicts **transforms: key-function pairs """ self.source_datapipe = source_datapipe self.transforms = transforms - + def __iter__(self): for d in self.source_datapipe: for key, function in self.transforms.items(): @@ -99,21 +93,21 @@ def __iter__(self): class ZipperDict(IterDataPipe): """Yield samples from multiple datapipes as a dict""" - + def __init__(self, **datapipes): """Yield samples from multiple datapipes as a dict. - + Args: **datapipes: Named datapipes """ self.keys = list(datapipes.keys()) self.source_datapipes = Zipper(*[datapipes[key] for key in self.keys]) - + def __iter__(self): for outputs in self.source_datapipes: yield {key: value for key, value in zip(self.keys, outputs)} - + class DataModule(LightningDataModule): """Datamodule for training pvnet_summation.""" @@ -144,95 +138,83 @@ def __init__( multiprocessing_context="spawn", worker_prefetch_cnt=prefetch_factor, ) - + def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=False): - # Load presaved concurrent sample batches file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False) - + if shuffle: file_pipeline = file_pipeline.shuffle(buffer_size=1000) - + file_pipeline = file_pipeline.sharding_filter() - + if add_filename: file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=5) - + sample_pipeline = file_pipeline.map(torch.load) - + # Find national outout simultaneous to concurrent samples gsp_data = ( - next(iter( - OpenGSP(gsp_pv_power_zarr_path=self.gsp_zarr_path) - .map(normalize_gsp) - )) + next(iter(OpenGSP(gsp_pv_power_zarr_path=self.gsp_zarr_path).map(normalize_gsp))) .sel(gsp_id=0) .compute() ) - + sample_pipeline, sample_pipeline_copy = sample_pipeline.fork(2, buffer_size=5) - - times_datapipe, times_datapipe_copy = ( - GetBatchTime(sample_pipeline_copy).fork(2, buffer_size=5) + + times_datapipe, times_datapipe_copy = GetBatchTime(sample_pipeline_copy).fork( + 2, buffer_size=5 ) - + national_targets_datapipe = GetNationalPVLive(gsp_data, times_datapipe_copy) - + # Compile the samples if add_filename: data_pipeline = ZipperDict( - pvnet_inputs = sample_pipeline, - national_targets = national_targets_datapipe, - times = times_datapipe, - filepath = file_pipeline_copy, + pvnet_inputs=sample_pipeline, + national_targets=national_targets_datapipe, + times=times_datapipe, + filepath=file_pipeline_copy, ) else: data_pipeline = ZipperDict( - pvnet_inputs = sample_pipeline, - national_targets = national_targets_datapipe, - times = times_datapipe, - ) - + pvnet_inputs=sample_pipeline, + national_targets=national_targets_datapipe, + times=times_datapipe, + ) + if self.batch_size is not None: - data_pipeline = PivotDictList(data_pipeline.batch(self.batch_size)) data_pipeline = DictApply( - data_pipeline, - national_targets=torch.stack, + data_pipeline, + national_targets=torch.stack, times=torch.stack, ) - + return data_pipeline - def train_dataloader(self, shuffle=True, add_filename=False): """Construct train dataloader""" datapipe = self._get_premade_batches_datapipe( - "train", - shuffle=shuffle, - add_filename=add_filename + "train", shuffle=shuffle, add_filename=add_filename ) rs = MultiProcessingReadingService(**self.readingservice_config) return DataLoader2(datapipe, reading_service=rs) - def val_dataloader(self, shuffle=False, add_filename=False): """Construct val dataloader""" datapipe = self._get_premade_batches_datapipe( - "val", - shuffle=shuffle, - add_filename=add_filename - ) + "val", shuffle=shuffle, add_filename=add_filename + ) rs = MultiProcessingReadingService(**self.readingservice_config) return DataLoader2(datapipe, reading_service=rs) - def test_dataloader(self): """Construct test dataloader""" raise NotImplementedError - - + + class PVNetPresavedDataModule(LightningDataModule): """Datamodule for loading pre-saved PVNet predictions to train pvnet_summation.""" @@ -260,34 +242,32 @@ def __init__( multiprocessing_context="spawn", worker_prefetch_cnt=prefetch_factor, ) - + def _get_premade_batches_datapipe(self, subdir, shuffle=False): - # Load presaved concurrent sample batches file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False) - + if shuffle: file_pipeline = file_pipeline.shuffle(buffer_size=1000) - - sample_pipeline = file_pipeline.sharding_filter().map(torch.load) - + + sample_pipeline = file_pipeline.sharding_filter().map(torch.load) + if self.batch_size is not None: - batch_pipeline = PivotDictList(sample_pipeline.batch(self.batch_size)) batch_pipeline = DictApply( batch_pipeline, pvnet_outputs=torch.stack, - national_targets=torch.stack, + national_targets=torch.stack, times=torch.stack, ) - + return batch_pipeline def train_dataloader(self, shuffle=True): """Construct train dataloader""" datapipe = self._get_premade_batches_datapipe( - "train", - shuffle=shuffle, + "train", + shuffle=shuffle, ) rs = MultiProcessingReadingService(**self.readingservice_config) @@ -296,13 +276,12 @@ def train_dataloader(self, shuffle=True): def val_dataloader(self, shuffle=False): """Construct val dataloader""" datapipe = self._get_premade_batches_datapipe( - "val", - shuffle=shuffle, - ) + "val", + shuffle=shuffle, + ) rs = MultiProcessingReadingService(**self.readingservice_config) return DataLoader2(datapipe, reading_service=rs) def test_dataloader(self): """Construct test dataloader""" raise NotImplementedError - diff --git a/pvnet_summation/models/base_model.py b/pvnet_summation/models/base_model.py index dec091f..16debf4 100644 --- a/pvnet_summation/models/base_model.py +++ b/pvnet_summation/models/base_model.py @@ -1,28 +1,22 @@ """Base model for all PVNet submodels""" -import json import logging -import os -from pathlib import Path -from typing import Dict, Optional, Union +from typing import Optional -import hydra +import lightning.pytorch as pl import torch import wandb from nowcasting_utils.models.loss import WeightedLosses -import lightning.pytorch as pl -from torch import nn -from pvnet.models.base_model import PVNetModelHubMixin, BaseModel as PVNetBaseModel - -from pvnet.models.utils import ( +from pvnet.models.base_model import BaseModel as PVNetBaseModel +from pvnet.models.base_model import PVNetModelHubMixin +from pvnet.models.utils import ( MetricAccumulator, PredAccumulator, ) - - from pvnet.optimizers import AbstractOptimizer + from pvnet_summation.utils import plot_forecasts -#from pvnet.models.base_model import BaseModel as PVNetBaseModel +# from pvnet.models.base_model import BaseModel as PVNetBaseModel logger = logging.getLogger(__name__) @@ -54,7 +48,7 @@ def __init__( None the output is a single value. """ pl.LightningModule.__init__(self) - PVNetModelHubMixin.__init__(self) + PVNetModelHubMixin.__init__(self) self._optimizer = optimizer @@ -74,20 +68,20 @@ def __init__( self._accumulated_y = PredAccumulator() self._accumulated_y_hat = PredAccumulator() self._accumulated_times = PredAccumulator() - + self.pvnet_model = PVNetBaseModel.from_pretrained( model_name, revision=model_version, ) self.pvnet_model.requires_grad_(False) - + def predict_pvnet_batch(self, batch): gsp_batches = [] for sample in batch: preds = self.pvnet_model(sample) gsp_batches += [preds] return torch.stack(gsp_batches) - + @property def pvnet_output_shape(self): if self.pvnet_model.use_quantile_regression: @@ -95,7 +89,6 @@ def pvnet_output_shape(self): else: return (317, self.pvnet_model.forecast_len_30) - def _training_accumulate_log(self, batch_idx, losses, y_hat, y, times): """Internal function to accumulate training batches and log results. @@ -135,7 +128,7 @@ def _training_accumulate_log(self, batch_idx, losses, y_hat, y, times): def training_step(self, batch, batch_idx): """Run training step""" - + y_hat = self.forward(batch) y = batch["national_targets"] times = batch["times"] @@ -150,10 +143,10 @@ def training_step(self, batch, batch_idx): else: opt_target = losses["MAE/train"] return opt_target - + def validation_step(self, batch: dict, batch_idx): """Run validation step""" - + y_hat = self.forward(batch) y = batch["national_targets"] times = batch["times"] @@ -203,7 +196,7 @@ def validation_step(self, batch: dict, batch_idx): def test_step(self, batch, batch_idx): """Run test step""" - + y_hat = self.forward(batch) y = batch["national_targets"] @@ -220,10 +213,9 @@ def test_step(self, batch, batch_idx): return logged_losses - def configure_optimizers(self): """Configure the optimizers using learning rate found with LR finder if used""" if self.lr is not None: # Use learning rate found by learning rate finder callback self._optimizer.lr = self.lr - return self._optimizer(self.parameters()) \ No newline at end of file + return self._optimizer(self.parameters()) diff --git a/pvnet_summation/models/model.py b/pvnet_summation/models/model.py index 901621e..9214508 100644 --- a/pvnet_summation/models/model.py +++ b/pvnet_summation/models/model.py @@ -3,21 +3,17 @@ from typing import Optional import numpy as np -import torch - import pvnet -from pvnet_summation.models.base_model import BaseModel -from pvnet.optimizers import AbstractOptimizer -from pvnet.models.multimodal.linear_networks.networks import DefaultFCNet +import torch from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork +from pvnet.models.multimodal.linear_networks.networks import DefaultFCNet +from pvnet.optimizers import AbstractOptimizer - +from pvnet_summation.models.base_model import BaseModel class Model(BaseModel): - """Neural network which combines GSP predictions from PVNet - - """ + """Neural network which combines GSP predictions from PVNet""" name = "pvnet_summation_model" @@ -29,8 +25,7 @@ def __init__( output_quantiles: Optional[list[float]] = None, output_network: AbstractLinearNetwork = DefaultFCNet, output_network_kwargs: dict = dict(), - optimizer: AbstractOptimizer = pvnet.optimizers.Adam(), - + optimizer: AbstractOptimizer = pvnet.optimizers.Adam(), ): """Neural network which combines GSP predictions from PVNet @@ -46,16 +41,10 @@ def __init__( optimizer (AbstractOptimizer): Optimizer """ - super().__init__( - forecast_minutes, - model_name, - model_version, - optimizer, - output_quantiles - ) + super().__init__(forecast_minutes, model_name, model_version, optimizer, output_quantiles) in_features = np.product(self.pvnet_output_shape) - + self.model = output_network( in_features=in_features, out_features=self.num_output_features, @@ -64,21 +53,19 @@ def __init__( self.save_hyperparameters() - def forward(self, x): """Run model forward""" - + if "pvnet_outputs" in x: pvnet_out = x["pvnet_outputs"] else: - pvnet_out = self.predict_pvnet_batch(x['pvnet_inputs']) - + pvnet_out = self.predict_pvnet_batch(x["pvnet_inputs"]) + pvnet_out = torch.flatten(pvnet_out, start_dim=1) out = self.model(pvnet_out) - + if self.use_quantile_regression: # Shape: batch_size, seq_length * num_quantiles out = out.reshape(out.shape[0], self.forecast_len_30, len(self.output_quantiles)) - - return out + return out diff --git a/pvnet_summation/training.py b/pvnet_summation/training.py index 5e65549..8388be7 100644 --- a/pvnet_summation/training.py +++ b/pvnet_summation/training.py @@ -15,9 +15,8 @@ from lightning.pytorch.loggers import Logger from lightning.pytorch.loggers.wandb import WandbLogger from omegaconf import DictConfig, OmegaConf -from tqdm import tqdm - from pvnet import utils +from tqdm import tqdm from pvnet_summation.data.datamodule import PVNetPresavedDataModule @@ -67,45 +66,42 @@ def train(config: DictConfig) -> Optional[float]: # Init lightning model log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate(config.model) - + # Presave batches if config.get("presave_pvnet_outputs", False): - - - save_dir = ( f"{config.datamodule.batch_dir}/" f"{config.model.model_name}/" f"{config.model.model_version}" ) - - - + if os.path.isdir(save_dir): log.info( f"PVNet output directory already exists: {save_dir}\n" "Skipping saving new outputs. The existing saved outputs will be loaded." ) - + else: log.info(f"Saving PVNet outputs to {save_dir}") - - os.makedirs(f"{save_dir}/train") + + os.makedirs(f"{save_dir}/train") os.makedirs(f"{save_dir}/val") - - # Set batch size to None so batching is skipped + + # Set batch size to None so batching is skipped datamodule.batch_size = None for dataloader_func, split in [ - (datamodule.train_dataloader, "train"), - (datamodule.val_dataloader, "val") + (datamodule.train_dataloader, "train"), + (datamodule.val_dataloader, "val"), ]: log.info(f"Saving {split} outputs") dataloader = dataloader_func(shuffle=False, add_filename=True) for concurrent_sample_dict in tqdm(dataloader): # Run though model and remove - pvnet_out = model.predict_pvnet_batch([concurrent_sample_dict["pvnet_inputs"]])[0] + pvnet_out = model.predict_pvnet_batch([concurrent_sample_dict["pvnet_inputs"]])[ + 0 + ] del concurrent_sample_dict["pvnet_inputs"] concurrent_sample_dict["pvnet_outputs"] = pvnet_out @@ -114,14 +110,12 @@ def train(config: DictConfig) -> Optional[float]: sample_rel_path = filepath.removeprefix(config.datamodule.batch_dir) sample_path = f"{save_dir}{sample_rel_path}" torch.save(concurrent_sample_dict, sample_path) - - - + datamodule = PVNetPresavedDataModule( batch_dir=save_dir, - batch_size=config.datamodule.batch_size, + batch_size=config.datamodule.batch_size, num_workers=config.datamodule.num_workers, - prefetch_factor=config.datamodule.prefetch_factor + prefetch_factor=config.datamodule.prefetch_factor, ) # Init lightning loggers @@ -163,7 +157,6 @@ def train(config: DictConfig) -> Optional[float]: OmegaConf.save(config.model, f"{callback.dirpath}/model_config.yaml") break - trainer: Trainer = hydra.utils.instantiate( config.trainer, logger=loggers, @@ -174,7 +167,6 @@ def train(config: DictConfig) -> Optional[float]: # Train the model completely trainer.fit(model=model, datamodule=datamodule) - # Make sure everything closed properly log.info("Finalizing!") utils.finish( diff --git a/pvnet_summation/utils.py b/pvnet_summation/utils.py index 711dc3e..e9c7b79 100644 --- a/pvnet_summation/utils.py +++ b/pvnet_summation/utils.py @@ -8,7 +8,6 @@ def plot_forecasts(y, y_hat, times, batch_idx=None, quantiles=None): """Plot a batch of data and the forecast from that batch""" - times_utc = times.cpu().numpy().squeeze().astype("datetime64[s]") times_utc = [pd.to_datetime(t) for t in times_utc] y = y.cpu().numpy() @@ -25,9 +24,7 @@ def plot_forecasts(y, y_hat, times, batch_idx=None, quantiles=None): ax.plot(times_utc[i], y[i], marker=".", color="k", label=r"$y$") if quantiles is None: - ax.plot( - times_utc[i], y_hat[i], marker=".", color="r", label=r"$\hat{y}$" - ) + ax.plot(times_utc[i], y_hat[i], marker=".", color="r", label=r"$\hat{y}$") else: cm = pylab.get_cmap("twilight") for nq, q in enumerate(quantiles): @@ -57,4 +54,4 @@ def plot_forecasts(y, y_hat, times, batch_idx=None, quantiles=None): plt.suptitle(title) plt.tight_layout() - return fig \ No newline at end of file + return fig diff --git a/requirements.txt b/requirements.txt index d80d384..bf8a5e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,4 +27,4 @@ tqdm rich omegaconf hydra-core -python-dotenv \ No newline at end of file +python-dotenv diff --git a/run.py b/run.py index 0c2b3c4..dbcceb2 100644 --- a/run.py +++ b/run.py @@ -12,8 +12,8 @@ pass import logging -import sys import os +import sys # Tired of seeing these warnings import warnings @@ -34,9 +34,10 @@ def main(config: DictConfig): """Runs training""" # Imports should be nested inside @hydra.main to optimize tab completion # Read more here: https://github.com/facebookresearch/hydra/issues/934 - from pvnet_summation.training import train from pvnet.utils import extras, print_config + from pvnet_summation.training import train + # A couple of optional utilities: # - disabling python warnings # - easier access to debug mode diff --git a/tests/conftest.py b/tests/conftest.py index 1ffb578..dfdffa7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,44 +17,41 @@ from pvnet_summation.data.datamodule import DataModule - @pytest.fixture() def sample_data(): - # Copy small batches to fake 317 GSPs in each with tempfile.TemporaryDirectory() as tmpdirname: os.makedirs(f"{tmpdirname}/train") os.makedirs(f"{tmpdirname}/val") - + # Grab times from batch to make national output zarr times = [] - + file_n = 0 for file in glob.glob("tests/data/sample_batches/train/*.pt"): - batch = torch.load(file) - + this_batch = {} for i in range(batch[BatchKey.gsp_time_utc].shape[0]): - # Duplicate sample to fake 317 GSPs + # Duplicate sample to fake 317 GSPs for key in batch.keys(): if isinstance(batch[key], torch.Tensor): n_dims = len(batch[key].shape) - repeats = (317,) + tuple(1 for dim in range(n_dims-1)) - this_batch[key] = batch[key][i:i+1].repeat(repeats)[:317] + repeats = (317,) + tuple(1 for dim in range(n_dims - 1)) + this_batch[key] = batch[key][i : i + 1].repeat(repeats)[:317] else: this_batch[key] = batch[key] - + # Save fopr both train and val torch.save(this_batch, f"{tmpdirname}/train/{file_n:06}.pt") torch.save(this_batch, f"{tmpdirname}/val/{file_n:06}.pt") - + file_n += 1 times += [batch[BatchKey.gsp_time_utc][i].numpy().astype("datetime64[s]")] - + times = np.unique(np.sort(np.concatenate(times))) - + da_output = xr.DataArray( data=np.random.uniform(size=(len(times), 1)), dims=["datetime_gmt", "gsp_id"], @@ -63,7 +60,7 @@ def sample_data(): gsp_id=[0], ), ) - + da_cap = xr.DataArray( data=np.ones((len(times), 1)), dims=["datetime_gmt", "gsp_id"], @@ -72,7 +69,7 @@ def sample_data(): gsp_id=[0], ), ) - + ds = xr.Dataset( data_vars=dict( generation_mw=da_output, @@ -80,9 +77,9 @@ def sample_data(): capacity_mwp=da_cap, ), ) - + ds.to_zarr(f"{tmpdirname}/gsp.zarr") - + yield tmpdirname, f"{tmpdirname}/gsp.zarr" @@ -97,7 +94,7 @@ def sample_datamodule(sample_data): num_workers=0, prefetch_factor=2, ) - + return dm @@ -111,8 +108,8 @@ def sample_batch(sample_datamodule): def model_kwargs(): kwargs = dict( forecast_minutes=480, - model_name= "openclimatefix/pvnet_v2", - model_version= "898630f3f8cd4e8506525d813dd61c6d8de86144", + model_name="openclimatefix/pvnet_v2", + model_version="898630f3f8cd4e8506525d813dd61c6d8de86144", ) return kwargs @@ -126,4 +123,4 @@ def model(model_kwargs): @pytest.fixture() def quantile_model(model_kwargs): model = Model(output_quantiles=[0.1, 0.5, 0.9], **model_kwargs) - return model \ No newline at end of file + return model diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index b23c3bc..9aa2c27 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -13,6 +13,7 @@ def test_init(sample_data): prefetch_factor=2, ) + def test_iter(sample_data): batch_dir, gsp_zarr_dir = sample_data @@ -23,20 +24,21 @@ def test_iter(sample_data): num_workers=0, prefetch_factor=2, ) - + batch = next(iter(dm.train_dataloader())) - + # batch size is 2 - assert len(batch['pvnet_inputs'])==2 - + assert len(batch["pvnet_inputs"]) == 2 + # 317 GSPs in each sample # 21 timestamps for each GSP from -120 mins to +480 mins - assert batch['pvnet_inputs'][0][BatchKey.gsp_time_utc].shape==(317,21) - - assert batch['times'].shape==(2, 16) - - assert batch['national_targets'].shape==(2, 16) - + assert batch["pvnet_inputs"][0][BatchKey.gsp_time_utc].shape == (317, 21) + + assert batch["times"].shape == (2, 16) + + assert batch["national_targets"].shape == (2, 16) + + def test_iter_multiprocessing(sample_data): batch_dir, gsp_zarr_dir = sample_data @@ -47,15 +49,15 @@ def test_iter_multiprocessing(sample_data): num_workers=2, prefetch_factor=2, ) - + for batch in dm.train_dataloader(): # batch size is 2 - assert len(batch['pvnet_inputs'])==2 + assert len(batch["pvnet_inputs"]) == 2 # 317 GSPs in each sample # 21 timestamps for each GSP from -120 mins to +480 mins - assert batch['pvnet_inputs'][0][BatchKey.gsp_time_utc].shape==(317,21) + assert batch["pvnet_inputs"][0][BatchKey.gsp_time_utc].shape == (317, 21) - assert batch['times'].shape==(2, 16) + assert batch["times"].shape == (2, 16) - assert batch['national_targets'].shape==(2, 16) \ No newline at end of file + assert batch["national_targets"].shape == (2, 16) diff --git a/tests/test_end2end.py b/tests/test_end2end.py index 856f883..b054d01 100644 --- a/tests/test_end2end.py +++ b/tests/test_end2end.py @@ -3,4 +3,4 @@ def test_model_trainer_fit(model, sample_datamodule): trainer = lightning.pytorch.trainer.trainer.Trainer(fast_dev_run=True) - trainer.fit(model=model, datamodule=sample_datamodule) \ No newline at end of file + trainer.fit(model=model, datamodule=sample_datamodule)