Skip to content

Commit

Permalink
Always use SeedableRandomSampler (#2110)
Browse files Browse the repository at this point in the history
* Fix tests fully

* Change comment

* Further comments

* Clean

* CPU specific

* Just use device

* Rewrite differently

* Rewrite
  • Loading branch information
muellerzr authored Nov 1, 2023
1 parent 5b3f3b9 commit d8e1285
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,10 @@ def set_epoch(self, epoch: int):
# In case it is manually passed in, the user can set it to what they like
if self.iteration != epoch:
self.iteration = epoch
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
if hasattr(self.batch_sampler, "set_epoch"):
# Case: `SkipBatchSampler`
self.batch_sampler.set_epoch(epoch)
elif hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(epoch)
# We support if a custom `Dataset` implementation has `set_epoch`
# or in general HF datasets `Datasets`
Expand Down Expand Up @@ -836,17 +839,19 @@ def prepare_data_loader(
sampler = getattr(dataloader.sampler, "sampler", None)
else:
sampler = getattr(dataloader.batch_sampler, "sampler", None)
if isinstance(sampler, RandomSampler) and num_processes > 1:
# When iterating through the dataloader during distributed processes
# we want to ensure that on each process we are iterating through the same
# samples in the same order if a seed is set. This requires a tweak
# to the `torch.utils.data.RandomSampler` class (if used).
sampler = SeedableRandomSampler(
data_source=sampler.data_source,
replacement=sampler.replacement,
num_samples=sampler._num_samples,
generator=getattr(sampler, "generator", torch.Generator()),
)
if isinstance(sampler, RandomSampler):
# CPU's specifically do not require this workaround
if not ((num_processes == 1) and (device.type == "cpu")):
# When iterating through the dataloader we want to ensure that
# on each process we are iterating through the same
# samples in the same order if a seed is set. This requires a tweak
# to the `torch.utils.data.RandomSampler` class (if used).
sampler = SeedableRandomSampler(
data_source=sampler.data_source,
replacement=sampler.replacement,
num_samples=sampler._num_samples,
generator=getattr(sampler, "generator", torch.Generator()),
)

# No change if no multiprocess
if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
Expand Down

0 comments on commit d8e1285

Please sign in to comment.