Skip to content

Commit

Permalink
Allow for samplers to be seedable and reproducable (#2057)
Browse files Browse the repository at this point in the history
* bookmark

* Works!

* Working!

* Fully working now

* Cover dataset

* Needed for dispatch

* Check both

* Bring back pop, fix hang

* Fully working

* Change back to epoch

* Adjust for new methods

* Clean

* Fix tests

* Avoid circular import

* Clean

* Fix test

* Comment

* Add a comment

* Comment

* Use yield from instead
  • Loading branch information
muellerzr authored Oct 24, 2023
1 parent 11e2e99 commit 7843286
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 15 deletions.
7 changes: 7 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2796,6 +2796,9 @@ def _inner(folder):
elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
schedulers = self._schedulers

# Save the samplers of the dataloaders
dataloaders = self._dataloaders

# Call model loading hooks that might have been registered with
# accelerator.register_model_state_hook
for hook in self._save_model_state_pre_hook.values():
Expand All @@ -2806,6 +2809,7 @@ def _inner(folder):
weights,
optimizers,
schedulers,
dataloaders,
self.state.process_index,
self.scaler,
save_on_each_node=self.project_configuration.save_on_each_node,
Expand Down Expand Up @@ -2935,6 +2939,8 @@ def _inner(folder):
elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
schedulers = self._schedulers

dataloaders = self._dataloaders

# Call model loading hooks that might have been registered with
# accelerator.register_model_state_hook
for hook in self._load_model_state_pre_hook.values():
Expand All @@ -2955,6 +2961,7 @@ def _inner(folder):
models,
optimizers,
schedulers,
dataloaders,
self.state.process_index,
self.scaler,
map_location,
Expand Down
40 changes: 40 additions & 0 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import numpy as np
import torch
from torch.cuda.amp import GradScaler
from torch.utils.data import BatchSampler

from .utils import (
MODEL_NAME,
OPTIMIZER_NAME,
RNG_STATE_NAME,
SAMPLER_NAME,
SCALER_NAME,
SCHEDULER_NAME,
get_pretty_name,
Expand All @@ -49,6 +51,7 @@ def save_accelerator_state(
model_states: List[dict],
optimizers: list,
schedulers: list,
dataloaders: list,
process_index: int,
scaler: GradScaler = None,
save_on_each_node: bool = False,
Expand All @@ -65,6 +68,8 @@ def save_accelerator_state(
A list of optimizer instances
schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
A list of learning rate schedulers
dataloaders (`List[torch.utils.data.DataLoader]`):
A list of dataloader instances to save their sampler states
process_index (`int`):
The current process index in the Accelerator state
scaler (`torch.cuda.amp.GradScaler`, *optional*):
Expand Down Expand Up @@ -92,6 +97,22 @@ def save_accelerator_state(
output_scheduler_file = os.path.join(output_dir, scheduler_name)
save(state, output_scheduler_file, save_on_each_node=save_on_each_node)
logger.info(f"Scheduler state saved in {output_scheduler_file}")
# DataLoader states
for i, dataloader in enumerate(dataloaders):
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
output_sampler_file = os.path.join(output_dir, sampler_name)
# Only save if we have our custom sampler
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
sampler = dataloader.sampler.sampler
else:
sampler = dataloader.batch_sampler.sampler
from .data_loader import SeedableRandomSampler

if isinstance(sampler, SeedableRandomSampler):
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node)
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")

# GradScaler state
if scaler is not None:
state = scaler.state_dict()
Expand Down Expand Up @@ -121,6 +142,7 @@ def load_accelerator_state(
models,
optimizers,
schedulers,
dataloaders,
process_index,
scaler=None,
map_location=None,
Expand Down Expand Up @@ -177,6 +199,24 @@ def load_accelerator_state(
scheduler.load_state_dict(torch.load(input_scheduler_file))
logger.info("All scheduler states loaded successfully")

for i, dataloader in enumerate(dataloaders):
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
input_sampler_file = os.path.join(input_dir, sampler_name)
# Only load if we have our custom sampler
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
sampler = dataloader.sampler.sampler
else:
sampler = dataloader.batch_sampler.sampler
from .data_loader import SeedableRandomSampler

if isinstance(sampler, SeedableRandomSampler):
if sampler_is_batch_sampler:
dataloader.sampler.sampler = torch.load(input_sampler_file)
else:
dataloader.batch_sampler.sampler = torch.load(input_sampler_file)
logger.info("All dataloader sampler states loaded successfully")

# GradScaler state
if scaler is not None:
input_scaler_file = os.path.join(input_dir, SCALER_NAME)
Expand Down
110 changes: 97 additions & 13 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Callable, List, Optional, Union

import torch
from torch.utils.data import BatchSampler, DataLoader, IterableDataset
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler

from .logging import get_logger
from .state import AcceleratorState, DistributedType, GradientState, is_tpu_available
Expand Down Expand Up @@ -64,6 +64,41 @@
_PYTORCH_DATALOADER_KWARGS.update(additional_kwargs)


class SeedableRandomSampler(RandomSampler):
"""
Same as a random sampler, except that in `__iter__` a seed can be used.
Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
and be fully reproducable on multiple iterations.
If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
(stored in `self.epoch`).
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.epoch = 0

def __iter__(self):
g = torch.Generator()
if self.generator is not None:
seed = self.epoch + self.generator.initial_seed()
else:
seed = self.epoch
g.manual_seed(seed)
n = len(self.data_source)
# Taken 1:1 from torch.utils.data.sampler.RandomSampler.__iter__
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=g).tolist()
else:
yield from torch.randperm(n, generator=g).tolist()

def set_epoch(self, epoch: int):
"Sets the current iteration of the sampler."
self.epoch = epoch


class BatchSamplerShard(BatchSampler):
"""
Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will
Expand Down Expand Up @@ -271,6 +306,11 @@ def __init__(
self.process_index = process_index
self.split_batches = split_batches

def set_epoch(self, epoch):
self.epoch = epoch
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)

def __len__(self):
# We will just raise the downstream error if the underlying dataset is not sized
if self.drop_last:
Expand All @@ -279,6 +319,12 @@ def __len__(self):
return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size

def __iter__(self):
if (
not hasattr(self.dataset, "set_epoch")
and hasattr(self.dataset, "generator")
and isinstance(self.dataset.generator, torch.Generator)
):
self.dataset.generator.manual_seed(self.epoch)
real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes)
process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size
process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size)
Expand Down Expand Up @@ -391,11 +437,14 @@ def __init__(
self.skip_batches = skip_batches
self.gradient_state = GradientState()
self._drop_last = _drop_last
self.iteration = 0

def __iter__(self):
if self.rng_types is not None:
synchronize_rng_states(self.rng_types, self.synchronized_generator)
self.begin()

self.set_epoch(self.iteration)
dataloader_iter = super().__iter__()
# We iterate one batch ahead to check when we are at the end
try:
Expand All @@ -419,8 +468,21 @@ def __iter__(self):
if batch_index >= self.skip_batches:
yield current_batch
break

self.iteration += 1
self.end()

def set_epoch(self, epoch: int):
# In case it is manually passed in, the user can set it to what they like
if self.iteration != epoch:
self.iteration = epoch
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(epoch)
# We support if a custom `Dataset` implementation has `set_epoch`
# or in general HF datasets `Datasets`
elif hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)

@property
def total_batch_size(self):
batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
Expand Down Expand Up @@ -524,6 +586,7 @@ def __init__(
self.skip_batches = skip_batches

self.slice_fn = slice_tensors if slice_fn is None else slice_fn
self.iteration = 0

def _fetch_batches(self, iterator):
batches, batch = None, None
Expand Down Expand Up @@ -564,6 +627,7 @@ def _fetch_batches(self, iterator):

def __iter__(self):
self.begin()
self.set_epoch(self.iteration)
main_iterator = None
if is_torch_version(">=", "2.0.1"):
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
Expand Down Expand Up @@ -633,8 +697,18 @@ def __iter__(self):
if batch_index >= self.skip_batches:
yield batch
batch_index += 1
self.iteration += 1
self.end()

def set_epoch(self, epoch: int):
# In case it is manually passed in, the user can set it to what they like
if self.iteration != epoch:
self.iteration = epoch
if hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(epoch)
elif hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)

def __len__(self):
whole_length = super().__len__()
if self.split_batches:
Expand Down Expand Up @@ -757,6 +831,23 @@ def prepare_data_loader(
new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
sampler_is_batch_sampler = False
synchronized_generator = None
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
sampler = dataloader.sampler.sampler
else:
sampler = dataloader.batch_sampler.sampler
if isinstance(sampler, RandomSampler) and num_processes > 1:
# When iterating through the dataloader during distributed processes
# we want to ensure that on each process we are iterating through the same
# samples in the same order if a seed is set. This requires a tweak
# to the `torch.utils.data.RandomSampler` class (if used).
sampler = SeedableRandomSampler(
data_source=sampler.data_source,
replacement=sampler.replacement,
num_samples=sampler._num_samples,
generator=getattr(sampler, "generator", torch.Generator()),
)

# No change if no multiprocess
if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
if isinstance(new_dataset, IterableDataset):
Expand All @@ -771,17 +862,6 @@ def prepare_data_loader(
split_batches=split_batches,
)
else:
# New batch sampler for the current process.
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
sampler = dataloader.sampler.sampler
else:
sampler = dataloader.batch_sampler.sampler
if hasattr(sampler, "generator"):
if sampler.generator is None:
sampler.generator = torch.Generator()
synchronized_generator = sampler.generator

batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
new_batch_sampler = BatchSamplerShard(
batch_sampler,
Expand Down Expand Up @@ -815,7 +895,11 @@ def prepare_data_loader(
kwargs["batch_size"] = (
dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
)

if isinstance(sampler, SeedableRandomSampler):
if sampler_is_batch_sampler:
dataloader.sampler.sampler = sampler
else:
dataloader.batch_sampler.sampler = sampler
if dispatch_batches:
kwargs.pop("generator")
dataloader = DataLoaderDispatcher(
Expand Down
14 changes: 12 additions & 2 deletions src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.utils.data import DataLoader

from accelerate import Accelerator
from accelerate.data_loader import prepare_data_loader
from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
from accelerate.state import AcceleratorState
from accelerate.test_utils import RegressionDataset, are_the_same_tensors
from accelerate.utils import (
Expand Down Expand Up @@ -292,7 +292,17 @@ def mock_training(length, batch_size, generator):
set_seed(42)
generator.manual_seed(42)
train_set = RegressionDataset(length=length, seed=42)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
if AcceleratorState().num_processes > 1:
# The SeedableRandomSampler is needed during distributed setups
# for full reproducability across processes with the `DataLoader`
sampler = SeedableRandomSampler(
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
else:
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(3):
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
RNG_STATE_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
SAMPLER_NAME,
SCALER_NAME,
SCHEDULER_NAME,
TORCH_DISTRIBUTED_OPERATION_TYPES,
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
RNG_STATE_NAME = "random_states"
OPTIMIZER_NAME = "optimizer"
SCHEDULER_NAME = "scheduler"
SAMPLER_NAME = "sampler"
WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
SAFE_WEIGHTS_NAME = "model.safetensors"
Expand Down

0 comments on commit 7843286

Please sign in to comment.