Skip to content

Commit

Permalink
use parameters from train.py to capture overrides and mounted paramet…
Browse files Browse the repository at this point in the history
…ers file
  • Loading branch information
aspfohl committed Nov 13, 2023
1 parent 7baa53f commit b58ccf9
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 43 deletions.
17 changes: 8 additions & 9 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""
Run the eval loop asynchronously as part of a MosaicML platform run.
"""Run the eval loop asynchronously as part of a MosaicML platform run.
This callback is currently experimental. The API may change in the future.
"""
Expand Down Expand Up @@ -93,6 +92,7 @@ class AsyncEval(Callback):
This callback is currently experimental. The API may change in the future.
Args:
training_config: Dict[str, Any]: The config from the training run
interval: Union[str, int, Time]: The interval describing how often eval runs should be
launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
Expand All @@ -104,20 +104,19 @@ class AsyncEval(Callback):

def __init__(
self,
training_config: Dict[str, Any],
interval: Union[str, int, Time],
compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None,
):
self.training_config = training_config
self.check_interval = create_interval_scheduler(interval)
self.compute = compute
self.count = 0
self.last_launch: Optional[Time] = None

# Run these during init to fail fast in any of the error cases
self.current_run = self._get_current_run()
self.get_eval_parameters(
self.current_run.submitted_config.parameters or {},
self.current_run.name,
)
self.get_eval_parameters(training_config, self.current_run.name)
log.info(
f'Initialized AsyncEval callback. Will generate runs at interval {interval}'
)
Expand All @@ -126,8 +125,8 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
del logger
if all([
state.get_elapsed_duration() is not None,
self.check_interval(state, event), self.last_launch
!= state.timestamp.batch,
self.check_interval(state, event),
self.last_launch != state.timestamp.batch,
dist.get_global_rank() == 0
]):
self.launch_run()
Expand Down Expand Up @@ -198,7 +197,7 @@ def launch_run(self) -> Run:
'gpus': 8,
'cluster': self.current_run.cluster,
}
params = self.get_eval_parameters(cfg.parameters or {},
params = self.get_eval_parameters(self.training_config,
self.current_run.name)

# TODO: This just runs an eval run, but we also want to attach the
Expand Down
8 changes: 6 additions & 2 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def build_icl_data_and_gauntlet(
return icl_evaluators, logger_keys, eval_gauntlet_cb


def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
def build_callback(
name: str,
kwargs: Dict[str, Any],
config: Dict[str, Any],
) -> Callback:
if name == 'lr_monitor':
return LRMonitor()
elif name == 'memory_monitor':
Expand Down Expand Up @@ -119,7 +123,7 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
elif name == 'hf_checkpointer':
return HuggingFaceCheckpointer(**kwargs)
elif name == 'async_eval':
return AsyncEval(**kwargs)
return AsyncEval(**kwargs, training_config=config)
else:
raise ValueError(f'Not sure how to build callback: {name}')

Expand Down
2 changes: 1 addition & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def main(cfg: DictConfig) -> Trainer:

# Callbacks
callbacks: List[Callback] = [
build_callback(str(name), callback_cfg)
build_callback(str(name), callback_cfg, cfg)
for name, callback_cfg in callback_configs.items()
] if callback_configs else []

Expand Down
45 changes: 22 additions & 23 deletions tests/callbacks/test_async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@
from mcli import Run, RunConfig, RunStatus

RUN_NAME = 'foo_bar-1234'
BASIC_PARAMS = {
'device_eval_batch_size': 2,
'icl_tasks': 'icl_task_example',
'max_seq_len': 3,
'model': {
'name': 'model_example',
'config_overrides': {
'attn_config': {
'foo': 'bar'
}
}
},
'tokenizer': {
'tokenizer_example': 'tokenizer_example',
},
'save_folder': 'save_folder_example',
}


def test_get_run_name():
Expand Down Expand Up @@ -37,7 +54,7 @@ def test_fails_when_not_on_platform():
match=
'AsyncEval callback is only supported when running on the MosaicML platform'
):
AsyncEval(interval='2ba')
AsyncEval(BASIC_PARAMS, interval='2ba')


def test_fails_when_no_run_name():
Expand All @@ -50,26 +67,7 @@ def test_fails_when_no_run_name():
match=
'RUN_NAME environment variable must be set to use the AsyncEval callback'
):
AsyncEval(interval='2ba')


BASIC_PARAMS = {
'device_eval_batch_size': 2,
'icl_tasks': 'icl_task_example',
'max_seq_len': 3,
'model': {
'name': 'model_example',
'config_overrides': {
'attn_config': {
'foo': 'bar'
}
}
},
'tokenizer': {
'tokenizer_example': 'tokenizer_example',
},
'save_folder': 'save_folder_example',
}
AsyncEval(BASIC_PARAMS, interval='2ba')


def test_get_eval_parameters():
Expand Down Expand Up @@ -185,7 +183,7 @@ def test_get_eval_parameters():
name=RUN_NAME,
image='fake-image',
command='echo hi',
parameters=BASIC_PARAMS,
parameters={},
),
)

Expand All @@ -196,7 +194,8 @@ def test_get_eval_parameters():
return_value=FAKE_RUN)
def test_async_eval_callback_minimal(mock_create_run: MagicMock,
mock_get_run: MagicMock):
callback = AsyncEval(interval='2ba',
callback = AsyncEval(BASIC_PARAMS,
interval='2ba',
compute={
'cluster': 'c2z3',
'nodes': 2,
Expand Down
23 changes: 15 additions & 8 deletions tests/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_tokenizer_builder(tokenizer_name: str, tokenizer_kwargs: dict):

def test_build_callback_fails():
with pytest.raises(ValueError):
build_callback('nonexistent_callback', {})
build_callback('nonexistent_callback', {}, {})


@pytest.mark.parametrize(
Expand All @@ -53,12 +53,15 @@ def test_build_generate_callback(
autospec=True) as mock_generate:
mock_generate.return_value = None
build_callback(
'generate_callback', {
'generate_callback',
{
'prompts': ['hello'],
interval_key: interval_value,
'foo': 'bar',
'something': 'else',
})
},
{},
)

assert mock_generate.call_count == 1
_, _, kwargs = mock_generate.mock_calls[0]
Expand All @@ -73,8 +76,12 @@ def test_build_generate_callback_unspecified_interval():
with mock.patch.object(Generate, '__init__',
autospec=True) as mock_generate:
mock_generate.return_value = None
build_callback('generate_callback', {
'prompts': ['hello'],
'foo': 'bar',
'something': 'else',
})
build_callback(
'generate_callback',
{
'prompts': ['hello'],
'foo': 'bar',
'something': 'else',
},
{},
)

0 comments on commit b58ccf9

Please sign in to comment.