diff --git a/pvnet/data/wind_datamodule.py b/pvnet/data/wind_datamodule.py index 7c1b4a45..c0f3476a 100644 --- a/pvnet/data/wind_datamodule.py +++ b/pvnet/data/wind_datamodule.py @@ -95,7 +95,9 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False): filenames=list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")), ) data_pipeline = ( - data_pipeline.batch(self.batch_size).map(stack_np_examples_into_batch).map(batch_to_tensor) + data_pipeline.batch(self.batch_size) + .map(stack_np_examples_into_batch) + .map(batch_to_tensor) ) if shuffle: data_pipeline = (