Skip to content

Commit

Permalink
Make SeedableRandomSampler the default always (#2117)
Browse files Browse the repository at this point in the history
* Fix tests

* Simplify logic a ~lot~
  • Loading branch information
muellerzr authored Nov 3, 2023
1 parent bd72a5f commit 820fc4c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 25 deletions.
22 changes: 8 additions & 14 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,13 @@ def __init__(self, *args, **kwargs):
self.epoch = 0

def __iter__(self):
g = torch.Generator()
if self.generator is not None:
seed = self.epoch + self.generator.initial_seed()
else:
seed = self.epoch
g.manual_seed(seed)
n = len(self.data_source)
# Taken 1:1 from torch.utils.data.sampler.RandomSampler.__iter__
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=g).tolist()
else:
yield from torch.randperm(n, generator=g).tolist()
if self.generator is None:
self.generator = torch.Generator()
# Allow `self.epoch` to modify the seed of the generator
seed = self.epoch + self.generator.initial_seed()
self.generator.manual_seed(seed)
yield from super().__iter__()
self.set_epoch(self.epoch + 1)

def set_epoch(self, epoch: int):
"Sets the current iteration of the sampler."
Expand Down Expand Up @@ -836,7 +830,7 @@ def prepare_data_loader(
sampler = getattr(dataloader.sampler, "sampler", None)
else:
sampler = getattr(dataloader.batch_sampler, "sampler", None)
if isinstance(sampler, RandomSampler) and num_processes > 1:
if isinstance(sampler, RandomSampler):
# When iterating through the dataloader during distributed processes
# we want to ensure that on each process we are iterating through the same
# samples in the same order if a seed is set. This requires a tweak
Expand Down
20 changes: 9 additions & 11 deletions src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,17 +339,15 @@ def mock_training(length, batch_size, generator):
set_seed(42)
generator.manual_seed(42)
train_set = RegressionDataset(length=length, seed=42)
if AcceleratorState().num_processes > 1:
# The SeedableRandomSampler is needed during distributed setups
# for full reproducability across processes with the `DataLoader`
sampler = SeedableRandomSampler(
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
else:
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)

# The SeedableRandomSampler is needed during distributed setups
# for full reproducability across processes with the `DataLoader`
sampler = SeedableRandomSampler(
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(3):
Expand Down

0 comments on commit 820fc4c

Please sign in to comment.