diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index bf3f35fb7e8..36be770dbef 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -588,6 +588,8 @@ def set_epoch(self, epoch: int): # In case it is manually passed in, the user can set it to what they like if self.iteration != epoch: self.iteration = epoch + if hasattr(self.batch_sampler, "set_epoch"): + self.batch_sampler.set_epoch(epoch) if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"): self.batch_sampler.sampler.set_epoch(epoch) # We support if a custom `Dataset` implementation has `set_epoch` diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 4ffbf29c134..658c9806cbf 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -80,6 +80,21 @@ def set_epoch(self, epoch): self.epoch = epoch +class SimpleBatchSampler(BatchSampler): + def __init__(self, sampler, batch_size, drop_last, generator, seed): + super().__init__(sampler, batch_size, drop_last) + self.generator = generator + self.seed = seed + self.epoch = 0 + + def __iter__(self): + self.generator.manual_seed(self.seed + self.epoch) + return super().__iter__() + + def set_epoch(self, epoch): + self.epoch = epoch + + class DataLoaderTester(unittest.TestCase): def check_batch_sampler_shards(self, batch_sampler, expected, split_batches=False, even_batches=True): batch_sampler_shards = [ @@ -469,6 +484,20 @@ def test_end_of_dataloader_dispatcher(self): for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) + def test_set_epoch_in_batch_sampler(self): + # Ensure that set_epoch gets propagated to custom batch samplers that accept it + dataset = list(range(16)) + generator = torch.Generator() + batch_sampler = SimpleBatchSampler(dataset, batch_size=4, drop_last=False, generator=generator, seed=12) + dataloader = DataLoader(dataset, batch_sampler=batch_sampler) + + accelerator = Accelerator() + dataloader = accelerator.prepare_data_loader(dataloader) + + assert batch_sampler.epoch == 0 + dataloader.set_epoch(1) + assert batch_sampler.epoch == 1 + class StatefulDataLoaderTester(unittest.TestCase): @require_torchdata_stateful_dataloader