Skip to content

Commit

Permalink
Add a config arg to just save an hf checkpoint (#1335)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jul 3, 2024
1 parent 73f267c commit a542565
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
1 change: 1 addition & 0 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class TrainConfig:
load_strict_model_weights: bool = True
load_ignore_keys: Optional[List[str]] = None
save_ignore_keys: Optional[List[str]] = None
only_hf_checkpoint: bool = False

# Dataloader
device_train_microbatch_size: Union[str, int, float] = 'auto'
Expand Down
20 changes: 19 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from omegaconf import DictConfig
from omegaconf import OmegaConf as om

from llmfoundry.callbacks import AsyncEval
from llmfoundry.callbacks import AsyncEval, HuggingFaceCheckpointer
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.eval.metrics.nlp import InContextLearningMetric
from llmfoundry.layers_registry import ffns_with_megablocks
Expand Down Expand Up @@ -527,6 +527,24 @@ def main(cfg: DictConfig) -> Trainer:
compile_config=compile_config,
)

# Optionally just save an HF checkpoint
if train_cfg.only_hf_checkpoint:
hf_checkpointer_callbacks = [
c for c in callbacks if isinstance(c, HuggingFaceCheckpointer)
]
if len(hf_checkpointer_callbacks) == 0:
raise ValueError(
'No HuggingFaceCheckpointer callback found, but only_hf_checkpoint was set to True. Please add a HuggingFaceCheckpointer.',
)
if len(hf_checkpointer_callbacks) > 1:
raise ValueError(
'Multiple HuggingFaceCheckpointer callbacks found, but only_hf_checkpoint was set to True. Please remove all but one HuggingFaceCheckpointer.',
)

hf_checkpointer_callback = hf_checkpointer_callbacks[0]
hf_checkpointer_callback._save_checkpoint(trainer.state, trainer.logger)
return trainer

if train_cfg.log_config:
log.info('Logging config')
log_config(logged_cfg)
Expand Down

0 comments on commit a542565

Please sign in to comment.