Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batch sampler #2097

Merged
merged 7 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,9 +833,9 @@ def prepare_data_loader(
synchronized_generator = None
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
sampler = dataloader.sampler.sampler
sampler = getattr(dataloader.sampler, "sampler", None)
else:
sampler = dataloader.batch_sampler.sampler
sampler = getattr(dataloader.batch_sampler, "sampler", None)
if isinstance(sampler, RandomSampler) and num_processes > 1:
# When iterating through the dataloader during distributed processes
# we want to ensure that on each process we are iterating through the same
Expand Down
45 changes: 44 additions & 1 deletion src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from copy import deepcopy
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

from accelerate import Accelerator
from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
Expand Down Expand Up @@ -288,6 +289,47 @@ def central_dl_preparation_check():
print("Shuffled central dataloader passing.")


def custom_sampler_check():
state = AcceleratorState()

class CustomDataset(Dataset):
def __init__(self, data):
self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index]

class CustomBatchSampler:
def __init__(self, dataset_length: int, batch_size: int, shuffle: bool = True):
self.batch_size = batch_size
self.data_index = np.arange(dataset_length)
self.shuffle = shuffle

def __iter__(self):
num_batches = len(self)
if self.shuffle:
index = np.random.permutation(self.data_index)
else:
index = self.data_index
output = np.array_split(index, num_batches)
yield from output

def __len__(self):
return math.ceil(len(self.data_index) / self.batch_size)

dataset = CustomDataset(range(32 * state.num_processes))
sampler = CustomBatchSampler(len(dataset), batch_size=8)
dl = DataLoader(dataset, batch_sampler=sampler)
dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True)
# We need just ensure that `dl.batch_sampler` remains unchanged
assert isinstance(
dl.batch_sampler, CustomBatchSampler
), "Custom sampler was changed after calling `prepare_data_loader`"


def mock_training(length, batch_size, generator):
set_seed(42)
generator.manual_seed(42)
Expand Down Expand Up @@ -608,6 +650,7 @@ def main():
dl_preparation_check()
if state.distributed_type != DistributedType.TPU:
central_dl_preparation_check()
custom_sampler_check()

# Trainings are not exactly the same in DeepSpeed and CPU mode
if state.distributed_type == DistributedType.DEEPSPEED:
Expand Down
Loading