From b58ccf9c33d6d1b3227762fb7a8c1552a350a56d Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Mon, 13 Nov 2023 10:41:47 -0800 Subject: [PATCH] use parameters from train.py to capture overrides and mounted parameters file --- llmfoundry/callbacks/async_eval_callback.py | 17 ++++---- llmfoundry/utils/builders.py | 8 +++- scripts/train/train.py | 2 +- tests/callbacks/test_async_eval_callback.py | 45 ++++++++++----------- tests/test_builders.py | 23 +++++++---- 5 files changed, 52 insertions(+), 43 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 14c5a6aade..c8581049d9 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -1,8 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -""" -Run the eval loop asynchronously as part of a MosaicML platform run. +"""Run the eval loop asynchronously as part of a MosaicML platform run. This callback is currently experimental. The API may change in the future. """ @@ -93,6 +92,7 @@ class AsyncEval(Callback): This callback is currently experimental. The API may change in the future. Args: + training_config: Dict[str, Any]: The config from the training run interval: Union[str, int, Time]: The interval describing how often eval runs should be launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, @@ -104,9 +104,11 @@ class AsyncEval(Callback): def __init__( self, + training_config: Dict[str, Any], interval: Union[str, int, Time], compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None, ): + self.training_config = training_config self.check_interval = create_interval_scheduler(interval) self.compute = compute self.count = 0 @@ -114,10 +116,7 @@ def __init__( # Run these during init to fail fast in any of the error cases self.current_run = self._get_current_run() - self.get_eval_parameters( - self.current_run.submitted_config.parameters or {}, - self.current_run.name, - ) + self.get_eval_parameters(training_config, self.current_run.name) log.info( f'Initialized AsyncEval callback. Will generate runs at interval {interval}' ) @@ -126,8 +125,8 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: del logger if all([ state.get_elapsed_duration() is not None, - self.check_interval(state, event), self.last_launch - != state.timestamp.batch, + self.check_interval(state, event), + self.last_launch != state.timestamp.batch, dist.get_global_rank() == 0 ]): self.launch_run() @@ -198,7 +197,7 @@ def launch_run(self) -> Run: 'gpus': 8, 'cluster': self.current_run.cluster, } - params = self.get_eval_parameters(cfg.parameters or {}, + params = self.get_eval_parameters(self.training_config, self.current_run.name) # TODO: This just runs an eval run, but we also want to attach the diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 9c3e94c29b..fd446df18c 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -73,7 +73,11 @@ def build_icl_data_and_gauntlet( return icl_evaluators, logger_keys, eval_gauntlet_cb -def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: +def build_callback( + name: str, + kwargs: Dict[str, Any], + config: Dict[str, Any], +) -> Callback: if name == 'lr_monitor': return LRMonitor() elif name == 'memory_monitor': @@ -119,7 +123,7 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: elif name == 'hf_checkpointer': return HuggingFaceCheckpointer(**kwargs) elif name == 'async_eval': - return AsyncEval(**kwargs) + return AsyncEval(**kwargs, training_config=config) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/scripts/train/train.py b/scripts/train/train.py index 925470c4e4..64b6d7ba1c 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -504,7 +504,7 @@ def main(cfg: DictConfig) -> Trainer: # Callbacks callbacks: List[Callback] = [ - build_callback(str(name), callback_cfg) + build_callback(str(name), callback_cfg, cfg) for name, callback_cfg in callback_configs.items() ] if callback_configs else [] diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index b595405b62..acdb5f5375 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -10,6 +10,23 @@ from mcli import Run, RunConfig, RunStatus RUN_NAME = 'foo_bar-1234' +BASIC_PARAMS = { + 'device_eval_batch_size': 2, + 'icl_tasks': 'icl_task_example', + 'max_seq_len': 3, + 'model': { + 'name': 'model_example', + 'config_overrides': { + 'attn_config': { + 'foo': 'bar' + } + } + }, + 'tokenizer': { + 'tokenizer_example': 'tokenizer_example', + }, + 'save_folder': 'save_folder_example', +} def test_get_run_name(): @@ -37,7 +54,7 @@ def test_fails_when_not_on_platform(): match= 'AsyncEval callback is only supported when running on the MosaicML platform' ): - AsyncEval(interval='2ba') + AsyncEval(BASIC_PARAMS, interval='2ba') def test_fails_when_no_run_name(): @@ -50,26 +67,7 @@ def test_fails_when_no_run_name(): match= 'RUN_NAME environment variable must be set to use the AsyncEval callback' ): - AsyncEval(interval='2ba') - - -BASIC_PARAMS = { - 'device_eval_batch_size': 2, - 'icl_tasks': 'icl_task_example', - 'max_seq_len': 3, - 'model': { - 'name': 'model_example', - 'config_overrides': { - 'attn_config': { - 'foo': 'bar' - } - } - }, - 'tokenizer': { - 'tokenizer_example': 'tokenizer_example', - }, - 'save_folder': 'save_folder_example', -} + AsyncEval(BASIC_PARAMS, interval='2ba') def test_get_eval_parameters(): @@ -185,7 +183,7 @@ def test_get_eval_parameters(): name=RUN_NAME, image='fake-image', command='echo hi', - parameters=BASIC_PARAMS, + parameters={}, ), ) @@ -196,7 +194,8 @@ def test_get_eval_parameters(): return_value=FAKE_RUN) def test_async_eval_callback_minimal(mock_create_run: MagicMock, mock_get_run: MagicMock): - callback = AsyncEval(interval='2ba', + callback = AsyncEval(BASIC_PARAMS, + interval='2ba', compute={ 'cluster': 'c2z3', 'nodes': 2, diff --git a/tests/test_builders.py b/tests/test_builders.py index 0d24d2154f..3fd638e9f1 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -37,7 +37,7 @@ def test_tokenizer_builder(tokenizer_name: str, tokenizer_kwargs: dict): def test_build_callback_fails(): with pytest.raises(ValueError): - build_callback('nonexistent_callback', {}) + build_callback('nonexistent_callback', {}, {}) @pytest.mark.parametrize( @@ -53,12 +53,15 @@ def test_build_generate_callback( autospec=True) as mock_generate: mock_generate.return_value = None build_callback( - 'generate_callback', { + 'generate_callback', + { 'prompts': ['hello'], interval_key: interval_value, 'foo': 'bar', 'something': 'else', - }) + }, + {}, + ) assert mock_generate.call_count == 1 _, _, kwargs = mock_generate.mock_calls[0] @@ -73,8 +76,12 @@ def test_build_generate_callback_unspecified_interval(): with mock.patch.object(Generate, '__init__', autospec=True) as mock_generate: mock_generate.return_value = None - build_callback('generate_callback', { - 'prompts': ['hello'], - 'foo': 'bar', - 'something': 'else', - }) + build_callback( + 'generate_callback', + { + 'prompts': ['hello'], + 'foo': 'bar', + 'something': 'else', + }, + {}, + )