From 3426f26ba910aff632d04523d2c5985f7a6e1e26 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 16 Oct 2023 22:12:22 +0000 Subject: [PATCH] Comment --- src/accelerate/data_loader.py | 2 +- .../test_utils/scripts/test_script.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index ffa0cddb84d..6f4eeece6b1 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -817,7 +817,7 @@ def prepare_data_loader( sampler = dataloader.sampler.sampler else: sampler = dataloader.batch_sampler.sampler - if isinstance(sampler, RandomSampler): + if isinstance(sampler, RandomSampler) and num_processes > 1: sampler = SeedableRandomSampler( data_source=sampler.data_source, replacement=sampler.replacement, diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index adcc2c55271..7b4f20ccfd2 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -292,12 +292,17 @@ def mock_training(length, batch_size, generator): set_seed(42) generator.manual_seed(42) train_set = RegressionDataset(length=length, seed=42) - sampler = SeedableRandomSampler( - generator=generator, - data_source=train_set, - num_samples=len(train_set), - ) - train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler) + if AcceleratorState().num_processes > 1: + # The SeedableRandomSampler is needed during distributed setups + # for full reproducability across processes with the `DataLoader` + sampler = SeedableRandomSampler( + generator=generator, + data_source=train_set, + num_samples=len(train_set), + ) + train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler) + else: + train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) for epoch in range(3): @@ -330,7 +335,6 @@ def training_check(): generator.manual_seed(42) for epoch in range(3): for batch in train_dl: - accelerator.print(f"Batch during accelerate training: {batch}") model.zero_grad() output = model(batch["x"]) loss = torch.nn.functional.mse_loss(output, batch["y"])