Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

fix: ensure last checkpoint is always saved, refactor training stop conditions to be computed in single location #729

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

mattmazzola
Copy link
Contributor

@mattmazzola mattmazzola commented Jun 8, 2023

Issues

1 Inconsistent checkpoint filenames saved by trainer

In our pipeline we often have sequence of steps such as (train, reshard/unflatten, evaluate). The output files of the training become inputs to the resharding scripts. In order for the execution to work reliably the output files need to have consistent filenames, such as checkpoint_last-model_part-0-shard0.pt

When running metaseq.cli.train with tasks such as streaming_finetune_language_modeling there are two different stopping conditions set by --max-epochs and --max-updates. Whichever limit is hit first will cause the model stop training.

The issue is that checkpoint_last-* file is ONLY written the epoch stop condition or update stop conditions were false.
This couples the checkpoint filename with the stopping conditions

Notice checkpoints[0] only uses the FIRST true filename/condition

if len(checkpoints) > 0:
if PathManager.islink(checkpoints[0]):
PathManager.rm(checkpoints[0])
trainer.save_checkpoint(
checkpoints[0],
extra_state,
training_finished=training_finished,
async_callback_fn=async_callback_fn if save_to_NFS else None,
files_to_symlink_to=checkpoints[1:] if len(checkpoints) > 1 else None,
)

Goal

We want to be able to run the jobs/pipeline and change the stopping conditions without implicitly changing the output file that will be given to the subsequent commands / scripts

2 Training Stop was Handled in Multiple Locations

Loop condition:

while epoch_itr.next_epoch_idx <= max_epoch:

Loop break:
if should_stop:
break

This makes it harder to reason about which condition will cause training to stop.

Solution

  • Consolidate all training stop conditions in to validate_and_save and should_stop
    • Use > instead of >= conditions
  • Change to always save checkpoint_last* file
    • This means there are cases it will save multiple checkpoints
      • Epoch AND last checkpoint
      • Updates AND last checkpoint

Testing

I wasn't able to test since this is merging with metaseq main instead of our fork's main.
I wanted to at least share the ideas. Although changing training stop conditions can be serious, so maybe someone else can submit a small jobs to test. One with max-epochs, other with max-updates, and in both cases it saves checkpoint_last files

Related to #726

train_meter = meters.StopwatchMeter()
train_meter.start()
while epoch_itr.next_epoch_idx <= max_epoch:

while True:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually it is practice to avoid while True anywhere since it relies on other code to stop the loop and it's easy to make mistakes. However, the alternative of splitting logic between loop and validate function is more complex and thus more likely for us to have issues in future.

Also, another reason the while true is not be so bad is because the original could potentially be the same condition when max-epochs was not provided / defined.

If not defined

while epoch_itr.next_epoch_idx <= max_epoch
while epoch_itr.next_epoch_idx <= cfg.optimization.max_epoch or math.inf
while epoch_itr.next_epoch_idx <= math.inf
while true

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants