From 1ba591d84454ddf2425880d853b847ee2f1ac6e2 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 16 Oct 2023 19:39:57 +0000 Subject: [PATCH] Fix tests --- src/accelerate/checkpointing.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index 862182a71d5..a83c19d8dad 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -20,7 +20,9 @@ import numpy as np import torch from torch.cuda.amp import GradScaler +from torch.utils.data import BatchSampler +from .data_loader import SeedableRandomSampler from .utils import ( MODEL_NAME, OPTIMIZER_NAME, @@ -100,7 +102,14 @@ def save_accelerator_state( for i, dataloader in enumerate(dataloaders): sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin" output_sampler_file = os.path.join(output_dir, sampler_name) - save(dataloader.sampler, output_sampler_file, save_on_each_node=save_on_each_node) + # Only save if we have our custom sampler + sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + if sampler_is_batch_sampler: + sampler = dataloader.sampler.sampler + else: + sampler = dataloader.batch_sampler.sampler + if isinstance(sampler, SeedableRandomSampler): + save(sampler, output_sampler_file, save_on_each_node=save_on_each_node) logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}") # GradScaler state @@ -192,7 +201,17 @@ def load_accelerator_state( for i, dataloader in enumerate(dataloaders): sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin" input_sampler_file = os.path.join(input_dir, sampler_name) - dataloader.sampler = torch.load(input_sampler_file) + # Only load if we have our custom sampler + sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + if sampler_is_batch_sampler: + sampler = dataloader.sampler.sampler + else: + sampler = dataloader.batch_sampler.sampler + if isinstance(sampler, SeedableRandomSampler): + if sampler_is_batch_sampler: + dataloader.sampler.sampler = torch.load(input_sampler_file) + else: + dataloader.batch_sampler.sampler = torch.load(input_sampler_file) logger.info("All dataloader sampler states loaded successfully") # GradScaler state