Skip to content

Commit

Permalink
Clean
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Oct 16, 2023
1 parent c5e2d11 commit 3efbb09
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2795,7 +2795,7 @@ def _inner(folder):
schedulers.append(scheduler)
elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
schedulers = self._schedulers

# Save the samplers of the dataloaders
dataloaders = self._dataloaders

Expand Down Expand Up @@ -2938,7 +2938,7 @@ def _inner(folder):
schedulers.append(scheduler)
elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
schedulers = self._schedulers

dataloaders = self._dataloaders

# Call model loading hooks that might have been registered with
Expand Down
4 changes: 2 additions & 2 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,12 @@ def load_accelerator_state(
input_scheduler_file = os.path.join(input_dir, scheduler_name)
scheduler.load_state_dict(torch.load(input_scheduler_file))
logger.info("All scheduler states loaded successfully")

for i, dataloader in enumerate(dataloaders):
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
input_sampler_file = os.path.join(input_dir, sampler_name)
dataloader.sampler = torch.load(input_sampler_file)
logger.info(f"All dataloader sampler states loaded successfully")
logger.info("All dataloader sampler states loaded successfully")

# GradScaler state
if scaler is not None:
Expand Down

0 comments on commit 3efbb09

Please sign in to comment.