diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index a9e976d4a71..cd00cf91b79 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2795,7 +2795,7 @@ def _inner(folder): schedulers.append(scheduler) elif self.distributed_type not in [DistributedType.MEGATRON_LM]: schedulers = self._schedulers - + # Save the samplers of the dataloaders dataloaders = self._dataloaders @@ -2938,7 +2938,7 @@ def _inner(folder): schedulers.append(scheduler) elif self.distributed_type not in [DistributedType.MEGATRON_LM]: schedulers = self._schedulers - + dataloaders = self._dataloaders # Call model loading hooks that might have been registered with diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index d50aff4c5c5..862182a71d5 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -188,12 +188,12 @@ def load_accelerator_state( input_scheduler_file = os.path.join(input_dir, scheduler_name) scheduler.load_state_dict(torch.load(input_scheduler_file)) logger.info("All scheduler states loaded successfully") - + for i, dataloader in enumerate(dataloaders): sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin" input_sampler_file = os.path.join(input_dir, sampler_name) dataloader.sampler = torch.load(input_sampler_file) - logger.info(f"All dataloader sampler states loaded successfully") + logger.info("All dataloader sampler states loaded successfully") # GradScaler state if scaler is not None: