From bdcce63781596d88b02d813c1e0ff875f67bd12b Mon Sep 17 00:00:00 2001 From: Anna Date: Fri, 26 Jan 2024 17:20:03 -0800 Subject: [PATCH] Add support for eval_loader & eval_subset_num_batches in async callback (#834) * Skip evalloader in training if using async eval * add support for subset_num_batches * remove todo * eval first * rename arg * fix * small updates * om * fix test * eval run config --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/callbacks/async_eval_callback.py | 78 ++++++++++++++++----- llmfoundry/utils/builders.py | 2 +- scripts/eval/eval.py | 17 +++-- scripts/train/train.py | 46 ++++++------ tests/callbacks/test_async_eval_callback.py | 54 ++++++++++---- 5 files changed, 139 insertions(+), 58 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 4227448d87..6cd57d440c 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -19,7 +19,7 @@ from composer.utils import dist from composer.utils.misc import create_interval_scheduler -from mcli import ComputeConfig, Run, RunConfig, create_run, get_run +from mcli import Run, RunConfig, create_run, get_run log = logging.getLogger(__name__) @@ -33,7 +33,9 @@ OPTIONAL_PARAMS_FOR_EVAL = { 'dist_timeout', 'eval_gauntlet', + 'eval_loader', 'fsdp_config', + 'eval_subset_num_batches', 'icl_subset_num_batches', 'loggers', 'precision', @@ -175,50 +177,84 @@ def validate_interval(interval: Union[str, int, Time], return async_interval +def validate_eval_run_config( + eval_run_config: Optional[Dict[str, Any]]) -> Dict[str, Any]: + + if not eval_run_config: + return {} + + run_config = eval_run_config.copy() + + supported_keys = {'image', 'command', 'compute', 'scheduling'} + found_unsupported = set() + for key in run_config: + if key not in supported_keys: + found_unsupported.add(key) + + if found_unsupported: + raise ValueError( + f'Unsupported eval run config keys found: {", ".join(found_unsupported)}' + + f'. Supported keys: {supported_keys}') + + return run_config + + class AsyncEval(Callback): """Run the eval loop asynchronously as part of a MosaicML platform run. This callback is currently experimental. The API may change in the future. Args: - training_config: Dict[str, Any]: The config from the training run + training_params: Dict[str, Any]: The parameter config from the training run interval: Union[str, int, Time]: The interval describing how often eval runs should be launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. - compute: Optional[Union[ComputeConfig, Dict[str, Any]]]: The compute configuration to - use for the eval run. If not provided, the same cluster as the current run and a - single, full GPU node will be used. + eval_run_config: Optional[Dict[str, Any]]: A subset of mcli run config values to use + for the eval run. If not specified, any fields from run config will be created + dynamically from the training run config and the interval. The following fields + are supported: + - ``image``: Image of the eval run. Default: same as training run + - ``command``: Command to run for the eval run. Default: calls + `composer scripts/eval/eval.py $PARAMETERS`. If custom setup is needed, + the command should include calling the eval script with $PARAMETERS + - ``compute``: Compute to use for the eval run. Default: same cluster as + the training run and a single node (8 GPUs) + - ``scheduling``: Scheduling to use for the eval run. Default: same as training run + + All fields are optional, but if specified, must be valid for a mcli run config. We + provide this optional config to give you the most flexibility in customizing the eval + run, but it is recommended to use the default values unless you have a specific use case """ def __init__( self, - training_config: Dict[str, Any], + training_params: Dict[str, Any], interval: Union[str, int, Time], - compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None, + eval_run_config: Optional[Dict[str, Any]] = None, ): for required in ('save_interval', 'save_folder'): - if required not in training_config: + if required not in training_params: raise ValueError(f'{required} required for async eval') - self.checkpoint_save_folder = training_config['save_folder'] - self.training_config = training_config + self.checkpoint_save_folder = training_params['save_folder'] + self.training_params = training_params + self.eval_run_config = validate_eval_run_config(eval_run_config) self.interval = validate_interval(interval, - self.training_config['save_interval']) + self.training_params['save_interval']) self.check_interval = create_interval_scheduler( interval, # There is a custom close to ensure that the final checkpoint # (which is the most important) is evaled after it is written include_end_of_training=False, ) - self.compute = compute self.last_checkpoint: Optional[str] = None # Run these during init to fail fast in any of the error cases self.current_run = self._get_current_run() get_eval_parameters( - parameters=training_config, + parameters=training_params, checkpoint='test', training_run_name=self.current_run.name, ) @@ -259,7 +295,7 @@ def close(self, state: State, logger: Logger) -> None: if dist.get_global_rank() != 0: return - save_latest_filename = self.training_config.get('save_latest_filename', + save_latest_filename = self.training_params.get('save_latest_filename', None) if not save_latest_filename: @@ -297,7 +333,7 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run: run_name = get_run_name(self.current_run.name, str(current_interval)) params = get_eval_parameters( - parameters=self.training_config, + parameters=self.training_params, checkpoint=checkpoint, training_run_name=self.current_run.name, ) @@ -347,12 +383,16 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run: # TODO: This just runs an eval run, but we also want to attach the # deployment, which would require a hf conversion and parametrizing the # dependent_deployment in the run config - command = f'cd {installation_path}/scripts \n composer eval/eval.py $PARAMETERS' + default_command = f'cd {installation_path}/scripts \n composer eval/eval.py $PARAMETERS' run_config = RunConfig( name=run_name, - image=self.current_run.image, - compute=self.compute or default_compute, - command=command, + image=self.eval_run_config.get('image', self.current_run.image), + command=self.eval_run_config.get('command', default_command), + compute=self.eval_run_config.get('compute', default_compute), + scheduling=self.eval_run_config.get( + 'scheduling', + self.current_run.submitted_config.scheduling, + ), integrations=integrations, env_variables=cfg.env_variables, metadata=cfg.metadata, diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 29642381f8..42f817b386 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -213,7 +213,7 @@ def build_callback( raise ValueError( 'Parameters config is required for async eval callback') - return AsyncEval(**kwargs, training_config=config) + return AsyncEval(**kwargs, training_params=config) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index c783a4f513..d4ba39acfa 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -118,6 +118,7 @@ def evaluate_model( python_log_level: Optional[str], precision: str, eval_gauntlet_df: Optional[pd.DataFrame], + eval_subset_num_batches: int, icl_subset_num_batches: Optional[int], metadata: Optional[Dict[str, str]], logged_config: DictConfig, @@ -224,7 +225,8 @@ def evaluate_model( if torch.cuda.is_available(): torch.cuda.synchronize() a = time.time() - trainer.eval(eval_dataloader=evaluators) + trainer.eval(eval_dataloader=evaluators, + subset_num_batches=eval_subset_num_batches) if torch.cuda.is_available(): torch.cuda.synchronize() b = time.time() @@ -299,10 +301,14 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: 'loggers', must_exist=False, default_value={}) - icl_subset_num_batches: int = pop_config(cfg, - 'icl_subset_num_batches', - must_exist=False, - default_value=None) + eval_subset_num_batches: int = pop_config(cfg, + 'eval_subset_num_batches', + must_exist=False, + default_value=-1) + icl_subset_num_batches: Optional[int] = pop_config(cfg, + 'icl_subset_num_batches', + must_exist=False, + default_value=None) metadata: Optional[Dict[str, str]] = pop_config(cfg, 'metadata', must_exist=False, @@ -356,6 +362,7 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: python_log_level=python_log_level, precision=precision, eval_gauntlet_df=eval_gauntlet_df, + eval_subset_num_batches=eval_subset_num_batches, icl_subset_num_batches=icl_subset_num_batches, metadata=metadata, logged_config=logged_cfg, diff --git a/scripts/train/train.py b/scripts/train/train.py index f28f8718ba..7bb5e71394 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -283,7 +283,8 @@ def main(cfg: DictConfig) -> Trainer: callback_configs: Optional[DictConfig] = pop_config(cfg, 'callbacks', must_exist=False, - default_value=None) + default_value=None, + convert=True) algorithm_configs: Optional[DictConfig] = pop_config(cfg, 'algorithms', must_exist=False, @@ -519,8 +520,7 @@ def main(cfg: DictConfig) -> Trainer: for name, callback_cfg in callback_configs.items() ] if callback_configs else [] - use_async_eval = any( - isinstance(callback, AsyncEval) for callback in callbacks) + use_async_eval = any(isinstance(c, AsyncEval) for c in callbacks) # Algorithms algorithms = [ @@ -540,22 +540,28 @@ def main(cfg: DictConfig) -> Trainer: mosaicml_logger.log_metrics({'data_validated': time.time()}) ## Evaluation - log.info('Building eval loader...') - eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len - # TODO: evaluators should not be built at all if use_async_eval is True - # This will be fixed when eval_loader support is fully added to AsyncEval - evaluators, _, eval_gauntlet_callback = build_evaluators( - eval_loader_config, - icl_tasks_config if not use_async_eval else None, - eval_gauntlet_config if not use_async_eval else None, - tokenizer=tokenizer, - device_eval_batch_size=device_eval_batch_size, - icl_seq_len=eval_icl_seq_len, - icl_subset_num_batches=icl_subset_num_batches, - ) - - if eval_gauntlet_callback is not None and not use_async_eval: - callbacks.append(eval_gauntlet_callback) + if use_async_eval: + evaluators = [] + if eval_first: + warnings.warn( + 'AsyncEval callback does not support eval_first=True. Ignoring.' + ) + eval_first = False + + else: + log.info('Building eval loader...') + eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len + evaluators, _, eval_gauntlet_callback = build_evaluators( + eval_loader_config, + icl_tasks_config, + eval_gauntlet_config, + tokenizer=tokenizer, + device_eval_batch_size=device_eval_batch_size, + icl_seq_len=eval_icl_seq_len, + icl_subset_num_batches=icl_subset_num_batches, + ) + if eval_gauntlet_callback is not None: + callbacks.append(eval_gauntlet_callback) # Build Model log.info('Initializing model...') @@ -582,7 +588,7 @@ def main(cfg: DictConfig) -> Trainer: optimizer = build_optimizer(model, optimizer_name, optimizer_config) # Now add the eval metrics - if eval_loader_config is not None: + if eval_loader_config is not None and not use_async_eval: train_metrics = model.get_metrics(is_train=True) evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics) diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index b3a1e98f79..92cb738d9c 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -11,10 +11,10 @@ from llmfoundry.callbacks.async_eval_callback import (AsyncEval, get_eval_parameters, get_run_name, + validate_eval_run_config, validate_interval) from mcli import Run, RunConfig, RunStatus -# here RUN_NAME = 'foo_bar-1234' BASIC_PARAMS = { 'save_interval': '1ba', @@ -191,6 +191,29 @@ def test_validate_interval(): assert validate_interval('2ep', two_epochs) == two_epochs +def test_validate_eval_run_config(): + assert validate_eval_run_config(None) == {} + assert validate_eval_run_config({}) == {} + + with pytest.raises(ValueError): + validate_eval_run_config({'foo': 'bar'}) + + valid_config = { + 'image': 'example_image', + 'command': 'example_command', + 'compute': { + 'gpus': 1, + 'cluster': 'example_cluster', + }, + 'scheduling': { + 'priority': 'high', + 'preemptible': True, + }, + } + res = validate_eval_run_config(valid_config) + assert res == valid_config + + FAKE_RUN = Run( run_uid='123', name=RUN_NAME, @@ -223,12 +246,16 @@ def test_validate_interval(): return_value=FAKE_RUN) def test_async_eval_callback_minimal(mock_create_run: MagicMock, mock_get_run: MagicMock): - callback = AsyncEval(BASIC_PARAMS, - interval='2ba', - compute={ - 'cluster': 'c2z3', - 'nodes': 2, - }) + callback = AsyncEval( + BASIC_PARAMS, + interval='2ba', + eval_run_config={ + 'compute': { + 'cluster': 'c2z3', + 'nodes': 2, + }, + }, + ) assert callback.current_run.name == RUN_NAME assert mock_get_run.call_count == 1 assert mock_get_run.call_args[0][0] == RUN_NAME @@ -310,12 +337,13 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, return_value=FAKE_RUN_WITH_INTEGRATIONS) def test_async_eval_callback_integrations(mock_create_run: MagicMock, mock_get_run: MagicMock): - callback = AsyncEval(BASIC_PARAMS, - interval='2ba', - compute={ - 'cluster': 'c2z3', - 'nodes': 2, - }) + callback = AsyncEval( + BASIC_PARAMS, + interval='2ba', + eval_run_config={'compute': { + 'cluster': 'c2z3', + 'nodes': 2, + }}) assert mock_get_run.call_count == 1 callback.launch_run('checkpoint/path', Time(1, TimeUnit.BATCH))