Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix skip first batch for deepspeed example #2001

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions examples/by_feature/deepspeed_with_config_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,15 +602,22 @@ def group_texts(examples):
resume_step -= starting_epoch * num_update_steps_per_epoch
completed_steps = resume_step

# update progress bar if resumed from checkpoint
progress_bar.update(completed_steps)

for epoch in range(starting_epoch, args.num_train_epochs):
model.train()
if args.with_tracking:
total_loss = 0

# skip new `skip_first_batches` to skip the batches when resuming from ckpt
if args.resume_from_checkpoint:
train_dataloader = accelerator.skip_first_batches(train_dataloader, num_batches=resume_step)
for step, batch in enumerate(train_dataloader):
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
# After the first iteration though, we need to go back to the original dataloader
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
# In particular, DeepSpeed handles `gradient_accumulation` via `DeepSpeedEngine`.
# Below, we use `accelerator.accumulate` if the user
# wants to switch to other approaches such as plain DDP, PyTorch FSDP ...
Expand Down
Loading