diff --git a/pvnet_summation/data/datamodule.py b/pvnet_summation/data/datamodule.py index 14d85ca..498ab0c 100644 --- a/pvnet_summation/data/datamodule.py +++ b/pvnet_summation/data/datamodule.py @@ -145,7 +145,6 @@ def __init__( self.batch_dir = batch_dir self._common_dataloader_kwargs = dict( - shuffle=False, # shuffled in datapipe step batch_size=None, # batched in datapipe step sampler=None, batch_sampler=None, @@ -164,7 +163,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False) if shuffle: - file_pipeline = file_pipeline.shuffle(buffer_size=1000) + file_pipeline = file_pipeline.shuffle(buffer_size=10_000) file_pipeline = file_pipeline.sharding_filter() @@ -228,14 +227,14 @@ def train_dataloader(self, shuffle=True, add_filename=False): datapipe = self._get_premade_batches_datapipe( "train", shuffle=shuffle, add_filename=add_filename ) - return DataLoader(datapipe, **self._common_dataloader_kwargs) + return DataLoader(datapipe, shuffle=shuffle, **self._common_dataloader_kwargs) def val_dataloader(self, shuffle=False, add_filename=False): """Construct val dataloader""" datapipe = self._get_premade_batches_datapipe( "val", shuffle=shuffle, add_filename=add_filename ) - return DataLoader(datapipe, **self._common_dataloader_kwargs) + return DataLoader(datapipe, shuffle=shuffle, **self._common_dataloader_kwargs) def test_dataloader(self): """Construct test dataloader""" @@ -265,7 +264,6 @@ def __init__( self.batch_dir = batch_dir self._common_dataloader_kwargs = dict( - shuffle=False, # shuffled in datapipe step batch_size=None, # batched in datapipe step sampler=None, batch_sampler=None, @@ -284,7 +282,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False): file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False) if shuffle: - file_pipeline = file_pipeline.shuffle(buffer_size=1000) + file_pipeline = file_pipeline.shuffle(buffer_size=10_000) sample_pipeline = file_pipeline.sharding_filter().map(torch.load) @@ -300,21 +298,21 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False): return batch_pipeline - def train_dataloader(self, shuffle=True): + def train_dataloader(self): """Construct train dataloader""" datapipe = self._get_premade_batches_datapipe( "train", - shuffle=shuffle, + shuffle=True, ) - return DataLoader(datapipe, **self._common_dataloader_kwargs) + return DataLoader(datapipe, shuffle=True, **self._common_dataloader_kwargs) - def val_dataloader(self, shuffle=False): + def val_dataloader(self): """Construct val dataloader""" datapipe = self._get_premade_batches_datapipe( "val", - shuffle=shuffle, + shuffle=False, ) - return DataLoader(datapipe, **self._common_dataloader_kwargs) + return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs) def test_dataloader(self): """Construct test dataloader"""