diff --git a/llmfoundry/train/train.py b/llmfoundry/train/train.py index 273372e1cd..df0472a775 100644 --- a/llmfoundry/train/train.py +++ b/llmfoundry/train/train.py @@ -537,6 +537,12 @@ def train(cfg: DictConfig) -> Trainer: hf_checkpointer_callback._save_checkpoint(trainer.state, trainer.logger) return trainer + if train_cfg.only_composer_checkpoint: + log.info('Not training. Only saving composer checkpoint.') + trainer.save_checkpoint_to_save_folder() + log.info('Done saving checkpoint.') + return trainer + if train_cfg.log_config: log.info('Logging config') log_config(logged_cfg) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 2667fceb67..4b86de99b8 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -162,6 +162,7 @@ class TrainConfig: load_ignore_keys: Optional[List[str]] = None save_ignore_keys: Optional[List[str]] = None only_hf_checkpoint: bool = False + only_composer_checkpoint: bool = False # Dataloader device_train_microbatch_size: Union[str, int, float] = 'auto'