diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 8fb833ddde..2765c56470 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -640,7 +640,37 @@ def train(cfg: DictConfig) -> Trainer: trainer.eval() log.info('Starting training...') - trainer.fit() + try: + trainer.fit() + except ValueError as e: + msg = str(e) + if 'The max_duration' in msg and 'is less than or equal to the elapsed training duration' in msg and train_cfg.run_is_retry: + log.info( + 'Training is already complete and detected retry. Skipping training and saving checkpoint.' + ) + trainer.save_checkpoint_to_save_folder() + + hf_checkpointer_callbacks = [ + c for c in callbacks if isinstance(c, HuggingFaceCheckpointer) + ] + if len(hf_checkpointer_callbacks) == 0: + log.info( + 'No HuggingFaceCheckpointer callback found. Skipping HF checkpoint.' + ) + return trainer + if len(hf_checkpointer_callbacks) > 1: + raise ValueError( + 'Multiple HuggingFaceCheckpointer callbacks found, but only_hf_checkpoint was set to True. Please remove all but one HuggingFaceCheckpointer.', + ) + + hf_checkpointer_callback = hf_checkpointer_callbacks[0] + hf_checkpointer_callback._save_checkpoint( + trainer.state, + trainer.logger, + upload_to_save_folder=True, + register_to_mlflow=True, + ) + return trainer log.info('Done.') return trainer diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 997273de7f..2221039d34 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -187,6 +187,7 @@ class TrainConfig: # Resumption autoresume: bool = False + run_is_retry: bool = False # Profiling profiler: Optional[dict[str, Any]] = None