Skip to content

Commit

Permalink
Comment
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Oct 16, 2023
1 parent 1f8752d commit 3426f26
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def prepare_data_loader(
sampler = dataloader.sampler.sampler
else:
sampler = dataloader.batch_sampler.sampler
if isinstance(sampler, RandomSampler):
if isinstance(sampler, RandomSampler) and num_processes > 1:
sampler = SeedableRandomSampler(
data_source=sampler.data_source,
replacement=sampler.replacement,
Expand Down
18 changes: 11 additions & 7 deletions src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,17 @@ def mock_training(length, batch_size, generator):
set_seed(42)
generator.manual_seed(42)
train_set = RegressionDataset(length=length, seed=42)
sampler = SeedableRandomSampler(
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
if AcceleratorState().num_processes > 1:
# The SeedableRandomSampler is needed during distributed setups
# for full reproducability across processes with the `DataLoader`
sampler = SeedableRandomSampler(
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
else:
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(3):
Expand Down Expand Up @@ -330,7 +335,6 @@ def training_check():
generator.manual_seed(42)
for epoch in range(3):
for batch in train_dl:
accelerator.print(f"Batch during accelerate training: {batch}")
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
Expand Down

0 comments on commit 3426f26

Please sign in to comment.