From 11845313df3880af73a516d11ec8d06247ddfd48 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Tue, 5 Dec 2023 23:30:03 +0000 Subject: [PATCH] last checkpoint --- llmfoundry/callbacks/async_eval_callback.py | 26 ++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index d77a311250..a91452616f 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -190,7 +190,12 @@ def __init__( else: self.interval = interval - self.check_interval = create_interval_scheduler(interval) + self.check_interval = create_interval_scheduler( + interval, + # There is a custom post_close to ensure that the final checkpoint + # (which is the most important) is evaled after it is written + include_end_of_training=False, + ) self.compute = compute self.last_checkpoint: Optional[str] = None @@ -219,6 +224,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: checkpoint = get_latest_checkpoint(event, state) if not checkpoint: return # warnings logged in get_latest_checkpoint + if checkpoint == self.last_checkpoint: # Do not eval a checkpoint that has already been evaluated. log.info( @@ -229,6 +235,22 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: self.launch_run(checkpoint, current_interval) self.last_checkpoint = checkpoint + def post_close(self) -> None: + if dist.get_global_rank() != 0: + return + self.training_config + + save_folder = self.training_config['save_folder'] + save_latest_filename = self.training_config.get('save_latest_filename', + None) + + if not save_latest_filename: + rank = dist.get_global_rank() + save_latest_filename = f'latest-rank{rank}.pt' + + checkpoint = f'{save_folder}/{save_latest_filename}' + self.launch_run(checkpoint, 'final') + def _get_current_run(self) -> Run: if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower() == 'false': @@ -246,6 +268,8 @@ def _get_current_run(self) -> Run: return get_run(run_name, include_details=True) def launch_run(self, checkpoint: str, current_interval: str) -> Run: + log.info(f'Launching eval run for {checkpoint} at {current_interval}') + cfg = self.current_run.submitted_config default_compute = { 'gpus': 8,