diff --git a/pvnet_summation/data/datamodule.py b/pvnet_summation/data/datamodule.py index 64ba4e8..23279c0 100644 --- a/pvnet_summation/data/datamodule.py +++ b/pvnet_summation/data/datamodule.py @@ -152,10 +152,13 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals if shuffle: file_pipeline = file_pipeline.shuffle(buffer_size=1000) + + file_pipeline = file_pipeline.sharding_filter() + if add_filename: - file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=50) + file_pipeline, file_pipeline_copy = file_pipeline.fork(2, buffer_size=5) - sample_pipeline = file_pipeline.sharding_filter().map(torch.load) + sample_pipeline = file_pipeline.map(torch.load) # Find national outout simultaneous to concurrent samples gsp_data = (