From f8ba1f7000f1b8d82e6f8f0e0a78ae2ec75a2f5e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Nov 2023 15:51:42 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/data/datamodule.py | 29 ++++++++++++----------------- scripts/save_batches.py | 23 +++++++++++------------ scripts/save_concurrent_batches.py | 13 ++++++------- 3 files changed, 29 insertions(+), 36 deletions(-) diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index 83f10487..bef05f6e 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -3,16 +3,13 @@ import numpy as np import torch -from torch.utils.data import DataLoader -from torch.utils.data.datapipes.datapipe import IterDataPipe -from torch.utils.data.datapipes._decorator import functional_datapipe - from lightning.pytorch import LightningDataModule - from ocf_datapipes.training.pvnet import pvnet_datapipe from ocf_datapipes.utils.consts import BatchKey from ocf_datapipes.utils.utils import stack_np_examples_into_batch - +from torch.utils.data import DataLoader +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe def copy_batch_to_device(batch, device): @@ -120,20 +117,20 @@ def __init__( self.test_period = [ None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in test_period ] - + self._common_dataloader_kwargs = dict( - shuffle=False, # shuffled in datapipe step - batch_size=None, # batched in datapipe step + shuffle=False, # shuffled in datapipe step + batch_size=None, # batched in datapipe step sampler=None, - batch_sampler=None, - num_workers=num_workers, + batch_sampler=None, + num_workers=num_workers, collate_fn=None, - pin_memory=False, - drop_last=False, + pin_memory=False, + drop_last=False, timeout=0, worker_init_fn=None, prefetch_factor=prefetch_factor, - persistent_workers=False + persistent_workers=False, ) def _get_datapipe(self, start_time, end_time): @@ -184,7 +181,7 @@ def train_dataloader(self): datapipe = self._get_premade_batches_datapipe("train", shuffle=True) else: datapipe = self._get_datapipe(*self.train_period) - rs = MultiProcessingReadingService(**self.readingservice_config) + MultiProcessingReadingService(**self.readingservice_config) return DataLoader(datapipe, **self._common_dataloader_kwargs) def val_dataloader(self): @@ -195,7 +192,6 @@ def val_dataloader(self): datapipe = self._get_datapipe(*self.val_period) return DataLoader(datapipe, **self._common_dataloader_kwargs) - def test_dataloader(self): """Construct test dataloader""" if self.batch_dir is not None: @@ -203,4 +199,3 @@ def test_dataloader(self): else: datapipe = self._get_datapipe(*self.test_period) return DataLoader(datapipe, **self._common_dataloader_kwargs) - diff --git a/scripts/save_batches.py b/scripts/save_batches.py index a3526a76..e5d5ba49 100644 --- a/scripts/save_batches.py +++ b/scripts/save_batches.py @@ -82,7 +82,7 @@ def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader def main(config: DictConfig): "Constructs and saves validation and training batches." config_dm = config.datamodule - + print_config(config, resolve=False) # Set up directory @@ -98,20 +98,20 @@ def main(config: DictConfig): dataloader_kwargs = dict( shuffle=False, - batch_size=None, # batched in datapipe step + batch_size=None, # batched in datapipe step sampler=None, - batch_sampler=None, - num_workers=config_dm.num_workers, + batch_sampler=None, + num_workers=config_dm.num_workers, collate_fn=None, - pin_memory=False, - drop_last=False, + pin_memory=False, + drop_last=False, timeout=0, worker_init_fn=None, prefetch_factor=config_dm.prefetch_factor, - persistent_workers=False + persistent_workers=False, ) - - if config.num_val_batches>0: + + if config.num_val_batches > 0: print("----- Saving val batches -----") val_batch_pipe = _get_datapipe( @@ -126,9 +126,8 @@ def main(config: DictConfig): num_batches=config.num_val_batches, dataloader_kwargs=dataloader_kwargs, ) - - if config.num_train_batches>0: + if config.num_train_batches > 0: print("----- Saving train batches -----") train_batch_pipe = _get_datapipe( @@ -148,4 +147,4 @@ def main(config: DictConfig): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/save_concurrent_batches.py b/scripts/save_concurrent_batches.py index 7fe22a38..72ded577 100644 --- a/scripts/save_concurrent_batches.py +++ b/scripts/save_concurrent_batches.py @@ -161,20 +161,19 @@ def main(config: DictConfig): os.mkdir(f"{config.batch_output_dir}/train") os.mkdir(f"{config.batch_output_dir}/val") - dataloader_kwargs = dict( shuffle=False, - batch_size=None, # batched in datapipe step + batch_size=None, # batched in datapipe step sampler=None, - batch_sampler=None, - num_workers=config_dm.num_workers, + batch_sampler=None, + num_workers=config_dm.num_workers, collate_fn=None, - pin_memory=False, - drop_last=False, + pin_memory=False, + drop_last=False, timeout=0, worker_init_fn=None, prefetch_factor=config_dm.prefetch_factor, - persistent_workers=False + persistent_workers=False, ) print("----- Saving val batches -----")