Skip to content

Commit

Permalink
Do not autoresume if a default name is set, only on user defined ones (
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Oct 11, 2024
1 parent aad13c4 commit c6b7453
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,17 +311,21 @@ def train(cfg: DictConfig) -> Trainer:
eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str

# Optional parameters will be set to default values if not specified.
default_run_name: str = os.environ.get('RUN_NAME', 'llm')
run_name: str = train_cfg.run_name if train_cfg.run_name else default_run_name
env_run_name: Optional[str] = os.environ.get('RUN_NAME', None)
run_name: str = (
train_cfg.run_name if train_cfg.run_name else env_run_name
) or 'llm'
is_state_dict_sharded: bool = (
fsdp_config.get('state_dict_type', 'full') == 'sharded'
) if fsdp_config else False
save_latest_filename: str = train_cfg.save_latest_filename if train_cfg.save_latest_filename else 'latest-sharded-rank{rank}' if is_state_dict_sharded else 'latest-rank{rank}.pt'
save_filename: str = train_cfg.save_filename if train_cfg.save_filename else 'ep{epoch}-ba{batch}-rank{rank}.pt'

# Enable autoresume from model checkpoints if possible
is_user_set_run_name: bool = train_cfg.run_name is not None or env_run_name is not None
autoresume_default: bool = False
if train_cfg.save_folder is not None \
if is_user_set_run_name and \
train_cfg.save_folder is not None \
and not train_cfg.save_overwrite \
and not train_cfg.save_weights_only:
autoresume_default = True
Expand Down

0 comments on commit c6b7453

Please sign in to comment.