Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Dec 20, 2023
1 parent 2847c8b commit 85139a0
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 39 deletions.
104 changes: 66 additions & 38 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from composer.loggers.mosaicml_logger import (MOSAICML_PLATFORM_ENV_VAR,
RUN_NAME_ENV_VAR)
from composer.utils import dist
from composer.utils.misc import create_interval_scheduler
from composer.utils.file_helpers import list_remote_objects
from composer.utils.misc import create_interval_scheduler

from mcli import ComputeConfig, Run, RunConfig, create_run, get_run

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,8 +71,10 @@ def get_run_name(training_run_name: str, current_interval: str) -> str:

return f'{RUN_NAME_PREFIX}-{current_interval}-{name_without_uuid_suffix}'


SUPPORTED_UNITS = {TimeUnit.EPOCH, TimeUnit.BATCH}


def get_interval_from_checkpoint(checkpoint: str, unit: TimeUnit) -> Time:
"""Get the interval from a checkpoint name.
Expand All @@ -88,7 +91,8 @@ def get_interval_from_checkpoint(checkpoint: str, unit: TimeUnit) -> Time:
elif unit == TimeUnit.BATCH:
val = checkpoint.split('-')[1].replace('ba', '')
else:
raise ValueError(f'Unsupported unit {unit}. Must be in {" ".join(SUPPORTED_UNITS)}')
raise ValueError(
f'Unsupported unit {unit}. Must be in {SUPPORTED_UNITS}')

return Time(int(val), unit)

Expand Down Expand Up @@ -168,8 +172,8 @@ def validate_interval(interval: Union[str, int, Time],

if result.unit not in SUPPORTED_UNITS:
raise ValueError(
f'Async eval interval must be in units {", ".join(SUPPORTED_UNITS)}')
f'Async eval interval must be in units {SUPPORTED_UNITS}')

if new_save_interval.unit != result.unit:
raise ValueError(
'Save interval and async eval interval must be in the same unit')
Expand All @@ -182,8 +186,9 @@ def validate_interval(interval: Union[str, int, Time],
'Async eval interval must be a multiple of save interval')
return result


def validate_check_interval(check_interval: Union[str, int, Time],
interval: Time) -> Time:
interval: Time) -> Time:
if isinstance(check_interval, str):
result: Time = Time.from_timestring(check_interval)
elif isinstance(check_interval, int):
Expand All @@ -193,10 +198,9 @@ def validate_check_interval(check_interval: Union[str, int, Time],

if result.unit not in SUPPORTED_UNITS:
raise ValueError(
f'Async eval check interval must be in units {", ".join(SUPPORTED_UNITS)}')
f'Async eval check interval must be in units {SUPPORTED_UNITS}')
if interval.unit != result.unit:
raise ValueError(
'Check interval and interval must be in the same unit')
raise ValueError('Check interval and interval must be in the same unit')
if result > interval:
raise ValueError(
'Async eval interval must be equal or greater (less frequent) than async check interval'
Expand Down Expand Up @@ -231,7 +235,7 @@ def __init__(
):

self.compute = compute

# Run these during init to fail fast in any of the error cases
for required in ('save_interval', 'save_folder'):
if required not in training_config:
Expand All @@ -249,16 +253,18 @@ def __init__(
# Validate the interval (how often to launch eval runs)
self.interval = validate_interval(interval,
self.training_config['save_interval'])

# Validate and compute the check interval (how often to check for new checkpoints)
if check_interval is None:
unit = self.interval.value // 5
if unit == 0:
unit = 1
check_interval = Time(unit, self.interval.unit)
log.info(f'No check interval provided, defaulting to {check_interval}')
log.info(
f'No check interval provided, defaulting to {check_interval}')

self.check_interval = validate_check_interval(check_interval, self.interval)
self.check_interval = validate_check_interval(check_interval,
self.interval)

# Keep track of checkpoints by interval that have already been evaled
# Format: {interval: (checkpoint, run_name)}
Expand All @@ -273,10 +279,8 @@ def __init__(
include_end_of_training=False,
)

log.info(
'Initialized AsyncEval callback. Will generate runs at ' +
f'interval {interval}, checking at {check_interval}'
)
log.info('Initialized AsyncEval callback. Will generate runs at ' +
f'interval {interval}, checking at {check_interval}')

def _get_checkpoints_and_launch_runs(self, state: State):
"""Get the latest checkpoint from the training run.
Expand All @@ -298,10 +302,12 @@ def _get_checkpoints_and_launch_runs(self, state: State):
return

if not checkpointer.saved_checkpoints:
log.debug('No saved checkpoints found on the checkpointer. Skipping eval')
log.debug(
'No saved checkpoints found on the checkpointer. Skipping eval')
return

found_checkpoints = set(list_remote_objects(self.checkpoint_save_folder))
found_checkpoints = set(list_remote_objects(
self.checkpoint_save_folder))
# self.checkpoint_save_folder s3://.../anna/asyncsharded
# found_checkpoints {'anna/asyncsharded/ep0-ba2/__6_0.distcp', 'anna/asyncsharded/ep0-ba2/__2_0.distcp', 'anna/asyncsharded/ep0-ba2/__1_0.distcp', 'anna/asyncsharded/ep0-ba2/__4_0.distcp', 'anna/asyncsharded/latest-rank0.pt.symlink', 'anna/asyncsharded/ep0-ba4/.metadata', 'anna/asyncsharded/ep0-ba2/__3_0.distcp', 'anna/asyncsharded/ep0-ba2/.metadata', 'anna/asyncsharded/ep0-ba2/__5_0.distcp', 'anna/asyncsharded/ep0-ba2/__7_0.distcp', 'anna/asyncsharded/ep0-ba2/__0_0.distcp'}
# saved_checkpoints ['anna/asyncsharded/ep0-ba2/__0_0.distcp']
Expand All @@ -313,9 +319,9 @@ def _get_checkpoints_and_launch_runs(self, state: State):
if not found_checkpoints:
log.debug('No saved checkpoints found yet on remote. Skipping eval')
return

for checkpoint in checkpointer.saved_checkpoints:
# Get the part of the path that contains the interval. This is
# Get the part of the path that contains the interval. This is
# different for sharded checkpoints (which are saved in a folder)
if state.fsdp_elastic_sharded_enabled:
# eg {save_folder}/ep0-ba1/.
Expand All @@ -324,29 +330,48 @@ def _get_checkpoints_and_launch_runs(self, state: State):
# eg {save_folder}/ep0-ba1-rank0.pt
interval_path = Path(checkpoint).parts[-1]

interval = get_interval_from_checkpoint(interval_path, self.interval.unit)
interval = get_interval_from_checkpoint(interval_path,
self.interval.unit)
if interval.value % self.interval.value != 0:
log.debug(f'Checkpoint {checkpoint} ({interval}) is not at an eval interval ({self.interval}), skipping')
continue # Skip checkpoints when save interval is more frequent than eval interval

log.debug(
f'Checkpoint {checkpoint} ({interval}) is not at an eval interval ({self.interval}), skipping'
)
continue

if interval in self.checkpoints_evaled:
continue # Skip checkpoints that have already been evaled

# Check if the checkpoint is fully uploaded. If not, skip it until upload is complete
if state.fsdp_elastic_sharded_enabled:
log.error('todo')
# 8 (or N gpus if partial) and a metadata file
# __0_0.distcp - __7_0.distcp
# checkpoint is something like folder/ep0-ba4/__0_0.distcp
checkpoint_folder = '/'.join(checkpoint.split('/')[:-1])

if f'{checkpoint_folder}/.metadata' not in found_checkpoints:
log.debug(
f'Checkpoint {checkpoint} not fully uploaded (missing metadata), skipping'
)
continue

for i in range(dist.get_world_size()):
proc = i % 8
rank = i // 8
shard_name = f'__{proc}_{rank}.distcp'
if f'{checkpoint_folder}/{shard_name}' not in found_checkpoints:
log.debug(
f'Checkpoint {checkpoint} not fully uploaded (missing shard {shard_name}), skipping'
)
continue

else:
if checkpoint not in found_checkpoints:
log.debug(f'Checkpoint {checkpoint} not fully uploaded, skipping')
log.debug(
f'Checkpoint {checkpoint} not fully uploaded, skipping')
continue

# TODO: load_path looks like anna/asyncsharded/ep0-ba2/__0_0.distcp.symlink and not s3://.../anna/asyncsharded/ep0-ba2
# anna/async/ep0-ba2-rank0.pt -> s3://.../anna/async/ep0-ba2-rank0.pt

eval_run = self.launch_run(checkpoint, interval)
self.checkpoints_evaled[interval] = (checkpoint, eval_run.name)
full_checkpoint_path = f'{self.checkpoint_save_folder}/{interval_path}'
eval_run = self.launch_run(full_checkpoint_path, interval)
self.checkpoints_evaled[interval] = (full_checkpoint_path,
eval_run.name)

def run_event(self, event: Event, state: State, logger: Logger) -> None:
del logger
Expand All @@ -367,15 +392,15 @@ def close(self, state: State, logger: Logger) -> None:

if dist.get_global_rank() != 0:
return

# Eval any remaining checkpoints
self._get_checkpoints_and_launch_runs(state)

# Eval the latest checkpoint
latest_timestamp = state.timestamp.get(self.interval.unit)
if latest_timestamp not in self.checkpoints_evaled:
save_latest_filename = self.training_config.get('save_latest_filename',
None)
save_latest_filename = self.training_config.get(
'save_latest_filename', None)

if not save_latest_filename:
rank = dist.get_global_rank()
Expand All @@ -384,9 +409,12 @@ def close(self, state: State, logger: Logger) -> None:
checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}'

eval_run = self.launch_run(checkpoint, latest_timestamp)
self.checkpoints_evaled[latest_timestamp] = (checkpoint, eval_run.name)
self.checkpoints_evaled[latest_timestamp] = (checkpoint,
eval_run.name)

log.info(f'AsyncEval callback finished. Launched {len(self.checkpoints_evaled)} eval runs:')
log.info(
f'AsyncEval callback finished. Launched {len(self.checkpoints_evaled)} eval runs:'
)
for interval, (checkpoint, run_name) in self.checkpoints_evaled.items():
log.info(f' {interval}: {checkpoint}, {run_name}')

Expand Down
1 change: 0 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,6 @@ def main(cfg: DictConfig) -> Trainer:
logging.getLogger('llmfoundry').setLevel(python_log_level.upper())

# Initialize context
print('fsdp_config', fsdp_config)
init_context = process_init_device(model_config, fsdp_config)
logged_cfg.update({'fsdp_config': fsdp_config}, merge=True)

Expand Down

0 comments on commit 85139a0

Please sign in to comment.