diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index f647565386..280ff0e0d5 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -311,8 +311,10 @@ def train(cfg: DictConfig) -> Trainer: eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str # Optional parameters will be set to default values if not specified. - default_run_name: str = os.environ.get('RUN_NAME', 'llm') - run_name: str = train_cfg.run_name if train_cfg.run_name else default_run_name + env_run_name: Optional[str] = os.environ.get('RUN_NAME', None) + run_name: str = ( + train_cfg.run_name if train_cfg.run_name else env_run_name + ) or 'llm' is_state_dict_sharded: bool = ( fsdp_config.get('state_dict_type', 'full') == 'sharded' ) if fsdp_config else False @@ -320,8 +322,10 @@ def train(cfg: DictConfig) -> Trainer: save_filename: str = train_cfg.save_filename if train_cfg.save_filename else 'ep{epoch}-ba{batch}-rank{rank}.pt' # Enable autoresume from model checkpoints if possible + is_user_set_run_name: bool = train_cfg.run_name is not None or env_run_name is not None autoresume_default: bool = False - if train_cfg.save_folder is not None \ + if is_user_set_run_name and \ + train_cfg.save_folder is not None \ and not train_cfg.save_overwrite \ and not train_cfg.save_weights_only: autoresume_default = True diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 1b9feb9a10..7f7442ab8f 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -310,7 +310,7 @@ def __init__(self, input: dict) -> None: ## Convert Delta to JSON exceptions -class ClusterDoesNotExistError(NetworkError): +class ClusterDoesNotExistError(UserError): """Error thrown when the cluster does not exist.""" def __init__(self, cluster_id: str) -> None: @@ -318,7 +318,7 @@ def __init__(self, cluster_id: str) -> None: super().__init__(message, cluster_id=cluster_id) -class ClusterInvalidAccessMode(NetworkError): +class ClusterInvalidAccessMode(UserError): """Error thrown when the cluster does not exist.""" def __init__(self, cluster_id: str, access_mode: str) -> None: