diff --git a/pvnet/data/wind_datamodule.py b/pvnet/data/wind_datamodule.py index 425143be..c4773e27 100644 --- a/pvnet/data/wind_datamodule.py +++ b/pvnet/data/wind_datamodule.py @@ -1,4 +1,5 @@ """ Data module for pytorch lightning """ +import glob from datetime import datetime import numpy as np @@ -10,7 +11,6 @@ from torch.utils.data import DataLoader from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import IterDataPipe -import glob def copy_batch_to_device(batch, device): @@ -98,7 +98,6 @@ def __init__( self.batch_size = batch_size self.batch_dir = batch_dir - if batch_dir is not None: if any([period != [None, None] for period in [train_period, val_period, test_period]]): raise ValueError("Cannot set `(train/val/test)_period` with presaved batches") @@ -131,7 +130,7 @@ def __init__( def _get_datapipe(self, start_time, end_time): data_pipeline = windnet_netcdf_datapipe( self.configuration, - keys=["sensor","nwp"], + keys=["sensor", "nwp"], ) data_pipeline = ( @@ -142,9 +141,11 @@ def _get_datapipe(self, start_time, end_time): return data_pipeline def _get_premade_batches_datapipe(self, subdir, shuffle=False): - data_pipeline = windnet_netcdf_datapipe(config_filename=self.configuration, - keys=["sensor","nwp"], - filenames=list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc"))) + data_pipeline = windnet_netcdf_datapipe( + config_filename=self.configuration, + keys=["sensor", "nwp"], + filenames=list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")), + ) if shuffle: data_pipeline = ( data_pipeline.shuffle(buffer_size=100)