From c6b7453888106e6f479ab5e620c4ed97fb856cf7 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Fri, 11 Oct 2024 15:41:49 -0700 Subject: [PATCH] Do not autoresume if a default name is set, only on user defined ones (#1588) --- llmfoundry/command_utils/train.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index f647565386..280ff0e0d5 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -311,8 +311,10 @@ def train(cfg: DictConfig) -> Trainer: eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str # Optional parameters will be set to default values if not specified. - default_run_name: str = os.environ.get('RUN_NAME', 'llm') - run_name: str = train_cfg.run_name if train_cfg.run_name else default_run_name + env_run_name: Optional[str] = os.environ.get('RUN_NAME', None) + run_name: str = ( + train_cfg.run_name if train_cfg.run_name else env_run_name + ) or 'llm' is_state_dict_sharded: bool = ( fsdp_config.get('state_dict_type', 'full') == 'sharded' ) if fsdp_config else False @@ -320,8 +322,10 @@ def train(cfg: DictConfig) -> Trainer: save_filename: str = train_cfg.save_filename if train_cfg.save_filename else 'ep{epoch}-ba{batch}-rank{rank}.pt' # Enable autoresume from model checkpoints if possible + is_user_set_run_name: bool = train_cfg.run_name is not None or env_run_name is not None autoresume_default: bool = False - if train_cfg.save_folder is not None \ + if is_user_set_run_name and \ + train_cfg.save_folder is not None \ and not train_cfg.save_overwrite \ and not train_cfg.save_weights_only: autoresume_default = True