Skip to content

Commit

Permalink
serialize time
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Jan 6, 2024
1 parent 96e3047 commit 85bcfcc
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
24 changes: 24 additions & 0 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,30 @@ def __init__(
log.info('Initialized AsyncEval callback. Will generate runs at ' +
f'interval {interval}, checking at {self.check_interval}')

def state_dict(self) -> Dict[str, Any]:
checkpoints_evaled = []
for i, (c, rn) in self.checkpoints_evaled.items():
interval_dict = {
'value': i.value,
'unit': i.unit.value,
}
checkpoints_evaled.append((interval_dict, c, rn))

return {
'checkpoints_evaled': checkpoints_evaled,
}

def load_state_dict(self, state_dict: Dict[str, Any]):
previous_checkpoints_evaled = state_dict.get('checkpoints_evaled', [])
if previous_checkpoints_evaled:
for (i, c, rn) in previous_checkpoints_evaled:
interval = Time(i['value'], TimeUnit(i['unit']))
self.checkpoints_evaled[interval] = (c, rn)

log.info(
f'Loaded previous checkpoints evaled: {self.checkpoints_evaled}'
)

@staticmethod
def _get_ready_sharded_checkpoints(
checkpointer_checkpoints: Dict[str, Timestamp],
Expand Down
35 changes: 35 additions & 0 deletions tests/callbacks/test_async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,41 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock,
assert parameters['run_name'] == 'eval-1ba-foo_bar' # original run


@patch('llmfoundry.callbacks.async_eval_callback.get_run',
return_value=FAKE_RUN)
def test_async_eval_state(mock_create_run: MagicMock):
callback = AsyncEval(BASIC_PARAMS, interval='2ba')

assert not callback.checkpoints_evaled

state_dict = callback.state_dict()
assert state_dict['checkpoints_evaled'] == []

callback.load_state_dict(state_dict)
assert not callback.checkpoints_evaled

callback.checkpoints_evaled = {
Time(1, TimeUnit.BATCH): ('checkpoint/path', 'run-name'),
}
state_dict = callback.state_dict()
assert state_dict['checkpoints_evaled'] == [
(
{
'value': 1,
'unit': 'ba',
},
'checkpoint/path',
'run-name',
),
]

callback.checkpoints_evaled = {}
callback.load_state_dict(state_dict)
assert callback.checkpoints_evaled == {
Time(1, TimeUnit.BATCH): ('checkpoint/path', 'run-name'),
}


INTEGRATION_GIT_LLMFOUNDRY = {
'integration_type': 'git_repo',
'git_repo': 'mosaicml/llm-foundry',
Expand Down

0 comments on commit 85bcfcc

Please sign in to comment.