From 7843286f2e1c50735d259fbc0084a7f1c85e00e3 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 24 Oct 2023 06:41:06 -0400 Subject: [PATCH] Allow for samplers to be seedable and reproducable (#2057) * bookmark * Works! * Working! * Fully working now * Cover dataset * Needed for dispatch * Check both * Bring back pop, fix hang * Fully working * Change back to epoch * Adjust for new methods * Clean * Fix tests * Avoid circular import * Clean * Fix test * Comment * Add a comment * Comment * Use yield from instead --- src/accelerate/accelerator.py | 7 ++ src/accelerate/checkpointing.py | 40 +++++++ src/accelerate/data_loader.py | 110 +++++++++++++++--- .../test_utils/scripts/test_script.py | 14 ++- src/accelerate/utils/__init__.py | 1 + src/accelerate/utils/constants.py | 1 + 6 files changed, 158 insertions(+), 15 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 8fee491e1ea..cd00cf91b79 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2796,6 +2796,9 @@ def _inner(folder): elif self.distributed_type not in [DistributedType.MEGATRON_LM]: schedulers = self._schedulers + # Save the samplers of the dataloaders + dataloaders = self._dataloaders + # Call model loading hooks that might have been registered with # accelerator.register_model_state_hook for hook in self._save_model_state_pre_hook.values(): @@ -2806,6 +2809,7 @@ def _inner(folder): weights, optimizers, schedulers, + dataloaders, self.state.process_index, self.scaler, save_on_each_node=self.project_configuration.save_on_each_node, @@ -2935,6 +2939,8 @@ def _inner(folder): elif self.distributed_type not in [DistributedType.MEGATRON_LM]: schedulers = self._schedulers + dataloaders = self._dataloaders + # Call model loading hooks that might have been registered with # accelerator.register_model_state_hook for hook in self._load_model_state_pre_hook.values(): @@ -2955,6 +2961,7 @@ def _inner(folder): models, optimizers, schedulers, + dataloaders, self.state.process_index, self.scaler, map_location, diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index a3eda8a1d70..12f9aee55ab 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -20,11 +20,13 @@ import numpy as np import torch from torch.cuda.amp import GradScaler +from torch.utils.data import BatchSampler from .utils import ( MODEL_NAME, OPTIMIZER_NAME, RNG_STATE_NAME, + SAMPLER_NAME, SCALER_NAME, SCHEDULER_NAME, get_pretty_name, @@ -49,6 +51,7 @@ def save_accelerator_state( model_states: List[dict], optimizers: list, schedulers: list, + dataloaders: list, process_index: int, scaler: GradScaler = None, save_on_each_node: bool = False, @@ -65,6 +68,8 @@ def save_accelerator_state( A list of optimizer instances schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`): A list of learning rate schedulers + dataloaders (`List[torch.utils.data.DataLoader]`): + A list of dataloader instances to save their sampler states process_index (`int`): The current process index in the Accelerator state scaler (`torch.cuda.amp.GradScaler`, *optional*): @@ -92,6 +97,22 @@ def save_accelerator_state( output_scheduler_file = os.path.join(output_dir, scheduler_name) save(state, output_scheduler_file, save_on_each_node=save_on_each_node) logger.info(f"Scheduler state saved in {output_scheduler_file}") + # DataLoader states + for i, dataloader in enumerate(dataloaders): + sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin" + output_sampler_file = os.path.join(output_dir, sampler_name) + # Only save if we have our custom sampler + sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + if sampler_is_batch_sampler: + sampler = dataloader.sampler.sampler + else: + sampler = dataloader.batch_sampler.sampler + from .data_loader import SeedableRandomSampler + + if isinstance(sampler, SeedableRandomSampler): + save(sampler, output_sampler_file, save_on_each_node=save_on_each_node) + logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}") + # GradScaler state if scaler is not None: state = scaler.state_dict() @@ -121,6 +142,7 @@ def load_accelerator_state( models, optimizers, schedulers, + dataloaders, process_index, scaler=None, map_location=None, @@ -177,6 +199,24 @@ def load_accelerator_state( scheduler.load_state_dict(torch.load(input_scheduler_file)) logger.info("All scheduler states loaded successfully") + for i, dataloader in enumerate(dataloaders): + sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin" + input_sampler_file = os.path.join(input_dir, sampler_name) + # Only load if we have our custom sampler + sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + if sampler_is_batch_sampler: + sampler = dataloader.sampler.sampler + else: + sampler = dataloader.batch_sampler.sampler + from .data_loader import SeedableRandomSampler + + if isinstance(sampler, SeedableRandomSampler): + if sampler_is_batch_sampler: + dataloader.sampler.sampler = torch.load(input_sampler_file) + else: + dataloader.batch_sampler.sampler = torch.load(input_sampler_file) + logger.info("All dataloader sampler states loaded successfully") + # GradScaler state if scaler is not None: input_scaler_file = os.path.join(input_dir, SCALER_NAME) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 2e9b0906ca1..aede8b62f4d 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -17,7 +17,7 @@ from typing import Callable, List, Optional, Union import torch -from torch.utils.data import BatchSampler, DataLoader, IterableDataset +from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler from .logging import get_logger from .state import AcceleratorState, DistributedType, GradientState, is_tpu_available @@ -64,6 +64,41 @@ _PYTORCH_DATALOADER_KWARGS.update(additional_kwargs) +class SeedableRandomSampler(RandomSampler): + """ + Same as a random sampler, except that in `__iter__` a seed can be used. + + Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed + and be fully reproducable on multiple iterations. + + If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on + (stored in `self.epoch`). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.epoch = 0 + + def __iter__(self): + g = torch.Generator() + if self.generator is not None: + seed = self.epoch + self.generator.initial_seed() + else: + seed = self.epoch + g.manual_seed(seed) + n = len(self.data_source) + # Taken 1:1 from torch.utils.data.sampler.RandomSampler.__iter__ + if self.replacement: + for _ in range(self.num_samples // 32): + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=g).tolist() + else: + yield from torch.randperm(n, generator=g).tolist() + + def set_epoch(self, epoch: int): + "Sets the current iteration of the sampler." + self.epoch = epoch + + class BatchSamplerShard(BatchSampler): """ Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will @@ -271,6 +306,11 @@ def __init__( self.process_index = process_index self.split_batches = split_batches + def set_epoch(self, epoch): + self.epoch = epoch + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + def __len__(self): # We will just raise the downstream error if the underlying dataset is not sized if self.drop_last: @@ -279,6 +319,12 @@ def __len__(self): return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size def __iter__(self): + if ( + not hasattr(self.dataset, "set_epoch") + and hasattr(self.dataset, "generator") + and isinstance(self.dataset.generator, torch.Generator) + ): + self.dataset.generator.manual_seed(self.epoch) real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes) process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size) @@ -391,11 +437,14 @@ def __init__( self.skip_batches = skip_batches self.gradient_state = GradientState() self._drop_last = _drop_last + self.iteration = 0 def __iter__(self): if self.rng_types is not None: synchronize_rng_states(self.rng_types, self.synchronized_generator) self.begin() + + self.set_epoch(self.iteration) dataloader_iter = super().__iter__() # We iterate one batch ahead to check when we are at the end try: @@ -419,8 +468,21 @@ def __iter__(self): if batch_index >= self.skip_batches: yield current_batch break + + self.iteration += 1 self.end() + def set_epoch(self, epoch: int): + # In case it is manually passed in, the user can set it to what they like + if self.iteration != epoch: + self.iteration = epoch + if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(epoch) + # We support if a custom `Dataset` implementation has `set_epoch` + # or in general HF datasets `Datasets` + elif hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + @property def total_batch_size(self): batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler @@ -524,6 +586,7 @@ def __init__( self.skip_batches = skip_batches self.slice_fn = slice_tensors if slice_fn is None else slice_fn + self.iteration = 0 def _fetch_batches(self, iterator): batches, batch = None, None @@ -564,6 +627,7 @@ def _fetch_batches(self, iterator): def __iter__(self): self.begin() + self.set_epoch(self.iteration) main_iterator = None if is_torch_version(">=", "2.0.1"): # NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts @@ -633,8 +697,18 @@ def __iter__(self): if batch_index >= self.skip_batches: yield batch batch_index += 1 + self.iteration += 1 self.end() + def set_epoch(self, epoch: int): + # In case it is manually passed in, the user can set it to what they like + if self.iteration != epoch: + self.iteration = epoch + if hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(epoch) + elif hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + def __len__(self): whole_length = super().__len__() if self.split_batches: @@ -757,6 +831,23 @@ def prepare_data_loader( new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None sampler_is_batch_sampler = False synchronized_generator = None + sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + if sampler_is_batch_sampler: + sampler = dataloader.sampler.sampler + else: + sampler = dataloader.batch_sampler.sampler + if isinstance(sampler, RandomSampler) and num_processes > 1: + # When iterating through the dataloader during distributed processes + # we want to ensure that on each process we are iterating through the same + # samples in the same order if a seed is set. This requires a tweak + # to the `torch.utils.data.RandomSampler` class (if used). + sampler = SeedableRandomSampler( + data_source=sampler.data_source, + replacement=sampler.replacement, + num_samples=sampler._num_samples, + generator=getattr(sampler, "generator", torch.Generator()), + ) + # No change if no multiprocess if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches: if isinstance(new_dataset, IterableDataset): @@ -771,17 +862,6 @@ def prepare_data_loader( split_batches=split_batches, ) else: - # New batch sampler for the current process. - sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) - if sampler_is_batch_sampler: - sampler = dataloader.sampler.sampler - else: - sampler = dataloader.batch_sampler.sampler - if hasattr(sampler, "generator"): - if sampler.generator is None: - sampler.generator = torch.Generator() - synchronized_generator = sampler.generator - batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler new_batch_sampler = BatchSamplerShard( batch_sampler, @@ -815,7 +895,11 @@ def prepare_data_loader( kwargs["batch_size"] = ( dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size ) - + if isinstance(sampler, SeedableRandomSampler): + if sampler_is_batch_sampler: + dataloader.sampler.sampler = sampler + else: + dataloader.batch_sampler.sampler = sampler if dispatch_batches: kwargs.pop("generator") dataloader = DataLoaderDispatcher( diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index 9ee508f8be4..7b4f20ccfd2 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -25,7 +25,7 @@ from torch.utils.data import DataLoader from accelerate import Accelerator -from accelerate.data_loader import prepare_data_loader +from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader from accelerate.state import AcceleratorState from accelerate.test_utils import RegressionDataset, are_the_same_tensors from accelerate.utils import ( @@ -292,7 +292,17 @@ def mock_training(length, batch_size, generator): set_seed(42) generator.manual_seed(42) train_set = RegressionDataset(length=length, seed=42) - train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) + if AcceleratorState().num_processes > 1: + # The SeedableRandomSampler is needed during distributed setups + # for full reproducability across processes with the `DataLoader` + sampler = SeedableRandomSampler( + generator=generator, + data_source=train_set, + num_samples=len(train_set), + ) + train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler) + else: + train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) for epoch in range(3): diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 09a5ef42a7a..04ad753db61 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -4,6 +4,7 @@ RNG_STATE_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, + SAMPLER_NAME, SCALER_NAME, SCHEDULER_NAME, TORCH_DISTRIBUTED_OPERATION_TYPES, diff --git a/src/accelerate/utils/constants.py b/src/accelerate/utils/constants.py index bed979e456a..638a8ea4529 100644 --- a/src/accelerate/utils/constants.py +++ b/src/accelerate/utils/constants.py @@ -20,6 +20,7 @@ RNG_STATE_NAME = "random_states" OPTIMIZER_NAME = "optimizer" SCHEDULER_NAME = "scheduler" +SAMPLER_NAME = "sampler" WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" SAFE_WEIGHTS_NAME = "model.safetensors"