diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index 5c348940..35aad131 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -7,6 +7,7 @@ from ocf_datapipes.utils.utils import stack_np_examples_into_batch from torch.utils.data import DataLoader from torch.utils.data.datapipes.iter import FileLister + from pvnet.data.utils import batch_to_tensor diff --git a/pvnet/data/utils.py b/pvnet/data/utils.py index c85ce40c..00cdeb78 100644 --- a/pvnet/data/utils.py +++ b/pvnet/data/utils.py @@ -1,7 +1,7 @@ import numpy as np import torch from ocf_datapipes.utils.consts import BatchKey -from torch.utils.data import functional_datapipe, IterDataPipe +from torch.utils.data import IterDataPipe, functional_datapipe def copy_batch_to_device(batch, device): diff --git a/pvnet/data/wind_datamodule.py b/pvnet/data/wind_datamodule.py index 58ad42da..bb3896f7 100644 --- a/pvnet/data/wind_datamodule.py +++ b/pvnet/data/wind_datamodule.py @@ -1,11 +1,11 @@ """ Data module for pytorch lightning """ +import glob from datetime import datetime from lightning.pytorch import LightningDataModule from ocf_datapipes.training.windnet import windnet_netcdf_datapipe from ocf_datapipes.utils.utils import stack_np_examples_into_batch from torch.utils.data import DataLoader -import glob from pvnet.data.utils import batch_to_tensor @@ -46,7 +46,6 @@ def __init__( self.batch_size = batch_size self.batch_dir = batch_dir - if batch_dir is not None: if any([period != [None, None] for period in [train_period, val_period, test_period]]): raise ValueError("Cannot set `(train/val/test)_period` with presaved batches") @@ -79,7 +78,7 @@ def __init__( def _get_datapipe(self, start_time, end_time): data_pipeline = windnet_netcdf_datapipe( self.configuration, - keys=["sensor","nwp"], + keys=["sensor", "nwp"], ) data_pipeline = ( @@ -90,9 +89,11 @@ def _get_datapipe(self, start_time, end_time): return data_pipeline def _get_premade_batches_datapipe(self, subdir, shuffle=False): - data_pipeline = windnet_netcdf_datapipe(config_filename=self.configuration, - keys=["sensor","nwp"], - filenames=list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc"))) + data_pipeline = windnet_netcdf_datapipe( + config_filename=self.configuration, + keys=["sensor", "nwp"], + filenames=list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")), + ) if shuffle: data_pipeline = ( data_pipeline.shuffle(buffer_size=100) diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index f5290517..eea38c94 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -2,6 +2,7 @@ from pvnet.data.wind_datamodule import WindDataModule import os + def test_init(): dm = DataModule( configuration=None,