diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index adf7cfc54d..cff14fa19a 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -222,6 +222,30 @@ def __init__( log.info('Initialized AsyncEval callback. Will generate runs at ' + f'interval {interval}, checking at {self.check_interval}') + def state_dict(self) -> Dict[str, Any]: + checkpoints_evaled = [] + for i, (c, rn) in self.checkpoints_evaled.items(): + interval_dict = { + 'value': i.value, + 'unit': i.unit.value, + } + checkpoints_evaled.append((interval_dict, c, rn)) + + return { + 'checkpoints_evaled': checkpoints_evaled, + } + + def load_state_dict(self, state_dict: Dict[str, Any]): + previous_checkpoints_evaled = state_dict.get('checkpoints_evaled', []) + if previous_checkpoints_evaled: + for (i, c, rn) in previous_checkpoints_evaled: + interval = Time(i['value'], TimeUnit(i['unit'])) + self.checkpoints_evaled[interval] = (c, rn) + + log.info( + f'Loaded previous checkpoints evaled: {self.checkpoints_evaled}' + ) + @staticmethod def _get_ready_sharded_checkpoints( checkpointer_checkpoints: Dict[str, Timestamp], diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index dcc7222471..2eb8ff54b9 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -284,6 +284,41 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, assert parameters['run_name'] == 'eval-1ba-foo_bar' # original run +@patch('llmfoundry.callbacks.async_eval_callback.get_run', + return_value=FAKE_RUN) +def test_async_eval_state(mock_create_run: MagicMock): + callback = AsyncEval(BASIC_PARAMS, interval='2ba') + + assert not callback.checkpoints_evaled + + state_dict = callback.state_dict() + assert state_dict['checkpoints_evaled'] == [] + + callback.load_state_dict(state_dict) + assert not callback.checkpoints_evaled + + callback.checkpoints_evaled = { + Time(1, TimeUnit.BATCH): ('checkpoint/path', 'run-name'), + } + state_dict = callback.state_dict() + assert state_dict['checkpoints_evaled'] == [ + ( + { + 'value': 1, + 'unit': 'ba', + }, + 'checkpoint/path', + 'run-name', + ), + ] + + callback.checkpoints_evaled = {} + callback.load_state_dict(state_dict) + assert callback.checkpoints_evaled == { + Time(1, TimeUnit.BATCH): ('checkpoint/path', 'run-name'), + } + + INTEGRATION_GIT_LLMFOUNDRY = { 'integration_type': 'git_repo', 'git_repo': 'mosaicml/llm-foundry',