Skip to content

Commit

Permalink
Async callback: Don't skip checkpoints, reliably only launch async ev…
Browse files Browse the repository at this point in the history
…al when the checkpoint is ready (#813)

* working without sharded checkpointing..

* add more debugs

* try this

* more debugging

* yikes dumb bug

* add notes

* fixes

* remove prints

* small updates

* fix typo

* refactor

* fix docstring formatting

* fighting with docstrings

* try this

* add unit tests

* point to composer update

* values -> items

* serialize time

* fix merge

* nits

* warning, small comment update

* add error

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
aspfohl and dakinggg authored Feb 16, 2024
1 parent 3a99270 commit 6e3842b
Show file tree
Hide file tree
Showing 2 changed files with 345 additions and 67 deletions.
306 changes: 240 additions & 66 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@

import logging
import os
import warnings
from collections import Counter
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from composer.callbacks import CheckpointSaver
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core import Callback, Event, State, Time, Timestamp, TimeUnit
from composer.loggers import Logger
from composer.loggers.mosaicml_logger import (MOSAICML_PLATFORM_ENV_VAR,
RUN_NAME_ENV_VAR)
from composer.utils import dist
from composer.utils.file_helpers import list_remote_objects
from composer.utils.misc import create_interval_scheduler

from mcli import Run, RunConfig, create_run, get_run
Expand Down Expand Up @@ -73,34 +76,6 @@ def get_run_name(training_run_name: str, current_interval: str) -> str:
return f'{RUN_NAME_PREFIX}-{current_interval}-{name_without_uuid_suffix}'


def get_latest_checkpoint(event: Event, state: State) -> Optional[str]:
"""Get the latest checkpoint from the training run.
Args:
event: The current run event
state: The current state of the training run
Returns:
The path to the latest checkpoint, or None if there is not a latest checkpoint
"""
checkpointer = None
for callback in state.callbacks:
if isinstance(callback, CheckpointSaver):
checkpointer = callback
break

if not checkpointer:
log.warning('No checkpoint saver callback found')
return None

if not checkpointer.saved_checkpoints:
log.warning('No saved checkpoints found on the checkpointer')
return None

latest = checkpointer.saved_checkpoints[-1]
return str(Path(latest).parts[-1])


def get_eval_parameters(
parameters: Dict[str, Any],
checkpoint: str,
Expand Down Expand Up @@ -199,6 +174,9 @@ def validate_eval_run_config(
return run_config


CHECKS_PER_INTERVAL = 4


class AsyncEval(Callback):
"""Run the eval loop asynchronously as part of a MosaicML platform run.
Expand Down Expand Up @@ -234,76 +212,263 @@ def __init__(
eval_run_config: Optional[Dict[str, Any]] = None,
):

# Run these during init to fail fast in any of the error cases
for required in ('save_interval', 'save_folder'):
if required not in training_params:
raise ValueError(f'{required} required for async eval')

if '/' in training_params.get('save_filename', ''):
raise ValueError(
'AsyncEval not supported for save_filename that includes a path'
)

self.checkpoint_save_folder = training_params['save_folder']
self.training_params = training_params
self.eval_run_config = validate_eval_run_config(eval_run_config)
self.interval = validate_interval(interval,
self.training_params['save_interval'])
self.check_interval = create_interval_scheduler(
interval,
# There is a custom close to ensure that the final checkpoint
# (which is the most important) is evaled after it is written
include_end_of_training=False,
)
self.last_checkpoint: Optional[str] = None

# Run these during init to fail fast in any of the error cases
self.current_run = self._get_current_run()
get_eval_parameters(
parameters=training_params,
checkpoint='test',
training_run_name=self.current_run.name,
)
log.info(
f'Initialized AsyncEval callback. Will generate runs at interval {interval}'

# Validate the interval (how often to launch eval runs)
self.interval = validate_interval(interval,
self.training_params['save_interval'])

# Configures how often to check for new checkpoints. This is semi-arbitrary;
# really we just want to check often enough to pull relevant checkpoints
# but not so often that we're constantly checking
check_interval_value = max(self.interval.value // CHECKS_PER_INTERVAL,
1)
self.check_interval = Time(check_interval_value, self.interval.unit)

# Keep track of checkpoints that have already been evaled
# Format: {eval_timestamp: (checkpoint, run_name)}
self.checkpoints_evaled: Dict[Time, Tuple[str, str]] = {}

# Scheduling is based on the check interval, while _get_checkpoints_and_launch_runs
# will only launch runs at the interval
self.is_at_check_interval = create_interval_scheduler(
self.check_interval,
# There is a custom close to ensure that the final checkpoint
# (which is the most important) is evaled after it is written
include_end_of_training=False,
)

log.info('Initialized AsyncEval callback. Will generate runs at ' +
f'interval {interval}, checking at {self.check_interval}')

def state_dict(self) -> Dict[str, Any]:
checkpoints_evaled = []
for eval_ts, (checkpoint, run_name) in self.checkpoints_evaled.items():
eval_ts_dict = {
'value': eval_ts.value,
'unit': eval_ts.unit.value,
}
checkpoints_evaled.append((eval_ts_dict, checkpoint, run_name))

return {
'checkpoints_evaled': checkpoints_evaled,
}

def load_state_dict(self, state_dict: Dict[str, Any]):
previous_checkpoints_evaled = state_dict.get('checkpoints_evaled', [])
if previous_checkpoints_evaled:
for (eval_ts, checkpoint, run_name) in previous_checkpoints_evaled:
eval_ts = Time(eval_ts['value'], TimeUnit(eval_ts['unit']))
self.checkpoints_evaled[eval_ts] = (checkpoint, run_name)

log.info(
f'Loaded previous checkpoints evaled: {self.checkpoints_evaled}'
)

@staticmethod
def _get_ready_sharded_checkpoints(
checkpointer_checkpoints: Dict[str, Timestamp],
remote_files: List[str],
) -> Dict[str, Timestamp]:
"""Identify checkpoints ready to be evaled based on remote files.
This has special logic for sharded checkpoints to consider checkpoints composed
of multiple shards (one per gpu) and metadata
Args:
checkpointer_checkpoints: All checkpoints from the checkpointer state
remote_files: List of remote files in the save folder
Returns:
Dict of checkpoints that are complete and ready to be evaled
"""
# Count the number of shards for each checkpoint group
remote_file_group_counts = Counter()
for f in remote_files:
checkpoint_ts_path = Path(f).parts[-2]
remote_file_group_counts[checkpoint_ts_path] += 1

# Check if all shards are present for each checkpoint group
checkpoints_to_eval = {}
for checkpoint, checkpoint_ts in checkpointer_checkpoints.items():
# eg {save_folder}/ep0-ba1/file.blah.
checkpoint_ts_path = Path(checkpoint).parts[-2]

# expecting one shard per gpu + 1 for metadata
expected_shard_count = dist.get_world_size() + 1
if remote_file_group_counts[
checkpoint_ts_path] != expected_shard_count:
log.debug(
f'Checkpoint {checkpoint} not fully uploaded (missing shards '
+
f'{remote_file_group_counts[checkpoint_ts_path]}/{expected_shard_count}), skipping'
)
continue

checkpoints_to_eval[checkpoint_ts_path] = checkpoint_ts

return checkpoints_to_eval

@staticmethod
def _get_ready_single_checkpoints(
checkpointer_checkpoints: Dict[str, Timestamp],
remote_checkpoints: List[str],
) -> Dict[str, Timestamp]:
"""Identify checkpoints ready to be evaled based on remote checkpoints.
This is much simpler than the sharded case, because there is only one file
Args:
checkpointer_checkpoints: All checkpoints from the checkpointer state
remote_checkpoints: List of remote checkpoints in the save folder
Returns:
Dict of checkpoints that are complete and ready to be evaled
"""
unique_remote_checkpoints = set(remote_checkpoints)

checkpoints_to_eval = {}
for checkpoint, checkpoint_ts in checkpointer_checkpoints.items():
# This assumes checkpoint_ts_path is unique per checkpoint,
# eg the default {save_folder}/ep0-ba1-rank0.pt
checkpoint_ts_path = Path(checkpoint).parts[-1]

if checkpoint not in unique_remote_checkpoints:
log.debug(
f'Checkpoint {checkpoint} not fully uploaded, skipping')
continue

checkpoints_to_eval[checkpoint_ts_path] = checkpoint_ts
return checkpoints_to_eval

def _get_checkpoints_and_launch_runs(self, state: State):
"""Get the latest checkpoint from the training run.
Args:
state: The current state of the training run
Returns:
Returns checkpoints that have not been evaled
"""
checkpointer = None
for callback in state.callbacks:
if isinstance(callback, CheckpointSaver):
if checkpointer is None:
checkpointer = callback
else:
log.warning(
'Multiple checkpoint savers found. Using the first one')

if not checkpointer:
warnings.warn('No checkpoint saver callback found. Skipping eval')
return

if not checkpointer.all_saved_checkpoints_to_timestamp:
log.debug(
'No saved checkpoints found on the checkpointer. Skipping eval')
return

log.debug(
f'Found {len(checkpointer.all_saved_checkpoints_to_timestamp)} ' +
f'checkpoints: {checkpointer.all_saved_checkpoints_to_timestamp}')

remote_checkpoints = list_remote_objects(self.checkpoint_save_folder)

if not remote_checkpoints:
log.debug('No saved checkpoints found yet on remote. Skipping eval')
return

if state.fsdp_elastic_sharded_enabled:
checkpoints_to_eval = self._get_ready_sharded_checkpoints(
checkpointer.all_saved_checkpoints_to_timestamp,
remote_checkpoints)
else:
checkpoints_to_eval = self._get_ready_single_checkpoints(
checkpointer.all_saved_checkpoints_to_timestamp,
remote_checkpoints)

for checkpoint_interval_path, checkpoint_timestamp in checkpoints_to_eval.items(
):
checkpoint_ts = checkpoint_timestamp.get(self.interval.unit)
if checkpoint_ts.value % self.interval.value != 0:
log.debug(
f'Checkpoint {checkpoint_interval_path} ({checkpoint_ts}) is '
+ f'not at an eval interval ({self.interval}), skipping')
continue
if checkpoint_ts in self.checkpoints_evaled:
continue # Skip checkpoints that have already been evaled

full_checkpoint_path = f'{self.checkpoint_save_folder}/{checkpoint_interval_path}'
eval_run = self.launch_run(full_checkpoint_path, checkpoint_ts)
self.checkpoints_evaled[checkpoint_ts] = (
full_checkpoint_path,
eval_run.name,
)

def run_event(self, event: Event, state: State, logger: Logger) -> None:
del logger

should_launch_run = all([
state.get_elapsed_duration() is not None,
self.check_interval(state, event),
# could also skip check intervals before the first async eval interval,
# but this may make the scheduler more complicated
self.is_at_check_interval(state, event),
dist.get_global_rank() == 0,
])

if should_launch_run:
current_interval = state.timestamp.get(self.interval.unit)
checkpoint = get_latest_checkpoint(event, state)
if not checkpoint:
return # warnings logged in get_latest_checkpoint

# TODO: ensure the checkpoint is fully written before launching the eval run
full_checkpoint = f'{self.checkpoint_save_folder}/{checkpoint}'
if full_checkpoint == self.last_checkpoint:
# Do not eval a checkpoint that has already been evaluated.
log.info(
'Skipping async eval because the checkpoint has not changed'
)
return

self.launch_run(full_checkpoint, current_interval)
self.last_checkpoint = full_checkpoint
self._get_checkpoints_and_launch_runs(state)

def close(self, state: State, logger: Logger) -> None:
del logger

if dist.get_global_rank() != 0:
return

save_latest_filename = self.training_params.get('save_latest_filename',
None)
# Eval any remaining checkpoints
self._get_checkpoints_and_launch_runs(state)

# Eval the latest checkpoint
latest_timestamp = state.timestamp.get(self.interval.unit)
if latest_timestamp not in self.checkpoints_evaled:
save_latest_filename = self.training_params.get(
'save_latest_filename', None)

if not save_latest_filename:
rank = dist.get_global_rank()
save_latest_filename = f'latest-rank{rank}.pt'

checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}'

if not save_latest_filename:
rank = dist.get_global_rank()
save_latest_filename = f'latest-rank{rank}.pt'
eval_run = self.launch_run(checkpoint, latest_timestamp)
self.checkpoints_evaled[latest_timestamp] = (checkpoint,
eval_run.name)

checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}'
self.launch_run(checkpoint, state.timestamp.get(self.interval.unit))
log.info(
f'AsyncEval callback finished. Launched {len(self.checkpoints_evaled)} eval runs:'
)
for checkpoint_ts, (checkpoint,
run_name) in self.checkpoints_evaled.items():
log.info(f' {checkpoint_ts}: {checkpoint}, {run_name}')

def _get_current_run(self) -> Run:
if os.environ.get(MOSAICML_PLATFORM_ENV_VAR,
Expand All @@ -322,6 +487,15 @@ def _get_current_run(self) -> Run:
return get_run(run_name, include_details=True)

def launch_run(self, checkpoint: str, current_interval: Time) -> Run:
"""Launch a new eval run.
Args:
checkpoint: The checkpoint to eval
current_interval: The interval of the checkpoint
Returns:
The launched run (mcli.Run type)
"""
log.info(f'Launching eval run for {checkpoint} at {current_interval}')

cfg = self.current_run.submitted_config
Expand Down
Loading

0 comments on commit 6e3842b

Please sign in to comment.