Skip to content

Commit

Permalink
fix shuffling
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Jun 20, 2024
1 parent 061a1ed commit 87187b1
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions pvnet_summation/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def __init__(
self.batch_dir = batch_dir

self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
Expand All @@ -164,7 +163,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False, add_filename=Fals
file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False)

if shuffle:
file_pipeline = file_pipeline.shuffle(buffer_size=1000)
file_pipeline = file_pipeline.shuffle(buffer_size=10_000)

file_pipeline = file_pipeline.sharding_filter()

Expand Down Expand Up @@ -228,14 +227,14 @@ def train_dataloader(self, shuffle=True, add_filename=False):
datapipe = self._get_premade_batches_datapipe(
"train", shuffle=shuffle, add_filename=add_filename
)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=shuffle, **self._common_dataloader_kwargs)

def val_dataloader(self, shuffle=False, add_filename=False):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe(
"val", shuffle=shuffle, add_filename=add_filename
)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=shuffle, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
Expand Down Expand Up @@ -265,7 +264,6 @@ def __init__(
self.batch_dir = batch_dir

self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
Expand All @@ -284,7 +282,7 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False):
file_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False)

if shuffle:
file_pipeline = file_pipeline.shuffle(buffer_size=1000)
file_pipeline = file_pipeline.shuffle(buffer_size=10_000)

sample_pipeline = file_pipeline.sharding_filter().map(torch.load)

Expand All @@ -300,21 +298,21 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False):

return batch_pipeline

def train_dataloader(self, shuffle=True):
def train_dataloader(self):
"""Construct train dataloader"""
datapipe = self._get_premade_batches_datapipe(
"train",
shuffle=shuffle,
shuffle=True,
)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=True, **self._common_dataloader_kwargs)

def val_dataloader(self, shuffle=False):
def val_dataloader(self):
"""Construct val dataloader"""
datapipe = self._get_premade_batches_datapipe(
"val",
shuffle=shuffle,
shuffle=False,
)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
Expand Down

0 comments on commit 87187b1

Please sign in to comment.