Skip to content

Commit

Permalink
Fix seedable sampler logic and expound docs (#2434)
Browse files Browse the repository at this point in the history
* Fix and add more docs

* Add tests + ensure working

* Fixup all tests!
  • Loading branch information
muellerzr authored Feb 13, 2024
1 parent ecebfa1 commit ef68b46
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class Accelerator:
Whether or not use a fully seedable random sampler ([`~data_loader.SeedableRandomSampler`]). Ensures
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
also be ran with [`~utils.set_seed`] each time for the best results.
step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`):
Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only
done under certain circumstances (at the end of each epoch, for instance).
Expand Down
22 changes: 12 additions & 10 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,16 @@ class SeedableRandomSampler(RandomSampler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.epoch = 0
self.seed = torch.random.initial_seed()
self.initial_seed = torch.random.initial_seed()

def __iter__(self):
if self.generator is None:
self.generator = torch.Generator()
else:
self.seed = self.generator.initial_seed()
self.generator.manual_seed(self.initial_seed)

# Allow `self.epoch` to modify the seed of the generator
seed = self.epoch + self.seed
seed = self.epoch + self.initial_seed
# print("Setting seed at epoch", self.epoch, seed)
self.generator.manual_seed(seed)
yield from super().__iter__()
self.set_epoch(self.epoch + 1)
Expand Down Expand Up @@ -809,7 +810,8 @@ def prepare_data_loader(
use_seedable_sampler (`bool`, *optional*, defaults to `False`):
Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
reproducability. Comes at a cost of potentially different performances due to different shuffling
algorithms but ensures results will be the *exact* same.
algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
`self.set_epoch`
Returns:
`torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
Expand Down Expand Up @@ -927,11 +929,6 @@ def prepare_data_loader(
kwargs["batch_size"] = (
dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
)
if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:
if sampler_is_batch_sampler:
dataloader.sampler.sampler = sampler
else:
dataloader.batch_sampler.sampler = sampler
if dispatch_batches:
kwargs.pop("generator")
dataloader = DataLoaderDispatcher(
Expand Down Expand Up @@ -964,6 +961,11 @@ def prepare_data_loader(
**kwargs,
)

if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:
if sampler_is_batch_sampler:
dataloader.sampler.sampler = sampler
else:
dataloader.batch_sampler.sampler = sampler
if state.distributed_type == DistributedType.TPU:
return MpDeviceLoaderWrapper(dataloader, device)
return dataloader
Expand Down
76 changes: 49 additions & 27 deletions src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@
from accelerate.test_utils import RegressionModel


def generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler=False):
"Creates a dataloader that can also use the `SeedableRandomSampler`"
if use_seedable_sampler:
# The SeedableRandomSampler is needed during distributed setups
# for full reproducability across processes with the `DataLoader`
sampler = SeedableRandomSampler(
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
return DataLoader(train_set, batch_size=batch_size, sampler=sampler)
else:
return DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)


def print_main(state):
print(f"Printing from the main process {state.process_index}")

Expand Down Expand Up @@ -335,22 +350,36 @@ def __len__(self):
), "Custom sampler was changed after calling `prepare_data_loader`"


def check_seedable_sampler():
# Set seed
set_seed(42)
train_set = RegressionDataset(length=10, seed=42)
train_dl = DataLoader(train_set, batch_size=2, shuffle=True)
accelerator = Accelerator(use_seedable_sampler=True)
train_dl = accelerator.prepare(train_dl)
original_items = []
for _ in range(3):
for batch in train_dl:
original_items.append(batch["x"])
original_items = torch.cat(original_items)

# Set seed again and the epoch
set_seed(42)
train_dl.set_epoch(0)
new_items = []
for _ in range(3):
for batch in train_dl:
new_items.append(batch["x"])
new_items = torch.cat(new_items)
assert torch.allclose(original_items, new_items), "Did not obtain the same items with the same seed and epoch."


def mock_training(length, batch_size, generator, use_seedable_sampler=False):
set_seed(42)
generator.manual_seed(42)
train_set = RegressionDataset(length=length, seed=42)

if use_seedable_sampler:
# The SeedableRandomSampler is needed during distributed setups
# for full reproducability across processes with the `DataLoader`
sampler = SeedableRandomSampler(
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
else:
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(3):
Expand All @@ -374,17 +403,7 @@ def training_check(use_seedable_sampler=False):
assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes."

accelerator = Accelerator()
if use_seedable_sampler:
# The SeedableRandomSampler is needed during distributed setups
# for full reproducability across processes with the `DataLoader`
sampler = SeedableRandomSampler(
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
else:
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

Expand All @@ -406,7 +425,9 @@ def training_check(use_seedable_sampler=False):
accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.")

accelerator = Accelerator(split_batches=True, use_seedable_sampler=use_seedable_sampler)
train_dl = DataLoader(train_set, batch_size=batch_size * state.num_processes, shuffle=True, generator=generator)
train_dl = generate_baseline_dataloader(
train_set, generator, batch_size * state.num_processes, use_seedable_sampler
)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

Expand All @@ -432,7 +453,7 @@ def training_check(use_seedable_sampler=False):
print("FP16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="fp16", use_seedable_sampler=use_seedable_sampler)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

Expand Down Expand Up @@ -472,7 +493,7 @@ def training_check(use_seedable_sampler=False):
print("BF16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="bf16", use_seedable_sampler=use_seedable_sampler)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

Expand All @@ -496,7 +517,7 @@ def training_check(use_seedable_sampler=False):
print("ipex BF16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="bf16", cpu=True, use_seedable_sampler=use_seedable_sampler)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

Expand All @@ -520,7 +541,7 @@ def training_check(use_seedable_sampler=False):
print("xpu BF16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="bf16", cpu=False, use_seedable_sampler=use_seedable_sampler)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

Expand Down Expand Up @@ -667,6 +688,7 @@ def main():
if state.distributed_type != DistributedType.TPU:
central_dl_preparation_check()
custom_sampler_check()
check_seedable_sampler()

# Trainings are not exactly the same in DeepSpeed and CPU mode
if state.distributed_type == DistributedType.DEEPSPEED:
Expand Down

0 comments on commit ef68b46

Please sign in to comment.