Skip to content

Commit

Permalink
Merge branch 'main' into callback-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Oct 11, 2024
2 parents 0da1114 + c6b7453 commit ce016af
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
10 changes: 7 additions & 3 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,17 +311,21 @@ def train(cfg: DictConfig) -> Trainer:
eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str

# Optional parameters will be set to default values if not specified.
default_run_name: str = os.environ.get('RUN_NAME', 'llm')
run_name: str = train_cfg.run_name if train_cfg.run_name else default_run_name
env_run_name: Optional[str] = os.environ.get('RUN_NAME', None)
run_name: str = (
train_cfg.run_name if train_cfg.run_name else env_run_name
) or 'llm'
is_state_dict_sharded: bool = (
fsdp_config.get('state_dict_type', 'full') == 'sharded'
) if fsdp_config else False
save_latest_filename: str = train_cfg.save_latest_filename if train_cfg.save_latest_filename else 'latest-sharded-rank{rank}' if is_state_dict_sharded else 'latest-rank{rank}.pt'
save_filename: str = train_cfg.save_filename if train_cfg.save_filename else 'ep{epoch}-ba{batch}-rank{rank}.pt'

# Enable autoresume from model checkpoints if possible
is_user_set_run_name: bool = train_cfg.run_name is not None or env_run_name is not None
autoresume_default: bool = False
if train_cfg.save_folder is not None \
if is_user_set_run_name and \
train_cfg.save_folder is not None \
and not train_cfg.save_overwrite \
and not train_cfg.save_weights_only:
autoresume_default = True
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,15 +310,15 @@ def __init__(self, input: dict) -> None:


## Convert Delta to JSON exceptions
class ClusterDoesNotExistError(NetworkError):
class ClusterDoesNotExistError(UserError):
"""Error thrown when the cluster does not exist."""

def __init__(self, cluster_id: str) -> None:
message = f'Cluster with id {cluster_id} does not exist. Check cluster id and try again!'
super().__init__(message, cluster_id=cluster_id)


class ClusterInvalidAccessMode(NetworkError):
class ClusterInvalidAccessMode(UserError):
"""Error thrown when the cluster does not exist."""

def __init__(self, cluster_id: str, access_mode: str) -> None:
Expand Down

0 comments on commit ce016af

Please sign in to comment.