Skip to content

Commit

Permalink
[data_loader] Optionally also propagate set_epoch to batch sampler (#…
Browse files Browse the repository at this point in the history
…3246)

* Optionally also propagate set_epoch to batch sampler

* Add simple batch sampler set_epoch test
  • Loading branch information
tomaarsen authored Nov 20, 2024
1 parent d7b1b36 commit 77f2b62
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,8 @@ def set_epoch(self, epoch: int):
# In case it is manually passed in, the user can set it to what they like
if self.iteration != epoch:
self.iteration = epoch
if hasattr(self.batch_sampler, "set_epoch"):
self.batch_sampler.set_epoch(epoch)
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(epoch)
# We support if a custom `Dataset` implementation has `set_epoch`
Expand Down
29 changes: 29 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,21 @@ def set_epoch(self, epoch):
self.epoch = epoch


class SimpleBatchSampler(BatchSampler):
def __init__(self, sampler, batch_size, drop_last, generator, seed):
super().__init__(sampler, batch_size, drop_last)
self.generator = generator
self.seed = seed
self.epoch = 0

def __iter__(self):
self.generator.manual_seed(self.seed + self.epoch)
return super().__iter__()

def set_epoch(self, epoch):
self.epoch = epoch


class DataLoaderTester(unittest.TestCase):
def check_batch_sampler_shards(self, batch_sampler, expected, split_batches=False, even_batches=True):
batch_sampler_shards = [
Expand Down Expand Up @@ -469,6 +484,20 @@ def test_end_of_dataloader_dispatcher(self):
for idx, _ in enumerate(dataloader):
assert dataloader.end_of_dataloader == (idx == 3)

def test_set_epoch_in_batch_sampler(self):
# Ensure that set_epoch gets propagated to custom batch samplers that accept it
dataset = list(range(16))
generator = torch.Generator()
batch_sampler = SimpleBatchSampler(dataset, batch_size=4, drop_last=False, generator=generator, seed=12)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)

accelerator = Accelerator()
dataloader = accelerator.prepare_data_loader(dataloader)

assert batch_sampler.epoch == 0
dataloader.set_epoch(1)
assert batch_sampler.epoch == 1


class StatefulDataLoaderTester(unittest.TestCase):
@require_torchdata_stateful_dataloader
Expand Down

0 comments on commit 77f2b62

Please sign in to comment.