diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 1fc4a0e96e..36a4d75fb8 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -161,6 +161,7 @@ class TrainConfig: load_strict_model_weights: bool = True load_ignore_keys: Optional[List[str]] = None save_ignore_keys: Optional[List[str]] = None + only_hf_checkpoint: bool = False # Dataloader device_train_microbatch_size: Union[str, int, float] = 'auto' diff --git a/scripts/train/train.py b/scripts/train/train.py index 134058a595..7945e9854a 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -22,7 +22,7 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om -from llmfoundry.callbacks import AsyncEval +from llmfoundry.callbacks import AsyncEval, HuggingFaceCheckpointer from llmfoundry.data.dataloader import build_dataloader from llmfoundry.eval.metrics.nlp import InContextLearningMetric from llmfoundry.layers_registry import ffns_with_megablocks @@ -527,6 +527,24 @@ def main(cfg: DictConfig) -> Trainer: compile_config=compile_config, ) + # Optionally just save an HF checkpoint + if train_cfg.only_hf_checkpoint: + hf_checkpointer_callbacks = [ + c for c in callbacks if isinstance(c, HuggingFaceCheckpointer) + ] + if len(hf_checkpointer_callbacks) == 0: + raise ValueError( + 'No HuggingFaceCheckpointer callback found, but only_hf_checkpoint was set to True. Please add a HuggingFaceCheckpointer.', + ) + if len(hf_checkpointer_callbacks) > 1: + raise ValueError( + 'Multiple HuggingFaceCheckpointer callbacks found, but only_hf_checkpoint was set to True. Please remove all but one HuggingFaceCheckpointer.', + ) + + hf_checkpointer_callback = hf_checkpointer_callbacks[0] + hf_checkpointer_callback._save_checkpoint(trainer.state, trainer.logger) + return trainer + if train_cfg.log_config: log.info('Logging config') log_config(logged_cfg)