Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async eval callback #702

Merged
merged 66 commits into from
Dec 19, 2023
Merged
Changes from 1 commit
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
e70a2f4
Async eval callback
aspfohl Oct 27, 2023
acd8b2e
add very basic tests
aspfohl Nov 8, 2023
aef814b
more tests
aspfohl Nov 8, 2023
007ae90
bump mcli
aspfohl Nov 8, 2023
d04aa40
Merge branch 'main' into anna/asynceval
aspfohl Nov 9, 2023
6cd020f
woop, missing import
aspfohl Nov 9, 2023
9fbe7a1
instance not specified error
aspfohl Nov 9, 2023
547ec21
fixes
aspfohl Nov 9, 2023
9819456
fix typing
aspfohl Nov 9, 2023
ba871d7
small testing fixes
aspfohl Nov 9, 2023
21e9880
launch new run only on main process
aspfohl Nov 9, 2023
47a8255
logger name
aspfohl Nov 9, 2023
08c24be
items
aspfohl Nov 9, 2023
bf415f0
format
aspfohl Nov 10, 2023
ecacdac
Merge branch 'main' into anna/asynceval
aspfohl Nov 10, 2023
5616ae4
Update llmfoundry/callbacks/async_eval_callback.py
aspfohl Nov 10, 2023
bc1647a
Update llmfoundry/callbacks/async_eval_callback.py
aspfohl Nov 10, 2023
3358837
feedback
aspfohl Nov 10, 2023
28e47df
Apply suggestions from code review
aspfohl Nov 13, 2023
78cc0b8
small updates
aspfohl Nov 13, 2023
7baa53f
Merge branch 'main' into anna/asynceval
aspfohl Nov 13, 2023
b58ccf9
use parameters from train.py to capture overrides and mounted paramet…
aspfohl Nov 13, 2023
d85ee5e
config_overrides
aspfohl Nov 14, 2023
0e96fea
Merge branch 'main' into anna/asynceval
aspfohl Nov 14, 2023
194774d
updates
aspfohl Nov 15, 2023
ba1280a
Merge branch 'main' into anna/asynceval
aspfohl Nov 15, 2023
08857d5
fix test
aspfohl Nov 15, 2023
6ce8b77
small fixes
aspfohl Nov 15, 2023
de155f7
add logging
aspfohl Nov 15, 2023
e5f9e9e
remove last launch check
aspfohl Nov 15, 2023
3f518f9
better logging
aspfohl Nov 15, 2023
ea742ef
fix parameters
aspfohl Nov 16, 2023
e3623f3
fix double unit in the name
aspfohl Nov 16, 2023
2deef3f
sadz
aspfohl Nov 16, 2023
91bdf43
Merge branch 'main' into anna/asynceval
aspfohl Nov 16, 2023
e8f4661
Merge branch 'main' into anna/asynceval
dakinggg Nov 30, 2023
f9e2dc7
Merge branch 'main' into anna/asynceval
aspfohl Dec 2, 2023
add7fbb
fies
aspfohl Dec 2, 2023
238086f
git integration path validation and update
aspfohl Dec 4, 2023
53a9943
detect forks, better error/comment
aspfohl Dec 4, 2023
1f35a7b
version import
aspfohl Dec 4, 2023
99f48cb
merge with main
aspfohl Dec 4, 2023
1184531
last checkpoint
aspfohl Dec 5, 2023
7af7383
post_close -> close
aspfohl Dec 6, 2023
9337af0
add todos, fix path bug
aspfohl Dec 6, 2023
87ffd86
add missing args
aspfohl Dec 6, 2023
e940f1c
remove eval_loader in callback too
aspfohl Dec 6, 2023
bb040d1
remove fit end event (already doing on close)
aspfohl Dec 6, 2023
14f386f
misc fixes
aspfohl Dec 6, 2023
ac37d09
fix test
aspfohl Dec 6, 2023
535d5fb
Merge branch 'main' into anna/asynceval
aspfohl Dec 7, 2023
9e11cf7
add back eval interval
aspfohl Dec 8, 2023
aa652f3
build_loggers and add tests
aspfohl Dec 8, 2023
818f4ac
Merge branch 'main' into anna/asynceval
aspfohl Dec 8, 2023
0e4f085
updates
aspfohl Dec 11, 2023
bf06c19
Merge branch 'main' into anna/asynceval
aspfohl Dec 11, 2023
dc25b2a
typing
aspfohl Dec 11, 2023
6865b96
changes
aspfohl Dec 11, 2023
9754b03
Merge branch 'main' into anna/asynceval
aspfohl Dec 11, 2023
1ac70cc
typing?
aspfohl Dec 11, 2023
1a485d4
Merge branch 'main' into anna/asynceval
aspfohl Dec 13, 2023
cd2a31d
metadata in eval.py
aspfohl Dec 18, 2023
f6393a9
Merge branch 'main' into anna/asynceval
aspfohl Dec 18, 2023
04865db
actually, just log metadata on every model eval
aspfohl Dec 19, 2023
557aca8
Merge branch 'main' into anna/asynceval
aspfohl Dec 19, 2023
2fd2317
Merge branch 'main' into anna/asynceval
aspfohl Dec 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
misc fixes
  • Loading branch information
aspfohl committed Dec 6, 2023
commit 14f386f8ea22048176e837f669ecd981f2f5982f
60 changes: 44 additions & 16 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@

import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union

from composer.callbacks import CheckpointSaver
@@ -94,7 +95,8 @@ def get_latest_checkpoint(event: Event, state: State) -> Optional[str]:
log.warning('No saved checkpoints found on the checkpointer')
return None

return checkpointer.saved_checkpoints[-1]
latest = checkpointer.saved_checkpoints[-1]
return str(Path(latest).parts[-1])


def get_eval_parameters(
@@ -153,6 +155,35 @@ def get_eval_parameters(
return subset_keys


def validate_interval(interval: Union[str, int, Time],
save_interval: Union[str, int, Time]) -> Time:
if isinstance(save_interval, str):
new_save_interval: Time = Time.from_timestring(save_interval)
elif isinstance(save_interval, int):
new_save_interval: Time = Time(save_interval, TimeUnit.EPOCH)
else:
new_save_interval: Time = save_interval

if isinstance(interval, str):
result: Time = Time.from_timestring(interval)
elif isinstance(interval, int):
result: Time = Time(interval, TimeUnit.EPOCH)
else:
result: Time = interval

if new_save_interval.unit != result.unit:
raise ValueError(
'Save interval and async eval interval must be in the same unit')
if result < new_save_interval:
raise ValueError(
'Async eval interval must be equal or greater (less frequent) than save interval'
)
if result.value % new_save_interval.value != 0:
raise ValueError(
'Async eval interval must be a multiple of save interval')
return result


class AsyncEval(Callback):
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
"""Run the eval loop asynchronously as part of a MosaicML platform run.

@@ -176,15 +207,14 @@ def __init__(
compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None,
):
aspfohl marked this conversation as resolved.
Show resolved Hide resolved

self.training_config = training_config

if isinstance(interval, str):
self.interval = Time.from_timestring(interval)
elif isinstance(interval, int):
self.interval = Time(interval, TimeUnit.EPOCH)
else:
self.interval = interval
for required in ('save_interval', 'save_folder'):
if required not in training_config:
raise ValueError(f'{required} required for async eval')

self.checkpoint_save_folder = training_config['save_folder']
self.training_config = training_config
self.interval = validate_interval(interval,
self.training_config['save_interval'])
self.check_interval = create_interval_scheduler(
interval,
# There is a custom close to ensure that the final checkpoint
@@ -220,34 +250,32 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
if not checkpoint:
return # warnings logged in get_latest_checkpoint

if checkpoint == self.last_checkpoint:
full_checkpoint = f'{self.checkpoint_save_folder}/{checkpoint}'
if full_checkpoint == self.last_checkpoint:
# Do not eval a checkpoint that has already been evaluated.
log.info(
'Skipping async eval because the checkpoint has not changed'
)
return

self.launch_run(checkpoint, current_interval)
self.last_checkpoint = checkpoint
self.launch_run(full_checkpoint, current_interval)
self.last_checkpoint = full_checkpoint

def close(self, state: State, logger: Logger) -> None:
del state
del logger

if dist.get_global_rank() != 0:
return
self.training_config

# TODO: enforce this exists before
save_folder = self.training_config['save_folder']
save_latest_filename = self.training_config.get('save_latest_filename',
None)

if not save_latest_filename:
rank = dist.get_global_rank()
save_latest_filename = f'latest-rank{rank}.pt'

checkpoint = f'{save_folder}/{save_latest_filename}'
checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}'
self.launch_run(checkpoint, 'final')

def _get_current_run(self) -> Run:
19 changes: 18 additions & 1 deletion tests/callbacks/test_async_eval_callback.py
Original file line number Diff line number Diff line change
@@ -6,10 +6,12 @@
from unittest.mock import MagicMock, patch

import pytest
from composer.core import Time, TimeUnit

from llmfoundry.callbacks.async_eval_callback import (AsyncEval,
get_eval_parameters,
get_run_name)
get_run_name,
validate_interval)
from mcli import Run, RunConfig, RunStatus

# here
@@ -164,6 +166,21 @@ def test_get_eval_parameters():
}


def test_validate_interval():
with pytest.raises(ValueError):
validate_interval('1ba', '1ep') # different units
with pytest.raises(ValueError):
validate_interval('1ba', '2ba') # checkpointing happens less often
with pytest.raises(ValueError):
validate_interval('3ba', '2ba') # not a multiple

assert validate_interval('2ba', '1ba') == Time(2, TimeUnit.BATCH)
two_epochs = Time(2, TimeUnit.EPOCH)
assert validate_interval(2, 2) == two_epochs
assert validate_interval(two_epochs, two_epochs) == two_epochs
assert validate_interval('2ep', two_epochs) == two_epochs


FAKE_RUN = Run(
run_uid='123',
name=RUN_NAME,
Loading