Skip to content

Commit

Permalink
Fix (#2080)
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr authored Oct 24, 2023
1 parent b7686cc commit eb8c535
Showing 1 changed file with 9 additions and 17 deletions.
26 changes: 9 additions & 17 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import numpy as np
import torch
from torch.cuda.amp import GradScaler
from torch.utils.data import BatchSampler

from .utils import (
MODEL_NAME,
Expand Down Expand Up @@ -102,15 +101,13 @@ def save_accelerator_state(
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
output_sampler_file = os.path.join(output_dir, sampler_name)
# Only save if we have our custom sampler
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
from .data_loader import IterableDatasetShard, SeedableRandomSampler

if isinstance(dataloader.dataset, IterableDatasetShard):
sampler = dataloader.sampler.sampler
else:
sampler = dataloader.batch_sampler.sampler
from .data_loader import SeedableRandomSampler

if isinstance(sampler, SeedableRandomSampler):
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node)
if isinstance(sampler, SeedableRandomSampler):
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node)
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")

# GradScaler state
Expand Down Expand Up @@ -203,18 +200,13 @@ def load_accelerator_state(
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
input_sampler_file = os.path.join(input_dir, sampler_name)
# Only load if we have our custom sampler
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
from .data_loader import IterableDatasetShard, SeedableRandomSampler

if isinstance(dataloader.dataset, IterableDatasetShard):
sampler = dataloader.sampler.sampler
else:
sampler = dataloader.batch_sampler.sampler
from .data_loader import SeedableRandomSampler

if isinstance(sampler, SeedableRandomSampler):
if sampler_is_batch_sampler:
if isinstance(sampler, SeedableRandomSampler):
dataloader.sampler.sampler = torch.load(input_sampler_file)
else:
dataloader.batch_sampler.sampler = torch.load(input_sampler_file)
logger.info("All dataloader sampler states loaded successfully")

# GradScaler state
Expand Down

0 comments on commit eb8c535

Please sign in to comment.