Skip to content

Commit

Permalink
last checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Dec 5, 2023
1 parent 99f48cb commit 1184531
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,12 @@ def __init__(
else:
self.interval = interval

self.check_interval = create_interval_scheduler(interval)
self.check_interval = create_interval_scheduler(
interval,
# There is a custom post_close to ensure that the final checkpoint
# (which is the most important) is evaled after it is written
include_end_of_training=False,
)
self.compute = compute
self.last_checkpoint: Optional[str] = None

Expand Down Expand Up @@ -219,6 +224,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
checkpoint = get_latest_checkpoint(event, state)
if not checkpoint:
return # warnings logged in get_latest_checkpoint

if checkpoint == self.last_checkpoint:
# Do not eval a checkpoint that has already been evaluated.
log.info(
Expand All @@ -229,6 +235,22 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
self.launch_run(checkpoint, current_interval)
self.last_checkpoint = checkpoint

def post_close(self) -> None:
if dist.get_global_rank() != 0:
return
self.training_config

save_folder = self.training_config['save_folder']
save_latest_filename = self.training_config.get('save_latest_filename',
None)

if not save_latest_filename:
rank = dist.get_global_rank()
save_latest_filename = f'latest-rank{rank}.pt'

checkpoint = f'{save_folder}/{save_latest_filename}'
self.launch_run(checkpoint, 'final')

def _get_current_run(self) -> Run:
if os.environ.get(MOSAICML_PLATFORM_ENV_VAR,
'false').lower() == 'false':
Expand All @@ -246,6 +268,8 @@ def _get_current_run(self) -> Run:
return get_run(run_name, include_details=True)

def launch_run(self, checkpoint: str, current_interval: str) -> Run:
log.info(f'Launching eval run for {checkpoint} at {current_interval}')

cfg = self.current_run.submitted_config
default_compute = {
'gpus': 8,
Expand Down

0 comments on commit 1184531

Please sign in to comment.