Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Oct 16, 2023
1 parent 3efbb09 commit 1ba591d
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import numpy as np
import torch
from torch.cuda.amp import GradScaler
from torch.utils.data import BatchSampler

from .data_loader import SeedableRandomSampler
from .utils import (
MODEL_NAME,
OPTIMIZER_NAME,
Expand Down Expand Up @@ -100,7 +102,14 @@ def save_accelerator_state(
for i, dataloader in enumerate(dataloaders):
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
output_sampler_file = os.path.join(output_dir, sampler_name)
save(dataloader.sampler, output_sampler_file, save_on_each_node=save_on_each_node)
# Only save if we have our custom sampler
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
sampler = dataloader.sampler.sampler
else:
sampler = dataloader.batch_sampler.sampler
if isinstance(sampler, SeedableRandomSampler):
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node)
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")

# GradScaler state
Expand Down Expand Up @@ -192,7 +201,17 @@ def load_accelerator_state(
for i, dataloader in enumerate(dataloaders):
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
input_sampler_file = os.path.join(input_dir, sampler_name)
dataloader.sampler = torch.load(input_sampler_file)
# Only load if we have our custom sampler
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
sampler = dataloader.sampler.sampler
else:
sampler = dataloader.batch_sampler.sampler
if isinstance(sampler, SeedableRandomSampler):
if sampler_is_batch_sampler:
dataloader.sampler.sampler = torch.load(input_sampler_file)
else:
dataloader.batch_sampler.sampler = torch.load(input_sampler_file)
logger.info("All dataloader sampler states loaded successfully")

# GradScaler state
Expand Down

0 comments on commit 1ba591d

Please sign in to comment.