From 34541f137c6c2f960c7e912e00006e6a5e70530a Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Wed, 23 Oct 2024 18:05:37 +0000 Subject: [PATCH] simplify --- llmfoundry/command_utils/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index e599961998..dbfe89618a 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -311,10 +311,11 @@ def train(cfg: DictConfig) -> Trainer: eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str # Optional parameters will be set to default values if not specified. - env_run_name: Optional[str] = os.environ.get('RUN_NAME', None) - run_name: Optional[str] = ( - train_cfg.run_name if train_cfg.run_name else env_run_name - ) + run_name: Optional[ + str] = train_cfg.run_name if train_cfg.run_name else os.environ.get( + 'RUN_NAME', + None, + ) is_state_dict_sharded: bool = ( fsdp_config.get('state_dict_type', 'full') == 'sharded' ) if fsdp_config else False @@ -322,7 +323,7 @@ def train(cfg: DictConfig) -> Trainer: save_filename: str = train_cfg.save_filename if train_cfg.save_filename else 'ep{epoch}-ba{batch}-rank{rank}.pt' # Enable autoresume from model checkpoints if possible - is_user_set_run_name: bool = train_cfg.run_name is not None or env_run_name is not None + is_user_set_run_name: bool = train_cfg.run_name is not None or run_name is not None autoresume_default: bool = False if is_user_set_run_name and \ train_cfg.save_folder is not None \