From 55747318a0f47cdfbc281e11269ba96214e4092d Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 30 Oct 2023 09:57:28 -0400 Subject: [PATCH] Fix batch sampler (#2097) * Fix batch sampler * Clean * Fix tests * Fix * Better comment * Base case --- src/accelerate/data_loader.py | 4 +- .../test_utils/scripts/test_script.py | 50 ++++++++++++++++++- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index b80916d4dba..8337d399a34 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -833,9 +833,9 @@ def prepare_data_loader( synchronized_generator = None sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) if sampler_is_batch_sampler: - sampler = dataloader.sampler.sampler + sampler = getattr(dataloader.sampler, "sampler", None) else: - sampler = dataloader.batch_sampler.sampler + sampler = getattr(dataloader.batch_sampler, "sampler", None) if isinstance(sampler, RandomSampler) and num_processes > 1: # When iterating through the dataloader during distributed processes # we want to ensure that on each process we are iterating through the same diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index 7b4f20ccfd2..e5acff0e7ed 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -21,8 +21,9 @@ from copy import deepcopy from pathlib import Path +import numpy as np import torch -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from accelerate import Accelerator from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader @@ -288,6 +289,52 @@ def central_dl_preparation_check(): print("Shuffled central dataloader passing.") +def custom_sampler_check(): + state = AcceleratorState() + + class CustomDataset(Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.data[index] + + class CustomBatchSampler: + def __init__(self, dataset_length: int, batch_size: int, shuffle: bool = True): + self.batch_size = batch_size + self.data_index = np.arange(dataset_length) + self.shuffle = shuffle + + def __iter__(self): + num_batches = len(self) + if self.shuffle: + index = np.random.permutation(self.data_index) + else: + index = self.data_index + output = np.array_split(index, num_batches) + yield from output + + def __len__(self): + return math.ceil(len(self.data_index) / self.batch_size) + + dataset = CustomDataset(range(32 * state.num_processes)) + sampler = CustomBatchSampler(len(dataset), batch_size=8) + dl = DataLoader(dataset, batch_sampler=sampler) + dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index) + # We need just ensure that `dl.batch_sampler` (or `dl.batch_sampler.batch_sampler` is indeed the old batch sampler + if hasattr(dl.batch_sampler, "batch_sampler"): + assert isinstance( + dl.batch_sampler.batch_sampler, CustomBatchSampler + ), "Custom sampler was changed after calling `prepare_data_loader`" + else: + assert isinstance( + dl.batch_sampler, CustomBatchSampler + ), "Custom sampler was changed after calling `prepare_data_loader`" + + def mock_training(length, batch_size, generator): set_seed(42) generator.manual_seed(42) @@ -608,6 +655,7 @@ def main(): dl_preparation_check() if state.distributed_type != DistributedType.TPU: central_dl_preparation_check() + custom_sampler_check() # Trainings are not exactly the same in DeepSpeed and CPU mode if state.distributed_type == DistributedType.DEEPSPEED: