Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async callback: Don't skip checkpoints, reliably only launch async eval when the checkpoint is ready #813

Merged
merged 29 commits into from
Feb 16, 2024
Merged
Changes from 9 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 192 additions & 61 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union

from composer.callbacks import CheckpointSaver
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.loggers import Logger
from composer.loggers.mosaicml_logger import (MOSAICML_PLATFORM_ENV_VAR,
RUN_NAME_ENV_VAR)
from composer.utils import dist
from composer.utils.file_helpers import list_remote_objects
from composer.utils.misc import create_interval_scheduler

from mcli import ComputeConfig, Run, RunConfig, create_run, get_run
Expand Down Expand Up @@ -71,32 +72,28 @@ def get_run_name(training_run_name: str, current_interval: str) -> str:
return f'{RUN_NAME_PREFIX}-{current_interval}-{name_without_uuid_suffix}'


def get_latest_checkpoint(event: Event, state: State) -> Optional[str]:
"""Get the latest checkpoint from the training run.
SUPPORTED_UNITS = {TimeUnit.EPOCH, TimeUnit.BATCH}


def get_interval_from_checkpoint(checkpoint: str, unit: TimeUnit) -> Time:
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
"""Get the interval from a checkpoint name.

Args:
event: The current run event
state: The current state of the training run
checkpoint: The name of the checkpoint
unit: The unit of the interval

Returns:
The path to the latest checkpoint, or None if there is not a latest checkpoint
The interval time
"""
checkpointer = None
for callback in state.callbacks:
if isinstance(callback, CheckpointSaver):
checkpointer = callback
break

if not checkpointer:
log.warning('No checkpoint saver callback found')
return None

if not checkpointer.saved_checkpoints:
log.warning('No saved checkpoints found on the checkpointer')
return None
if unit == TimeUnit.EPOCH:
val = checkpoint.split('-')[0].replace('ep', '')
elif unit == TimeUnit.BATCH:
val = checkpoint.split('-')[1].replace('ba', '')
else:
raise ValueError(
f'Unsupported unit {unit}. Must be in {SUPPORTED_UNITS}')

latest = checkpointer.saved_checkpoints[-1]
return str(Path(latest).parts[-1])
return Time(int(val), unit)


def get_eval_parameters(
Expand Down Expand Up @@ -172,6 +169,10 @@ def validate_interval(interval: Union[str, int, Time],
else:
result: Time = interval

if result.unit not in SUPPORTED_UNITS:
raise ValueError(
f'Async eval interval must be in units {SUPPORTED_UNITS}')

if new_save_interval.unit != result.unit:
raise ValueError(
'Save interval and async eval interval must be in the same unit')
Expand All @@ -185,6 +186,27 @@ def validate_interval(interval: Union[str, int, Time],
return result


def validate_check_interval(check_interval: Union[str, int, Time],
interval: Time) -> Time:
if isinstance(check_interval, str):
result: Time = Time.from_timestring(check_interval)
elif isinstance(check_interval, int):
result: Time = Time(check_interval, TimeUnit.EPOCH)
else:
result: Time = check_interval

if result.unit not in SUPPORTED_UNITS:
raise ValueError(
f'Async eval check interval must be in units {SUPPORTED_UNITS}')
if interval.unit != result.unit:
raise ValueError('Check interval and interval must be in the same unit')
if result > interval:
raise ValueError(
'Async eval interval must be equal or greater (less frequent) than async check interval'
)
return result


class AsyncEval(Callback):
"""Run the eval loop asynchronously as part of a MosaicML platform run.

Expand All @@ -194,8 +216,10 @@ class AsyncEval(Callback):
training_config: Dict[str, Any]: The config from the training run
interval: Union[str, int, Time]: The interval describing how often eval runs should be
launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH` or :attr:`.TimeUnit.BATCH`.
check_interval: Optional[Union[str, int, Time]]: The interval describing how often
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
to check if an eval run should be launched. If not provided, it will be set to
check 5 times each :attr:`interval` (floored to the nearest integer)
compute: Optional[Union[ComputeConfig, Dict[str, Any]]]: The compute configuration to
use for the eval run. If not provided, the same cluster as the current run and a
single, full GPU node will be used.
Expand All @@ -205,79 +229,186 @@ def __init__(
self,
training_config: Dict[str, Any],
interval: Union[str, int, Time],
check_interval: Optional[Union[str, int, Time]] = None,
compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None,
):

self.compute = compute

# Run these during init to fail fast in any of the error cases
for required in ('save_interval', 'save_folder'):
if required not in training_config:
raise ValueError(f'{required} required for async eval')

self.checkpoint_save_folder = training_config['save_folder']
self.training_config = training_config
self.interval = validate_interval(interval,
self.training_config['save_interval'])
self.check_interval = create_interval_scheduler(
interval,
# There is a custom close to ensure that the final checkpoint
# (which is the most important) is evaled after it is written
include_end_of_training=False,
)
self.compute = compute
self.last_checkpoint: Optional[str] = None

# Run these during init to fail fast in any of the error cases
self.current_run = self._get_current_run()
get_eval_parameters(
parameters=training_config,
checkpoint='test',
training_run_name=self.current_run.name,
)
log.info(
f'Initialized AsyncEval callback. Will generate runs at interval {interval}'

# Validate the interval (how often to launch eval runs)
self.interval = validate_interval(interval,
self.training_config['save_interval'])

# Validate and compute the check interval (how often to check for new checkpoints)
if check_interval is None:
unit = self.interval.value // 5
if unit == 0:
unit = 1
check_interval = Time(unit, self.interval.unit)
log.info(
f'No check interval provided, defaulting to {check_interval}')

self.check_interval = validate_check_interval(check_interval,
self.interval)

# Keep track of checkpoints by interval that have already been evaled
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
# Format: {interval: (checkpoint, run_name)}
self.checkpoints_evaled: Dict[Time, Tuple[str, str]] = {}

# Scheduling is based on the check interval, while _get_checkpoints_and_launch_runs
# will only launch runs at the interval
self.is_at_check_interval = create_interval_scheduler(
self.check_interval,
# There is a custom close to ensure that the final checkpoint
# (which is the most important) is evaled after it is written
include_end_of_training=False,
)

log.info('Initialized AsyncEval callback. Will generate runs at ' +
f'interval {interval}, checking at {check_interval}')

def _get_checkpoints_and_launch_runs(self, state: State):
"""Get the latest checkpoint from the training run.

Args:
state: The current state of the training run

Returns:
Returns checkpoints that have not been evaled
"""
checkpointer = None
for callback in state.callbacks:
if isinstance(callback, CheckpointSaver):
checkpointer = callback
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
break

if not checkpointer:
log.warning('No checkpoint saver callback found. Skipping eval')
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
return

if not checkpointer.saved_checkpoints:
log.debug(
'No saved checkpoints found on the checkpointer. Skipping eval')
return

found_checkpoints = set(list_remote_objects(
self.checkpoint_save_folder))

if not found_checkpoints:
log.debug('No saved checkpoints found yet on remote. Skipping eval')
return

for checkpoint in checkpointer.saved_checkpoints:
# Get the part of the path that contains the interval. This is
# different for sharded checkpoints (which are saved in a folder)
if state.fsdp_elastic_sharded_enabled:
# eg {save_folder}/ep0-ba1/.
interval_path = Path(checkpoint).parts[-2]
else:
# eg {save_folder}/ep0-ba1-rank0.pt
interval_path = Path(checkpoint).parts[-1]

interval = get_interval_from_checkpoint(interval_path,
self.interval.unit)
if interval.value % self.interval.value != 0:
log.debug(
f'Checkpoint {checkpoint} ({interval}) is not at an eval interval ({self.interval}), skipping'
)
continue

if interval in self.checkpoints_evaled:
continue # Skip checkpoints that have already been evaled

# Check if the checkpoint is fully uploaded. If not, skip it until upload is complete
if state.fsdp_elastic_sharded_enabled:
# checkpoint is something like folder/ep0-ba4/__0_0.distcp
checkpoint_folder = '/'.join(checkpoint.split('/')[:-1])

if f'{checkpoint_folder}/.metadata' not in found_checkpoints:
log.debug(
f'Checkpoint {checkpoint} not fully uploaded (missing metadata), skipping'
)
continue

for i in range(dist.get_world_size()):
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
proc = i % 8
rank = i // 8
shard_name = f'__{proc}_{rank}.distcp'
if f'{checkpoint_folder}/{shard_name}' not in found_checkpoints:
log.debug(
f'Checkpoint {checkpoint} not fully uploaded (missing shard {shard_name}), skipping'
)
continue

else:
if checkpoint not in found_checkpoints:
log.debug(
f'Checkpoint {checkpoint} not fully uploaded, skipping')
continue

full_checkpoint_path = f'{self.checkpoint_save_folder}/{interval_path}'
eval_run = self.launch_run(full_checkpoint_path, interval)
self.checkpoints_evaled[interval] = (full_checkpoint_path,
eval_run.name)

def run_event(self, event: Event, state: State, logger: Logger) -> None:
del logger

should_launch_run = all([
state.get_elapsed_duration() is not None,
self.check_interval(state, event),
# could also skip check intervals before the first async eval interval,
# but this may make the scheduler more complicated
self.is_at_check_interval(state, event),
dist.get_global_rank() == 0,
])

if should_launch_run:
current_interval = state.timestamp.get(self.interval.unit)
checkpoint = get_latest_checkpoint(event, state)
if not checkpoint:
return # warnings logged in get_latest_checkpoint

# TODO: ensure the checkpoint is fully written before launching the eval run
full_checkpoint = f'{self.checkpoint_save_folder}/{checkpoint}'
if full_checkpoint == self.last_checkpoint:
# Do not eval a checkpoint that has already been evaluated.
log.info(
'Skipping async eval because the checkpoint has not changed'
)
return

self.launch_run(full_checkpoint, current_interval)
self.last_checkpoint = full_checkpoint
self._get_checkpoints_and_launch_runs(state)

def close(self, state: State, logger: Logger) -> None:
del logger

if dist.get_global_rank() != 0:
return

save_latest_filename = self.training_config.get('save_latest_filename',
None)
# Eval any remaining checkpoints
self._get_checkpoints_and_launch_runs(state)

# Eval the latest checkpoint
latest_timestamp = state.timestamp.get(self.interval.unit)
if latest_timestamp not in self.checkpoints_evaled:
save_latest_filename = self.training_config.get(
'save_latest_filename', None)

if not save_latest_filename:
rank = dist.get_global_rank()
save_latest_filename = f'latest-rank{rank}.pt'

if not save_latest_filename:
rank = dist.get_global_rank()
save_latest_filename = f'latest-rank{rank}.pt'
checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}'

checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}'
self.launch_run(checkpoint, state.timestamp.get(self.interval.unit))
eval_run = self.launch_run(checkpoint, latest_timestamp)
self.checkpoints_evaled[latest_timestamp] = (checkpoint,
eval_run.name)

log.info(
f'AsyncEval callback finished. Launched {len(self.checkpoints_evaled)} eval runs:'
)
for interval, (checkpoint, run_name) in self.checkpoints_evaled.items():
log.info(f' {interval}: {checkpoint}, {run_name}')

def _get_current_run(self) -> Run:
if os.environ.get(MOSAICML_PLATFORM_ENV_VAR,
Expand Down
Loading