Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Oct 16, 2023
1 parent de4c6a1 commit 1f8752d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
6 changes: 5 additions & 1 deletion src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def __init__(self, *args, **kwargs):

def __iter__(self):
g = torch.Generator()
g.manual_seed(self.epoch)
if self.generator is not None:
seed = self.epoch + self.generator.initial_seed()
else:
seed = self.epoch
g.manual_seed(seed)
n = len(self.data_source)
if not self.replacement:
items = torch.randperm(n, generator=g).tolist()
Expand Down
10 changes: 8 additions & 2 deletions src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.utils.data import DataLoader

from accelerate import Accelerator
from accelerate.data_loader import prepare_data_loader
from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
from accelerate.state import AcceleratorState
from accelerate.test_utils import RegressionDataset, are_the_same_tensors
from accelerate.utils import (
Expand Down Expand Up @@ -292,7 +292,12 @@ def mock_training(length, batch_size, generator):
set_seed(42)
generator.manual_seed(42)
train_set = RegressionDataset(length=length, seed=42)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
sampler = SeedableRandomSampler(
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(3):
Expand Down Expand Up @@ -325,6 +330,7 @@ def training_check():
generator.manual_seed(42)
for epoch in range(3):
for batch in train_dl:
accelerator.print(f"Batch during accelerate training: {batch}")
model.zero_grad()
output = model(batch["x"])
loss = torch.nn.functional.mse_loss(output, batch["y"])
Expand Down

0 comments on commit 1f8752d

Please sign in to comment.