From b7c18e2d067d3025676d2b5abb23e501ff811a8d Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 22 Sep 2023 13:22:29 -0700 Subject: [PATCH] fix resuming from checkpoint --- .../by_feature/deepspeed_with_config_support.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/by_feature/deepspeed_with_config_support.py b/examples/by_feature/deepspeed_with_config_support.py index 21b216b31f3..15e810c4a2e 100755 --- a/examples/by_feature/deepspeed_with_config_support.py +++ b/examples/by_feature/deepspeed_with_config_support.py @@ -602,15 +602,22 @@ def group_texts(examples): resume_step -= starting_epoch * num_update_steps_per_epoch completed_steps = resume_step + # update progress bar if resumed from checkpoint + progress_bar.update(completed_steps) + for epoch in range(starting_epoch, args.num_train_epochs): model.train() if args.with_tracking: total_loss = 0 # skip new `skip_first_batches` to skip the batches when resuming from ckpt - if args.resume_from_checkpoint: - train_dataloader = accelerator.skip_first_batches(train_dataloader, num_batches=resume_step) - for step, batch in enumerate(train_dataloader): + if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: + # We need to skip steps until we reach the resumed step + active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) + else: + # After the first iteration though, we need to go back to the original dataloader + active_dataloader = train_dataloader + for step, batch in enumerate(active_dataloader): # In particular, DeepSpeed handles `gradient_accumulation` via `DeepSpeedEngine`. # Below, we use `accelerator.accumulate` if the user # wants to switch to other approaches such as plain DDP, PyTorch FSDP ...