Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Dec 9, 2024
1 parent 7b8bf5f commit b489a9a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
32 changes: 31 additions & 1 deletion llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,37 @@ def train(cfg: DictConfig) -> Trainer:
trainer.eval()

log.info('Starting training...')
trainer.fit()
try:
trainer.fit()
except ValueError as e:
msg = str(e)
if 'The max_duration' in msg and 'is less than or equal to the elapsed training duration' in msg and train_cfg.run_is_retry:
log.info(
'Training is already complete and detected retry. Skipping training and saving checkpoint.'
)
trainer.save_checkpoint_to_save_folder()

hf_checkpointer_callbacks = [
c for c in callbacks if isinstance(c, HuggingFaceCheckpointer)
]
if len(hf_checkpointer_callbacks) == 0:
log.info(
'No HuggingFaceCheckpointer callback found. Skipping HF checkpoint.'
)
return trainer
if len(hf_checkpointer_callbacks) > 1:
raise ValueError(
'Multiple HuggingFaceCheckpointer callbacks found, but only_hf_checkpoint was set to True. Please remove all but one HuggingFaceCheckpointer.',
)

hf_checkpointer_callback = hf_checkpointer_callbacks[0]
hf_checkpointer_callback._save_checkpoint(
trainer.state,
trainer.logger,
upload_to_save_folder=True,
register_to_mlflow=True,
)
return trainer

log.info('Done.')
return trainer
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class TrainConfig:

# Resumption
autoresume: bool = False
run_is_retry: bool = False

# Profiling
profiler: Optional[dict[str, Any]] = None
Expand Down

0 comments on commit b489a9a

Please sign in to comment.