Skip to content

Commit

Permalink
Fix batch sampler (#2097)
Browse files Browse the repository at this point in the history
* Fix batch sampler

* Clean

* Fix tests

* Fix

* Better comment

* Base case
  • Loading branch information
muellerzr authored Oct 30, 2023
1 parent 217faaf commit 5574731
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,9 +833,9 @@ def prepare_data_loader(
synchronized_generator = None
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
sampler = dataloader.sampler.sampler
sampler = getattr(dataloader.sampler, "sampler", None)
else:
sampler = dataloader.batch_sampler.sampler
sampler = getattr(dataloader.batch_sampler, "sampler", None)
if isinstance(sampler, RandomSampler) and num_processes > 1:
# When iterating through the dataloader during distributed processes
# we want to ensure that on each process we are iterating through the same
Expand Down
50 changes: 49 additions & 1 deletion src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from copy import deepcopy
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

from accelerate import Accelerator
from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
Expand Down Expand Up @@ -288,6 +289,52 @@ def central_dl_preparation_check():
print("Shuffled central dataloader passing.")


def custom_sampler_check():
state = AcceleratorState()

class CustomDataset(Dataset):
def __init__(self, data):
self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index]

class CustomBatchSampler:
def __init__(self, dataset_length: int, batch_size: int, shuffle: bool = True):
self.batch_size = batch_size
self.data_index = np.arange(dataset_length)
self.shuffle = shuffle

def __iter__(self):
num_batches = len(self)
if self.shuffle:
index = np.random.permutation(self.data_index)
else:
index = self.data_index
output = np.array_split(index, num_batches)
yield from output

def __len__(self):
return math.ceil(len(self.data_index) / self.batch_size)

dataset = CustomDataset(range(32 * state.num_processes))
sampler = CustomBatchSampler(len(dataset), batch_size=8)
dl = DataLoader(dataset, batch_sampler=sampler)
dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index)
# We need just ensure that `dl.batch_sampler` (or `dl.batch_sampler.batch_sampler` is indeed the old batch sampler
if hasattr(dl.batch_sampler, "batch_sampler"):
assert isinstance(
dl.batch_sampler.batch_sampler, CustomBatchSampler
), "Custom sampler was changed after calling `prepare_data_loader`"
else:
assert isinstance(
dl.batch_sampler, CustomBatchSampler
), "Custom sampler was changed after calling `prepare_data_loader`"


def mock_training(length, batch_size, generator):
set_seed(42)
generator.manual_seed(42)
Expand Down Expand Up @@ -608,6 +655,7 @@ def main():
dl_preparation_check()
if state.distributed_type != DistributedType.TPU:
central_dl_preparation_check()
custom_sampler_check()

# Trainings are not exactly the same in DeepSpeed and CPU mode
if state.distributed_type == DistributedType.DEEPSPEED:
Expand Down

0 comments on commit 5574731

Please sign in to comment.