From e70a2f47750b2b9b7232bfed850d01a4282f10cf Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Fri, 27 Oct 2023 10:52:21 -0700 Subject: [PATCH 01/49] Async eval callback --- llmfoundry/callbacks/__init__.py | 2 + llmfoundry/callbacks/async_eval_callback.py | 156 ++++++++++++++++++++ llmfoundry/utils/builders.py | 8 +- scripts/train/train.py | 7 +- 4 files changed, 168 insertions(+), 5 deletions(-) create mode 100644 llmfoundry/callbacks/async_eval_callback.py diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 62ffcd565c..08e9337681 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 try: + from llmfoundry.callbacks.async_eval_callback import AsyncEval from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet from llmfoundry.callbacks.fdiff_callback import FDiffMetrics from llmfoundry.callbacks.generate_callback import Generate @@ -28,4 +29,5 @@ 'EvalGauntlet', 'ModelGauntlet', 'HuggingFaceCheckpointer', + 'AsyncEval', ] diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py new file mode 100644 index 0000000000..b965e4172c --- /dev/null +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -0,0 +1,156 @@ +# Copyright 2023 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +import logging +import os +from typing import Any, Dict, Optional, Union + +from composer.core import Callback, Event, State, Time +from composer.loggers import Logger +from composer.loggers.mosaicml_logger import (MOSAICML_PLATFORM_ENV_VAR, + RUN_NAME_ENV_VAR) +from composer.utils import create_interval_scheduler, dist +from mcli.api.runs import ComputeConfig # TODO: should be available in root + +from mcli import Run, RunConfig, create_run, get_run + +log = logging.getLogger(__name__) + +MAX_RUN_NAME_LENGTH = 40 + +# Note: train parameter names. See comments if they are different from eval +REQUIRED_PARAMS_FOR_EVAL = { + 'device_eval_batch_size', + 'icl_tasks', # only required for eval + 'max_seq_len', + 'model', # models + 'save_folder', # required, but used as load_path +} +OPTIONAL_PARAMS_FOR_EVAL = { + 'dist_timeout', + 'eval_gauntlet', + 'fsdp_config', # fsdp_dict_cfg + 'icl_subset_num_batches', + 'loggers', + 'precision', + 'python_log_level', + 'seed', +} + + +def get_run_name(previous_run_name: str, count: int) -> str: + return f'eval{count}-{previous_run_name[:MAX_RUN_NAME_LENGTH]}' + + +def get_load_path(save_folder: str, + save_latest_filename: Optional[str] = None) -> str: + # TODO: check that the prefix is remote and not a local file (not supported of course) + + if not save_latest_filename: + rank = dist.get_global_rank() + save_latest_filename = f'latest-rank{rank}.pt' + + return f'{save_folder}/{save_latest_filename}' + + +class AsyncEval(Callback): + """Run the eval loop asynchronously as part of a MosaicML platform run + + Args: + interval: Union[str, int, Time]: The interval describing how often eval runs should be + launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. + Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, + :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. + """ + + def __init__( + self, + interval: Union[str, int, Time], + compute: Optional[ComputeConfig] = None, + ): + self.check_interval = create_interval_scheduler(interval) + self.compute = compute + self.count = 0 + + # Run these during init to fail fast in any of the error cases + self.current_run = self._get_current_run() + self._get_eval_parameters() + + def run_event(self, event: Event, state: State, logger: Logger) -> None: + del logger + if state.get_elapsed_duration() is not None and self.check_interval( + state, event): + new_run = self._launch_run() + logger.info(f'Launched new run {new_run.name} for eval loop') + self.count += 1 + + def _get_current_run(self) -> Run: + if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, + 'false').lower() == 'false': + raise Exception( + 'AsyncEval callback is only supported when running on the MosaicML platform' + ) + + run_name = os.environ.get(RUN_NAME_ENV_VAR, None) + if not run_name: + raise Exception( + 'RUN_NAME environment variable must be set to use the AsyncEval callback' + ) + + # allows the MapiException to be raised if the run doesn't exist + return get_run(run_name, include_details=True) + + def _get_eval_parameters(self) -> Dict[str, Any]: + cfg_params = self.current_run.submitted_config.parameters or {} + looking_for = REQUIRED_PARAMS_FOR_EVAL.copy() + + # Go through all parameters and pull out the ones needed for eval + subset_keys = {} + for key in cfg_params: + if key in OPTIONAL_PARAMS_FOR_EVAL: + subset_keys[key] = cfg_params[key] + elif key in REQUIRED_PARAMS_FOR_EVAL: + subset_keys[key] = cfg_params[key] + looking_for.remove(key) + + if looking_for: + raise Exception( + f'Missing the following required parameters for async eval: {looking_for}' + ) + + # Convert the save_folder to a load_path + subset_keys['load_path'] = get_load_path( + subset_keys.pop('save_folder'), + cfg_params.get('save_latest_filename', None)) + + # Rename the keys to match the eval script + subset_keys['models'] = [cfg_params.pop('model')] + if 'fsdp_cfg' in subset_keys: + subset_keys['fsdp_dict_cfg'] = cfg_params.pop('fsdp_cfg') + + cfg_params['run_name'] = get_run_name(self.current_run.name, self.count) + return cfg_params + + def _launch_run(self) -> Run: + cfg = self.current_run.submitted_config + default_compute = { + 'nodes': 1, + 'cluster': self.current_run.cluster, + } + params = self._get_eval_parameters() + + # TODO: This just runs an eval run, but we also want to attach the + # deployment, which would require a hf conversion and parametrizing the + # dependent_deployment in the run config + command = 'cd llm-foundry/scripts \n composer eval/eval.py $PARAMETERS' + c = RunConfig( + name=get_run_name(self.current_run.name, self.count), + image=self.current_run.image, + compute=self.compute or default_compute, + command=command, + integrations=cfg.integrations, + env_variables=cfg.env_variables, + metadata=cfg.metadata, + parameters=params, + ) + + return create_run(c) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index f027afb0ce..38e6a210e1 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -27,9 +27,9 @@ from torch.optim.optimizer import Optimizer from transformers import AutoTokenizer, PreTrainedTokenizerBase -from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, GlobalLRScaling, - HuggingFaceCheckpointer, LayerFreezing, - MonolithicCheckpointSaver, +from llmfoundry.callbacks import (AsyncEval, EvalGauntlet, FDiffMetrics, + GlobalLRScaling, HuggingFaceCheckpointer, + LayerFreezing, MonolithicCheckpointSaver, ScheduledGarbageCollector) from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion, DecoupledLionW, DecoupledLionW_8bit) @@ -118,6 +118,8 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: return EarlyStopper(**kwargs) elif name == 'hf_checkpointer': return HuggingFaceCheckpointer(**kwargs) + elif name == 'async_eval': + return AsyncEval(**kwargs) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/scripts/train/train.py b/scripts/train/train.py index e29f2c9a47..f9c16dcdcb 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -518,6 +518,9 @@ def main(cfg: DictConfig) -> Trainer: for name, callback_cfg in callback_configs.items() ] if callback_configs else [] + use_async_eval = any( + isinstance(callback, Evaluator.Async) for callback in callbacks) + # Algorithms algorithms = [ build_algorithm(str(name), algorithm_cfg) @@ -556,14 +559,14 @@ def main(cfg: DictConfig) -> Trainer: eval_gauntlet_callback = None - if icl_tasks_config is not None: + if icl_tasks_config is not None and not use_async_eval: icl_evaluators, _, eval_gauntlet_callback = build_icl_data_and_gauntlet( icl_tasks_config, eval_gauntlet_config, tokenizer, device_eval_batch_size, icl_seq_len if icl_seq_len else max_seq_len, icl_subset_num_batches) evaluators.extend(icl_evaluators) - if eval_gauntlet_callback is not None: + if eval_gauntlet_callback is not None and not use_async_eval: callbacks.append(eval_gauntlet_callback) # Build Model From acd8b2e7bd1d8c82f8290766d794ce18273ab014 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Tue, 7 Nov 2023 17:49:16 -0800 Subject: [PATCH 02/49] add very basic tests --- llmfoundry/callbacks/async_eval_callback.py | 46 ++++-- scripts/train/train.py | 3 +- tests/callbacks/test_async_eval_callback.py | 163 ++++++++++++++++++++ 3 files changed, 194 insertions(+), 18 deletions(-) create mode 100644 tests/callbacks/test_async_eval_callback.py diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index b965e4172c..9f1b3bacfa 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -1,4 +1,4 @@ -# Copyright 2023 MosaicML LLM Foundry authors +# Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 import logging import os @@ -8,7 +8,8 @@ from composer.loggers import Logger from composer.loggers.mosaicml_logger import (MOSAICML_PLATFORM_ENV_VAR, RUN_NAME_ENV_VAR) -from composer.utils import create_interval_scheduler, dist +from composer.utils import dist +from composer.utils.misc import create_interval_scheduler from mcli.api.runs import ComputeConfig # TODO: should be available in root from mcli import Run, RunConfig, create_run, get_run @@ -53,7 +54,7 @@ def get_load_path(save_folder: str, class AsyncEval(Callback): - """Run the eval loop asynchronously as part of a MosaicML platform run + """Run the eval loop asynchronously as part of a MosaicML platform run. Args: interval: Union[str, int, Time]: The interval describing how often eval runs should be @@ -73,14 +74,17 @@ def __init__( # Run these during init to fail fast in any of the error cases self.current_run = self._get_current_run() - self._get_eval_parameters() + self.get_eval_parameters( + self.current_run.submitted_config.parameters or {}, + self.current_run.name, + ) def run_event(self, event: Event, state: State, logger: Logger) -> None: del logger if state.get_elapsed_duration() is not None and self.check_interval( state, event): new_run = self._launch_run() - logger.info(f'Launched new run {new_run.name} for eval loop') + log.info(f'Launched new run {new_run.name} for eval loop') self.count += 1 def _get_current_run(self) -> Run: @@ -99,17 +103,21 @@ def _get_current_run(self) -> Run: # allows the MapiException to be raised if the run doesn't exist return get_run(run_name, include_details=True) - def _get_eval_parameters(self) -> Dict[str, Any]: - cfg_params = self.current_run.submitted_config.parameters or {} + def get_eval_parameters( + self, + parameters: Dict[str, Any], + run_name: str, + count: int = 0, + ) -> Dict[str, Any]: looking_for = REQUIRED_PARAMS_FOR_EVAL.copy() # Go through all parameters and pull out the ones needed for eval subset_keys = {} - for key in cfg_params: + for key in parameters: if key in OPTIONAL_PARAMS_FOR_EVAL: - subset_keys[key] = cfg_params[key] + subset_keys[key] = parameters[key] elif key in REQUIRED_PARAMS_FOR_EVAL: - subset_keys[key] = cfg_params[key] + subset_keys[key] = parameters[key] looking_for.remove(key) if looking_for: @@ -120,15 +128,15 @@ def _get_eval_parameters(self) -> Dict[str, Any]: # Convert the save_folder to a load_path subset_keys['load_path'] = get_load_path( subset_keys.pop('save_folder'), - cfg_params.get('save_latest_filename', None)) + parameters.get('save_latest_filename', None)) # Rename the keys to match the eval script - subset_keys['models'] = [cfg_params.pop('model')] - if 'fsdp_cfg' in subset_keys: - subset_keys['fsdp_dict_cfg'] = cfg_params.pop('fsdp_cfg') + subset_keys['models'] = [subset_keys.pop('model')] + if 'fsdp_config' in subset_keys: + subset_keys['fsdp_dict_cfg'] = subset_keys.pop('fsdp_config') - cfg_params['run_name'] = get_run_name(self.current_run.name, self.count) - return cfg_params + subset_keys['run_name'] = get_run_name(run_name, count) + return subset_keys def _launch_run(self) -> Run: cfg = self.current_run.submitted_config @@ -136,7 +144,11 @@ def _launch_run(self) -> Run: 'nodes': 1, 'cluster': self.current_run.cluster, } - params = self._get_eval_parameters() + params = self.get_eval_parameters( + cfg.parameters or {}, + self.current_run.name, + self.count, + ) # TODO: This just runs an eval run, but we also want to attach the # deployment, which would require a hf conversion and parametrizing the diff --git a/scripts/train/train.py b/scripts/train/train.py index f9c16dcdcb..28397b507a 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -26,6 +26,7 @@ from llmfoundry import (COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM, MPTForCausalLM, build_finetuning_dataloader, build_text_denoising_dataloader) +from llmfoundry.callbacks import AsyncEval from llmfoundry.data.text_data import build_text_dataloader from llmfoundry.utils.builders import (build_algorithm, build_callback, build_icl_data_and_gauntlet, @@ -519,7 +520,7 @@ def main(cfg: DictConfig) -> Trainer: ] if callback_configs else [] use_async_eval = any( - isinstance(callback, Evaluator.Async) for callback in callbacks) + isinstance(callback, AsyncEval) for callback in callbacks) # Algorithms algorithms = [ diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py new file mode 100644 index 0000000000..6312ab0863 --- /dev/null +++ b/tests/callbacks/test_async_eval_callback.py @@ -0,0 +1,163 @@ +from unittest.mock import patch + +import pytest + +from llmfoundry.callbacks import AsyncEval +from llmfoundry.callbacks.async_eval_callback import get_run_name +from mcli import Run, RunConfig, RunStatus + +RUN_NAME = 'foo_bar' + + +def test_get_run_name(): + a = get_run_name('foo', 0) + assert a == 'eval0-foo' + + b = get_run_name(50 * 'foo', 1) + assert b == 'eval1-foofoofoofoofoofoofoofoofoofoofoofoofoof' + + +@pytest.fixture(autouse=True, scope='module') +def set_os_env_vars(): + with patch.dict('os.environ', { + 'MOSAICML_PLATFORM': 'true', + 'RUN_NAME': RUN_NAME + }): + yield + + +def test_fails_when_not_on_platform(): + with patch.dict('os.environ', {'MOSAICML_PLATFORM': 'false'}): + with pytest.raises( + Exception, + match= + 'AsyncEval callback is only supported when running on the MosaicML platform' + ): + AsyncEval(interval='2ba') + + +def test_fails_when_no_run_name(): + with patch.dict('os.environ', { + 'MOSAICML_PLATFORM': 'true', + 'RUN_NAME': '' + }): + with pytest.raises( + Exception, + match= + 'RUN_NAME environment variable must be set to use the AsyncEval callback' + ): + AsyncEval(interval='2ba') + + +def test_get_eval_parameters(): + with pytest.raises( + Exception, + match='Missing the following required parameters for async eval:'): + AsyncEval.get_eval_parameters(None, {}, RUN_NAME) + + # minimal example + params = AsyncEval.get_eval_parameters( + None, { + 'device_eval_batch_size': 2, + 'icl_tasks': 'icl_task_example', + 'max_seq_len': 3, + 'model': { + 'model_example': 'model_example' + }, + 'save_folder': 'save_folder_example', + }, RUN_NAME) + assert params == { + 'device_eval_batch_size': 2, + 'icl_tasks': 'icl_task_example', + 'max_seq_len': 3, + 'load_path': 'save_folder_example/latest-rank0.pt', + 'run_name': 'eval0-foo_bar', + 'models': [{ + 'model_example': 'model_example' + }], + } + + # maximal example + params2 = AsyncEval.get_eval_parameters( + None, + { + # required + 'device_eval_batch_size': 2, + 'icl_tasks': 'icl_task_example', + 'max_seq_len': 3, + 'model': { + 'model_example': 'model_example' + }, + 'save_folder': 'save_folder_example', + # optional + 'dist_timeout': 1, + 'eval_gauntlet': 'eval_gauntlet_example', + 'fsdp_config': { + 'fsdp_cfg_example': 'fsdp_cfg_example' + }, + 'icl_subset_num_batches': 4, + 'loggers': { + 'loggers_example': 'loggers_example' + }, + 'precision': 'precision_example', + 'python_log_level': 'debug', + 'seed': 5, + # ignore this + 'ignore_this': 'ignore_this', + }, + RUN_NAME) + assert params2 == { + 'device_eval_batch_size': 2, + 'icl_tasks': 'icl_task_example', + 'max_seq_len': 3, + 'run_name': 'eval0-foo_bar', + 'dist_timeout': 1, + 'models': [{ + 'model_example': 'model_example' + }], + 'eval_gauntlet': 'eval_gauntlet_example', + 'fsdp_dict_cfg': { + 'fsdp_cfg_example': 'fsdp_cfg_example' + }, + 'icl_subset_num_batches': 4, + 'loggers': { + 'loggers_example': 'loggers_example' + }, + 'precision': 'precision_example', + 'python_log_level': 'debug', + 'seed': 5, + 'load_path': 'save_folder_example/latest-rank0.pt' + } + + +@patch('llmfoundry.callbacks.async_eval_callback.get_run', + return_value=Run( + run_uid='123', + name=RUN_NAME, + status=RunStatus.RUNNING, + created_at='2021-01-01', + updated_at='2021-01-01', + created_by='me', + priority='low', + preemptible=False, + retry_on_system_failure=True, + cluster='c1z2', + gpu_type="a100", + gpus=16, + cpus=0, + node_count=2, + latest_resumption=None, + submitted_config=RunConfig( + parameters={ + 'device_eval_batch_size': 2, + 'icl_tasks': 'icl_task_example', + 'max_seq_len': 3, + 'model': 'model_example', + 'save_folder': 'save_folder_example', + }), + )) +@patch('llmfoundry.callbacks.async_eval_callback.create_run', return_value=None) +def test_async_eval_callback_minimal(mock_get_run, mock_create_run): + callback = AsyncEval(interval='2ba') + assert callback.current_run.name == RUN_NAME + # todo From aef814b68f98e12783a8930a4d852403917398af Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 8 Nov 2023 09:17:39 -0800 Subject: [PATCH 03/49] more tests --- llmfoundry/callbacks/async_eval_callback.py | 33 ++++---- tests/callbacks/test_async_eval_callback.py | 93 ++++++++++++++------- 2 files changed, 83 insertions(+), 43 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 9f1b3bacfa..386740de95 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -61,12 +61,15 @@ class AsyncEval(Callback): launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. + compute: Optional[Union[ComputeConfig, Dict[str, Any]]]: The compute configuration to + use for the eval run. If not provided, the same cluster as the current run and a + single GPU node will be used. """ def __init__( self, interval: Union[str, int, Time], - compute: Optional[ComputeConfig] = None, + compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None, ): self.check_interval = create_interval_scheduler(interval) self.compute = compute @@ -78,25 +81,27 @@ def __init__( self.current_run.submitted_config.parameters or {}, self.current_run.name, ) + log.info( + f'Initialized AsyncEval callback. Will generate runs at interval {interval}' + ) def run_event(self, event: Event, state: State, logger: Logger) -> None: del logger if state.get_elapsed_duration() is not None and self.check_interval( state, event): - new_run = self._launch_run() - log.info(f'Launched new run {new_run.name} for eval loop') + self.launch_run() self.count += 1 def _get_current_run(self) -> Run: if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower() == 'false': - raise Exception( + raise RuntimeError( 'AsyncEval callback is only supported when running on the MosaicML platform' ) run_name = os.environ.get(RUN_NAME_ENV_VAR, None) if not run_name: - raise Exception( + raise RuntimeError( 'RUN_NAME environment variable must be set to use the AsyncEval callback' ) @@ -107,7 +112,6 @@ def get_eval_parameters( self, parameters: Dict[str, Any], run_name: str, - count: int = 0, ) -> Dict[str, Any]: looking_for = REQUIRED_PARAMS_FOR_EVAL.copy() @@ -135,20 +139,17 @@ def get_eval_parameters( if 'fsdp_config' in subset_keys: subset_keys['fsdp_dict_cfg'] = subset_keys.pop('fsdp_config') - subset_keys['run_name'] = get_run_name(run_name, count) + subset_keys['run_name'] = get_run_name(run_name, 0) return subset_keys - def _launch_run(self) -> Run: + def launch_run(self) -> Run: cfg = self.current_run.submitted_config default_compute = { 'nodes': 1, 'cluster': self.current_run.cluster, } - params = self.get_eval_parameters( - cfg.parameters or {}, - self.current_run.name, - self.count, - ) + params = self.get_eval_parameters(cfg.parameters or {}, + self.current_run.name) # TODO: This just runs an eval run, but we also want to attach the # deployment, which would require a hf conversion and parametrizing the @@ -165,4 +166,8 @@ def _launch_run(self) -> Run: parameters=params, ) - return create_run(c) + new_run = create_run(c) + log.info( + f'Launched new run {new_run.name} inside eval loop with config: \n{new_run.submitted_config}' + ) + return new_run diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index 6312ab0863..af06ff811b 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -130,34 +130,69 @@ def test_get_eval_parameters(): } +FAKE_RUN = Run( + run_uid='123', + name=RUN_NAME, + image="fake-image", + status=RunStatus.RUNNING, + created_at='2021-01-01', + updated_at='2021-01-01', + created_by='me', + priority='low', + preemptible=False, + retry_on_system_failure=True, + cluster='c1z2', + gpu_type="a100", + gpus=16, + cpus=0, + node_count=2, + latest_resumption=None, + submitted_config=RunConfig( + name=RUN_NAME, + image='fake-image', + command='echo hi', + parameters={ + 'device_eval_batch_size': 2, + 'icl_tasks': 'icl_task_example', + 'max_seq_len': 3, + 'model': 'model_example', + 'save_folder': 'save_folder_example', + }, + ), +) + + @patch('llmfoundry.callbacks.async_eval_callback.get_run', - return_value=Run( - run_uid='123', - name=RUN_NAME, - status=RunStatus.RUNNING, - created_at='2021-01-01', - updated_at='2021-01-01', - created_by='me', - priority='low', - preemptible=False, - retry_on_system_failure=True, - cluster='c1z2', - gpu_type="a100", - gpus=16, - cpus=0, - node_count=2, - latest_resumption=None, - submitted_config=RunConfig( - parameters={ - 'device_eval_batch_size': 2, - 'icl_tasks': 'icl_task_example', - 'max_seq_len': 3, - 'model': 'model_example', - 'save_folder': 'save_folder_example', - }), - )) -@patch('llmfoundry.callbacks.async_eval_callback.create_run', return_value=None) -def test_async_eval_callback_minimal(mock_get_run, mock_create_run): - callback = AsyncEval(interval='2ba') + return_value=FAKE_RUN) +@patch('llmfoundry.callbacks.async_eval_callback.create_run', + return_value=FAKE_RUN) +def test_async_eval_callback_minimal(mock_create_run, mock_get_run): + callback = AsyncEval(interval='2ba', + compute={ + 'cluster': 'c2z3', + 'nodes': 2, + }) assert callback.current_run.name == RUN_NAME - # todo + assert mock_get_run.call_count == 1 + assert mock_get_run.call_args[0][0] == RUN_NAME + + callback.count += 2 + callback.launch_run() + assert mock_create_run.call_count == 1 + + run_config_created = mock_create_run.call_args[0][0] + assert run_config_created.name == 'eval2-foo_bar' + assert run_config_created.image == 'fake-image' + assert run_config_created.command + + compute = run_config_created.compute + assert compute['cluster'] == 'c2z3' + assert compute['nodes'] == 2 + + parameters = run_config_created.parameters + assert parameters['device_eval_batch_size'] == 2 + assert parameters['icl_tasks'] == 'icl_task_example' + assert parameters['max_seq_len'] == 3 + assert parameters['load_path'] == 'save_folder_example/latest-rank0.pt' + assert parameters['models'] == ['model_example'] + assert parameters['run_name'] == 'eval0-foo_bar' # original run From 007ae906859a292a7f2f95e85034b40194eaa1e0 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 8 Nov 2023 09:44:57 -0800 Subject: [PATCH 04/49] bump mcli --- .github/workflows/pytest-gpu.yaml | 1 - setup.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/pytest-gpu.yaml b/.github/workflows/pytest-gpu.yaml index 45b49366c9..d8a8d73d06 100644 --- a/.github/workflows/pytest-gpu.yaml +++ b/.github/workflows/pytest-gpu.yaml @@ -52,7 +52,6 @@ jobs: run: | set -ex python -m pip install mosaicml-cli - mcli init --mcloud mcli version - name: Submit Run id: tests diff --git a/setup.py b/setup.py index 63aac9d752..b9a503083b 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ 'einops==0.5.0', 'omegaconf>=2.2.3,<3', 'slack-sdk<4', - 'mosaicml-cli>=0.3,<1', + 'mosaicml-cli>=0.5.20,<1', 'onnx==1.14.0', 'onnxruntime==1.15.1', 'cmake>=3.25.0,<=3.26.3', # required for triton-pre-mlir below From 6cd020f81991c8c286f7d3618a18083d235dd106 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Thu, 9 Nov 2023 11:42:51 -0800 Subject: [PATCH 05/49] woop, missing import --- scripts/train/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/train/train.py b/scripts/train/train.py index 1c17e1f9c3..925470c4e4 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -25,6 +25,7 @@ from llmfoundry import (COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM, MPTForCausalLM) +from llmfoundry.callbacks import AsyncEval from llmfoundry.data.dataloader import build_dataloader from llmfoundry.utils.builders import (build_algorithm, build_callback, build_icl_data_and_gauntlet, From 9fbe7a1460dc3e4d82b0244709b2192fd207add2 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Thu, 9 Nov 2023 11:51:02 -0800 Subject: [PATCH 06/49] instance not specified error --- llmfoundry/callbacks/async_eval_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 386740de95..df1256c733 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -145,7 +145,7 @@ def get_eval_parameters( def launch_run(self) -> Run: cfg = self.current_run.submitted_config default_compute = { - 'nodes': 1, + 'gpus': 8, 'cluster': self.current_run.cluster, } params = self.get_eval_parameters(cfg.parameters or {}, @@ -166,7 +166,7 @@ def launch_run(self) -> Run: parameters=params, ) - new_run = create_run(c) + new_run = create_run(c, timeout=60) log.info( f'Launched new run {new_run.name} inside eval loop with config: \n{new_run.submitted_config}' ) From 547ec218158176d42ea15a6bef76dbff3d1bb449 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Thu, 9 Nov 2023 14:28:58 -0800 Subject: [PATCH 07/49] fixes --- llmfoundry/callbacks/async_eval_callback.py | 42 ++++++--- tests/callbacks/__init__.py | 0 tests/callbacks/test_async_eval_callback.py | 99 +++++++++++++-------- 3 files changed, 93 insertions(+), 48 deletions(-) create mode 100644 tests/callbacks/__init__.py diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index df1256c733..0b44d053fd 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging import os -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union from composer.core import Callback, Event, State, Time from composer.loggers import Logger @@ -18,12 +18,12 @@ MAX_RUN_NAME_LENGTH = 40 -# Note: train parameter names. See comments if they are different from eval REQUIRED_PARAMS_FOR_EVAL = { 'device_eval_batch_size', - 'icl_tasks', # only required for eval + 'icl_tasks', # only required for eval, may not be specified in pure training 'max_seq_len', - 'model', # models + 'model', # converted into models + 'tokenizer', # converted into models 'save_folder', # required, but used as load_path } OPTIONAL_PARAMS_FOR_EVAL = { @@ -39,7 +39,9 @@ def get_run_name(previous_run_name: str, count: int) -> str: - return f'eval{count}-{previous_run_name[:MAX_RUN_NAME_LENGTH]}' + *name_without_uuid_suffix, _ = previous_run_name.split('-') + name_suffix = '-'.join(name_without_uuid_suffix)[:MAX_RUN_NAME_LENGTH] + return f'eval{count}-{name_suffix}' def get_load_path(save_folder: str, @@ -53,6 +55,24 @@ def get_load_path(save_folder: str, return f'{save_folder}/{save_latest_filename}' +def get_eval_models_dict( + model: Dict[str, Any], + tokenizer: Dict[str, Any], +) -> List[Dict[str, Any]]: + name = model.get('name') + + cfg_overrides = model.pop('cfg_overrides', {}) + for key in cfg_overrides: + model[key] = cfg_overrides[key] + + new_model = {'model_name': name, 'model': model} + + if tokenizer: + new_model['tokenizer'] = tokenizer + + return [new_model] + + class AsyncEval(Callback): """Run the eval loop asynchronously as part of a MosaicML platform run. @@ -74,6 +94,7 @@ def __init__( self.check_interval = create_interval_scheduler(interval) self.compute = compute self.count = 0 + self.last_launch: Optional[Time] = None # Run these during init to fail fast in any of the error cases self.current_run = self._get_current_run() @@ -88,8 +109,10 @@ def __init__( def run_event(self, event: Event, state: State, logger: Logger) -> None: del logger if state.get_elapsed_duration() is not None and self.check_interval( - state, event): + state, event) and self.last_launch != state.timestamp.batch: self.launch_run() + + self.last_launch = state.timestamp.batch self.count += 1 def _get_current_run(self) -> Run: @@ -134,10 +157,9 @@ def get_eval_parameters( subset_keys.pop('save_folder'), parameters.get('save_latest_filename', None)) - # Rename the keys to match the eval script - subset_keys['models'] = [subset_keys.pop('model')] - if 'fsdp_config' in subset_keys: - subset_keys['fsdp_dict_cfg'] = subset_keys.pop('fsdp_config') + # Create new eval models list + subset_keys['models'] = get_eval_models_dict( + subset_keys.pop('model'), subset_keys.pop('tokenizer')) subset_keys['run_name'] = get_run_name(run_name, 0) return subset_keys diff --git a/tests/callbacks/__init__.py b/tests/callbacks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index af06ff811b..caf5e72868 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -1,19 +1,21 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + from unittest.mock import patch import pytest -from llmfoundry.callbacks import AsyncEval -from llmfoundry.callbacks.async_eval_callback import get_run_name +from llmfoundry.callbacks.async_eval_callback import AsyncEval, get_run_name from mcli import Run, RunConfig, RunStatus -RUN_NAME = 'foo_bar' +RUN_NAME = 'foo_bar-1234' def test_get_run_name(): - a = get_run_name('foo', 0) + a = get_run_name('foo-1234', 0) assert a == 'eval0-foo' - b = get_run_name(50 * 'foo', 1) + b = get_run_name(50 * 'foo-1234', 1) assert b == 'eval1-foofoofoofoofoofoofoofoofoofoofoofoofoof' @@ -49,6 +51,25 @@ def test_fails_when_no_run_name(): AsyncEval(interval='2ba') +BASIC_PARAMS = { + 'device_eval_batch_size': 2, + 'icl_tasks': 'icl_task_example', + 'max_seq_len': 3, + 'model': { + 'name': 'model_example', + 'cfg_overrides': { + 'attn_config': { + 'foo': 'bar' + } + } + }, + 'tokenizer': { + 'tokenizer_example': 'tokenizer_example', + }, + 'save_folder': 'save_folder_example', +} + + def test_get_eval_parameters(): with pytest.raises( Exception, @@ -56,24 +77,29 @@ def test_get_eval_parameters(): AsyncEval.get_eval_parameters(None, {}, RUN_NAME) # minimal example - params = AsyncEval.get_eval_parameters( - None, { - 'device_eval_batch_size': 2, - 'icl_tasks': 'icl_task_example', - 'max_seq_len': 3, - 'model': { - 'model_example': 'model_example' - }, - 'save_folder': 'save_folder_example', - }, RUN_NAME) + params = AsyncEval.get_eval_parameters(None, BASIC_PARAMS, RUN_NAME) assert params == { - 'device_eval_batch_size': 2, - 'icl_tasks': 'icl_task_example', - 'max_seq_len': 3, - 'load_path': 'save_folder_example/latest-rank0.pt', - 'run_name': 'eval0-foo_bar', + 'device_eval_batch_size': + 2, + 'icl_tasks': + 'icl_task_example', + 'max_seq_len': + 3, + 'load_path': + 'save_folder_example/latest-rank0.pt', + 'run_name': + 'eval0-foo_bar', 'models': [{ - 'model_example': 'model_example' + 'model_name': 'model_example', + 'model': { + 'name': 'model_example', + 'attn_config': { + 'foo': 'bar' + }, + }, + 'tokenizer': { + 'tokenizer_example': 'tokenizer_example' + }, }], } @@ -82,13 +108,7 @@ def test_get_eval_parameters(): None, { # required - 'device_eval_batch_size': 2, - 'icl_tasks': 'icl_task_example', - 'max_seq_len': 3, - 'model': { - 'model_example': 'model_example' - }, - 'save_folder': 'save_folder_example', + **BASIC_PARAMS, # optional 'dist_timeout': 1, 'eval_gauntlet': 'eval_gauntlet_example', @@ -113,7 +133,16 @@ def test_get_eval_parameters(): 'run_name': 'eval0-foo_bar', 'dist_timeout': 1, 'models': [{ - 'model_example': 'model_example' + 'model_name': 'model_example', + 'model': { + 'name': 'model_example', + 'attn_config': { + 'foo': 'bar' + }, + }, + 'tokenizer': { + 'tokenizer_example': 'tokenizer_example' + }, }], 'eval_gauntlet': 'eval_gauntlet_example', 'fsdp_dict_cfg': { @@ -133,7 +162,7 @@ def test_get_eval_parameters(): FAKE_RUN = Run( run_uid='123', name=RUN_NAME, - image="fake-image", + image='fake-image', status=RunStatus.RUNNING, created_at='2021-01-01', updated_at='2021-01-01', @@ -142,7 +171,7 @@ def test_get_eval_parameters(): preemptible=False, retry_on_system_failure=True, cluster='c1z2', - gpu_type="a100", + gpu_type='a100', gpus=16, cpus=0, node_count=2, @@ -151,13 +180,7 @@ def test_get_eval_parameters(): name=RUN_NAME, image='fake-image', command='echo hi', - parameters={ - 'device_eval_batch_size': 2, - 'icl_tasks': 'icl_task_example', - 'max_seq_len': 3, - 'model': 'model_example', - 'save_folder': 'save_folder_example', - }, + parameters=BASIC_PARAMS, ), ) From 9819456bf7443b86682932dee5da9e5ac6dd4878 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Thu, 9 Nov 2023 14:42:42 -0800 Subject: [PATCH 08/49] fix typing --- tests/callbacks/test_async_eval_callback.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index caf5e72868..274d80b9d8 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -1,7 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import patch +import datetime +from unittest.mock import MagicMock, patch import pytest @@ -74,10 +75,13 @@ def test_get_eval_parameters(): with pytest.raises( Exception, match='Missing the following required parameters for async eval:'): - AsyncEval.get_eval_parameters(None, {}, RUN_NAME) + AsyncEval.get_eval_parameters(None, {}, RUN_NAME) # type: ignore # minimal example - params = AsyncEval.get_eval_parameters(None, BASIC_PARAMS, RUN_NAME) + params = AsyncEval.get_eval_parameters( + None, # type: ignore + BASIC_PARAMS, + RUN_NAME) assert params == { 'device_eval_batch_size': 2, @@ -105,7 +109,7 @@ def test_get_eval_parameters(): # maximal example params2 = AsyncEval.get_eval_parameters( - None, + None, # type: ignore { # required **BASIC_PARAMS, @@ -164,8 +168,8 @@ def test_get_eval_parameters(): name=RUN_NAME, image='fake-image', status=RunStatus.RUNNING, - created_at='2021-01-01', - updated_at='2021-01-01', + created_at=datetime.datetime(2021, 1, 1), + updated_at=datetime.datetime(2021, 1, 1), created_by='me', priority='low', preemptible=False, @@ -175,7 +179,7 @@ def test_get_eval_parameters(): gpus=16, cpus=0, node_count=2, - latest_resumption=None, + latest_resumption=None, # type: ignore submitted_config=RunConfig( name=RUN_NAME, image='fake-image', @@ -189,7 +193,8 @@ def test_get_eval_parameters(): return_value=FAKE_RUN) @patch('llmfoundry.callbacks.async_eval_callback.create_run', return_value=FAKE_RUN) -def test_async_eval_callback_minimal(mock_create_run, mock_get_run): +def test_async_eval_callback_minimal(mock_create_run: MagicMock, + mock_get_run: MagicMock): callback = AsyncEval(interval='2ba', compute={ 'cluster': 'c2z3', From ba871d776dcb621e3bfc157eb09b4fe3e59c18c0 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Thu, 9 Nov 2023 14:50:31 -0800 Subject: [PATCH 09/49] small testing fixes --- llmfoundry/callbacks/async_eval_callback.py | 7 ++++--- tests/callbacks/test_async_eval_callback.py | 17 ++++++++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 0b44d053fd..135aac974e 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -10,7 +10,8 @@ RUN_NAME_ENV_VAR) from composer.utils import dist from composer.utils.misc import create_interval_scheduler -from mcli.api.runs import ComputeConfig # TODO: should be available in root +from mcli.api.runs import \ + ComputeConfig # TODO: available in root in mcli 0.5.27+ from mcli import Run, RunConfig, create_run, get_run @@ -29,7 +30,7 @@ OPTIONAL_PARAMS_FOR_EVAL = { 'dist_timeout', 'eval_gauntlet', - 'fsdp_config', # fsdp_dict_cfg + 'fsdp_config', 'icl_subset_num_batches', 'loggers', 'precision', @@ -40,7 +41,7 @@ def get_run_name(previous_run_name: str, count: int) -> str: *name_without_uuid_suffix, _ = previous_run_name.split('-') - name_suffix = '-'.join(name_without_uuid_suffix)[:MAX_RUN_NAME_LENGTH] + name_suffix = ('-'.join(name_without_uuid_suffix))[:MAX_RUN_NAME_LENGTH] return f'eval{count}-{name_suffix}' diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index 274d80b9d8..65571f7b77 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -16,7 +16,7 @@ def test_get_run_name(): a = get_run_name('foo-1234', 0) assert a == 'eval0-foo' - b = get_run_name(50 * 'foo-1234', 1) + b = get_run_name(50 * 'foo' + '-1234', 1) assert b == 'eval1-foofoofoofoofoofoofoofoofoofoofoofoofoof' @@ -149,7 +149,7 @@ def test_get_eval_parameters(): }, }], 'eval_gauntlet': 'eval_gauntlet_example', - 'fsdp_dict_cfg': { + 'fsdp_config': { 'fsdp_cfg_example': 'fsdp_cfg_example' }, 'icl_subset_num_batches': 4, @@ -222,5 +222,16 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, assert parameters['icl_tasks'] == 'icl_task_example' assert parameters['max_seq_len'] == 3 assert parameters['load_path'] == 'save_folder_example/latest-rank0.pt' - assert parameters['models'] == ['model_example'] + assert parameters['models'] == [{ + 'model_name': 'model_example', + 'model': { + 'name': 'model_example', + 'attn_config': { + 'foo': 'bar' + } + }, + 'tokenizer': { + 'tokenizer_example': 'tokenizer_example' + } + }] assert parameters['run_name'] == 'eval0-foo_bar' # original run From 21e988022613bf0b235229085836bcf4701b23a2 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Thu, 9 Nov 2023 15:09:00 -0800 Subject: [PATCH 10/49] launch new run only on main process --- llmfoundry/callbacks/async_eval_callback.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 135aac974e..e07d2fa869 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -109,8 +109,12 @@ def __init__( def run_event(self, event: Event, state: State, logger: Logger) -> None: del logger - if state.get_elapsed_duration() is not None and self.check_interval( - state, event) and self.last_launch != state.timestamp.batch: + if all([ + state.get_elapsed_duration() is not None, + self.check_interval(state, event), self.last_launch + != state.timestamp.batch, + dist.get_global_rank() == 0 + ]): self.launch_run() self.last_launch = state.timestamp.batch From 47a8255c486a3080b498599ff27b918f34f7fcb4 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Thu, 9 Nov 2023 15:46:01 -0800 Subject: [PATCH 11/49] logger name --- llmfoundry/callbacks/async_eval_callback.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index e07d2fa869..02124b6d44 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -162,6 +162,13 @@ def get_eval_parameters( subset_keys.pop('save_folder'), parameters.get('save_latest_filename', None)) + # Update the loggers to use the training run name + for logger, config in subset_keys.get('loggers', []): + if logger == 'wandb': + config['name'] = config.get('name', run_name) + elif logger == 'mlflow': + config['run_name'] = config.get('run_name', run_name) + # Create new eval models list subset_keys['models'] = get_eval_models_dict( subset_keys.pop('model'), subset_keys.pop('tokenizer')) From 08c24be7fba4b8a9bc6dceb7ebad0afa5cd89ef5 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Thu, 9 Nov 2023 15:59:27 -0800 Subject: [PATCH 12/49] items --- llmfoundry/callbacks/async_eval_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 02124b6d44..d7f6af4f53 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -163,7 +163,7 @@ def get_eval_parameters( parameters.get('save_latest_filename', None)) # Update the loggers to use the training run name - for logger, config in subset_keys.get('loggers', []): + for logger, config in subset_keys.get('loggers', {}).items(): if logger == 'wandb': config['name'] = config.get('name', run_name) elif logger == 'mlflow': From bf415f0939b26d5a56e0e636d63358d8669f109e Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Thu, 9 Nov 2023 16:59:37 -0800 Subject: [PATCH 13/49] format --- llmfoundry/callbacks/async_eval_callback.py | 4 ++-- tests/callbacks/__init__.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index d7f6af4f53..29d7ad59e4 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -111,8 +111,8 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: del logger if all([ state.get_elapsed_duration() is not None, - self.check_interval(state, event), self.last_launch - != state.timestamp.batch, + self.check_interval(state, event), + self.last_launch != state.timestamp.batch, dist.get_global_rank() == 0 ]): self.launch_run() diff --git a/tests/callbacks/__init__.py b/tests/callbacks/__init__.py index e69de29bb2..05d33100f5 100644 --- a/tests/callbacks/__init__.py +++ b/tests/callbacks/__init__.py @@ -0,0 +1,3 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + From 5616ae46c76b334be29e3d6ff92701e309bcfe61 Mon Sep 17 00:00:00 2001 From: Anna Date: Fri, 10 Nov 2023 09:37:16 -0800 Subject: [PATCH 14/49] Update llmfoundry/callbacks/async_eval_callback.py Co-authored-by: Mihir Patel --- llmfoundry/callbacks/async_eval_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 29d7ad59e4..c04e1e0130 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -133,7 +133,7 @@ def _get_current_run(self) -> Run: 'RUN_NAME environment variable must be set to use the AsyncEval callback' ) - # allows the MapiException to be raised if the run doesn't exist + # Allows the MapiException to be raised if the run doesn't exist return get_run(run_name, include_details=True) def get_eval_parameters( From bc1647a2c3766824934722f99da57364098e8d35 Mon Sep 17 00:00:00 2001 From: Anna Date: Fri, 10 Nov 2023 09:37:33 -0800 Subject: [PATCH 15/49] Update llmfoundry/callbacks/async_eval_callback.py Co-authored-by: Mihir Patel --- llmfoundry/callbacks/async_eval_callback.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index c04e1e0130..781a1b212f 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -1,5 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 + import logging import os from typing import Any, Dict, List, Optional, Union From 3358837be94af9d38020c0963b49fa8a1a036f6f Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Fri, 10 Nov 2023 09:55:25 -0800 Subject: [PATCH 16/49] feedback --- llmfoundry/callbacks/async_eval_callback.py | 18 ++++++++++++------ setup.py | 2 +- tests/callbacks/__init__.py | 1 - 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 781a1b212f..bc5c770bb3 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -11,10 +11,8 @@ RUN_NAME_ENV_VAR) from composer.utils import dist from composer.utils.misc import create_interval_scheduler -from mcli.api.runs import \ - ComputeConfig # TODO: available in root in mcli 0.5.27+ -from mcli import Run, RunConfig, create_run, get_run +from mcli import ComputeConfig, Run, RunConfig, create_run, get_run log = logging.getLogger(__name__) @@ -42,7 +40,14 @@ def get_run_name(previous_run_name: str, count: int) -> str: *name_without_uuid_suffix, _ = previous_run_name.split('-') - name_suffix = ('-'.join(name_without_uuid_suffix))[:MAX_RUN_NAME_LENGTH] + name_suffix = ('-'.join(name_without_uuid_suffix)) + + if len(name_suffix) > MAX_RUN_NAME_LENGTH: + log.warning( + f'Training run name {name_suffix} may be too long, truncating to {MAX_RUN_NAME_LENGTH} characters' + ) + name_suffix = name_suffix[:MAX_RUN_NAME_LENGTH] + return f'eval{count}-{name_suffix}' @@ -190,7 +195,7 @@ def launch_run(self) -> Run: # deployment, which would require a hf conversion and parametrizing the # dependent_deployment in the run config command = 'cd llm-foundry/scripts \n composer eval/eval.py $PARAMETERS' - c = RunConfig( + run_config = RunConfig( name=get_run_name(self.current_run.name, self.count), image=self.current_run.image, compute=self.compute or default_compute, @@ -201,7 +206,8 @@ def launch_run(self) -> Run: parameters=params, ) - new_run = create_run(c, timeout=60) + # Increase default timeout of 10s just in case + new_run = create_run(run_config, timeout=20) log.info( f'Launched new run {new_run.name} inside eval loop with config: \n{new_run.submitted_config}' ) diff --git a/setup.py b/setup.py index 6f96bf4a5c..ba383a4d7f 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ 'einops==0.5.0', 'omegaconf>=2.2.3,<3', 'slack-sdk<4', - 'mosaicml-cli>=0.5.20,<1', + 'mosaicml-cli>=0.5.27,<1', 'onnx==1.14.0', 'onnxruntime==1.15.1', 'cmake>=3.25.0,<=3.26.3', # required for triton-pre-mlir below diff --git a/tests/callbacks/__init__.py b/tests/callbacks/__init__.py index 05d33100f5..f6c1f9f3ab 100644 --- a/tests/callbacks/__init__.py +++ b/tests/callbacks/__init__.py @@ -1,3 +1,2 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - From 28e47df403f9ef489931c1c478d2e7501b4fe197 Mon Sep 17 00:00:00 2001 From: Anna Date: Mon, 13 Nov 2023 09:39:50 -0800 Subject: [PATCH 17/49] Apply suggestions from code review Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/callbacks/async_eval_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index bc5c770bb3..6a19c9ef50 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -55,7 +55,7 @@ def get_load_path(save_folder: str, save_latest_filename: Optional[str] = None) -> str: # TODO: check that the prefix is remote and not a local file (not supported of course) - if not save_latest_filename: + if save_latest_filename is None: rank = dist.get_global_rank() save_latest_filename = f'latest-rank{rank}.pt' @@ -90,7 +90,7 @@ class AsyncEval(Callback): :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. compute: Optional[Union[ComputeConfig, Dict[str, Any]]]: The compute configuration to use for the eval run. If not provided, the same cluster as the current run and a - single GPU node will be used. + single, full GPU node will be used. """ def __init__( From 78cc0b8ecd9cde432ae2efd1c7d98a8dbef403ae Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Mon, 13 Nov 2023 09:59:44 -0800 Subject: [PATCH 18/49] small updates --- llmfoundry/callbacks/async_eval_callback.py | 28 ++++++++++++++------- tests/callbacks/test_async_eval_callback.py | 3 ++- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 6a19c9ef50..14c5a6aade 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -1,6 +1,12 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +""" +Run the eval loop asynchronously as part of a MosaicML platform run. + +This callback is currently experimental. The API may change in the future. +""" + import logging import os from typing import Any, Dict, List, Optional, Union @@ -42,6 +48,7 @@ def get_run_name(previous_run_name: str, count: int) -> str: *name_without_uuid_suffix, _ = previous_run_name.split('-') name_suffix = ('-'.join(name_without_uuid_suffix)) + # A run name that is too long will fail a createRun call if len(name_suffix) > MAX_RUN_NAME_LENGTH: log.warning( f'Training run name {name_suffix} may be too long, truncating to {MAX_RUN_NAME_LENGTH} characters' @@ -68,7 +75,7 @@ def get_eval_models_dict( ) -> List[Dict[str, Any]]: name = model.get('name') - cfg_overrides = model.pop('cfg_overrides', {}) + cfg_overrides = model.pop('config_overrides', {}) for key in cfg_overrides: model[key] = cfg_overrides[key] @@ -83,6 +90,8 @@ def get_eval_models_dict( class AsyncEval(Callback): """Run the eval loop asynchronously as part of a MosaicML platform run. + This callback is currently experimental. The API may change in the future. + Args: interval: Union[str, int, Time]: The interval describing how often eval runs should be launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. @@ -117,8 +126,8 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: del logger if all([ state.get_elapsed_duration() is not None, - self.check_interval(state, event), - self.last_launch != state.timestamp.batch, + self.check_interval(state, event), self.last_launch + != state.timestamp.batch, dist.get_global_rank() == 0 ]): self.launch_run() @@ -168,12 +177,13 @@ def get_eval_parameters( subset_keys.pop('save_folder'), parameters.get('save_latest_filename', None)) - # Update the loggers to use the training run name - for logger, config in subset_keys.get('loggers', {}).items(): - if logger == 'wandb': - config['name'] = config.get('name', run_name) - elif logger == 'mlflow': - config['run_name'] = config.get('run_name', run_name) + # TODO: Update this and parametrize step when the composer loggers support + # it. For now, eval runs will be logged to separate experiment tracker runs + # for logger, config in subset_keys.get('loggers', {}).items(): + # if logger == 'wandb': + # config['name'] = config.get('name', run_name) + # elif logger == 'mlflow': + # config['run_name'] = config.get('run_name', run_name) # Create new eval models list subset_keys['models'] = get_eval_models_dict( diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index 65571f7b77..b595405b62 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -16,6 +16,7 @@ def test_get_run_name(): a = get_run_name('foo-1234', 0) assert a == 'eval0-foo' + # Run name should be truncated b = get_run_name(50 * 'foo' + '-1234', 1) assert b == 'eval1-foofoofoofoofoofoofoofoofoofoofoofoofoof' @@ -58,7 +59,7 @@ def test_fails_when_no_run_name(): 'max_seq_len': 3, 'model': { 'name': 'model_example', - 'cfg_overrides': { + 'config_overrides': { 'attn_config': { 'foo': 'bar' } From b58ccf9c33d6d1b3227762fb7a8c1552a350a56d Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Mon, 13 Nov 2023 10:41:47 -0800 Subject: [PATCH 19/49] use parameters from train.py to capture overrides and mounted parameters file --- llmfoundry/callbacks/async_eval_callback.py | 17 ++++---- llmfoundry/utils/builders.py | 8 +++- scripts/train/train.py | 2 +- tests/callbacks/test_async_eval_callback.py | 45 ++++++++++----------- tests/test_builders.py | 23 +++++++---- 5 files changed, 52 insertions(+), 43 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 14c5a6aade..c8581049d9 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -1,8 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -""" -Run the eval loop asynchronously as part of a MosaicML platform run. +"""Run the eval loop asynchronously as part of a MosaicML platform run. This callback is currently experimental. The API may change in the future. """ @@ -93,6 +92,7 @@ class AsyncEval(Callback): This callback is currently experimental. The API may change in the future. Args: + training_config: Dict[str, Any]: The config from the training run interval: Union[str, int, Time]: The interval describing how often eval runs should be launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, @@ -104,9 +104,11 @@ class AsyncEval(Callback): def __init__( self, + training_config: Dict[str, Any], interval: Union[str, int, Time], compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None, ): + self.training_config = training_config self.check_interval = create_interval_scheduler(interval) self.compute = compute self.count = 0 @@ -114,10 +116,7 @@ def __init__( # Run these during init to fail fast in any of the error cases self.current_run = self._get_current_run() - self.get_eval_parameters( - self.current_run.submitted_config.parameters or {}, - self.current_run.name, - ) + self.get_eval_parameters(training_config, self.current_run.name) log.info( f'Initialized AsyncEval callback. Will generate runs at interval {interval}' ) @@ -126,8 +125,8 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: del logger if all([ state.get_elapsed_duration() is not None, - self.check_interval(state, event), self.last_launch - != state.timestamp.batch, + self.check_interval(state, event), + self.last_launch != state.timestamp.batch, dist.get_global_rank() == 0 ]): self.launch_run() @@ -198,7 +197,7 @@ def launch_run(self) -> Run: 'gpus': 8, 'cluster': self.current_run.cluster, } - params = self.get_eval_parameters(cfg.parameters or {}, + params = self.get_eval_parameters(self.training_config, self.current_run.name) # TODO: This just runs an eval run, but we also want to attach the diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 9c3e94c29b..fd446df18c 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -73,7 +73,11 @@ def build_icl_data_and_gauntlet( return icl_evaluators, logger_keys, eval_gauntlet_cb -def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: +def build_callback( + name: str, + kwargs: Dict[str, Any], + config: Dict[str, Any], +) -> Callback: if name == 'lr_monitor': return LRMonitor() elif name == 'memory_monitor': @@ -119,7 +123,7 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: elif name == 'hf_checkpointer': return HuggingFaceCheckpointer(**kwargs) elif name == 'async_eval': - return AsyncEval(**kwargs) + return AsyncEval(**kwargs, training_config=config) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/scripts/train/train.py b/scripts/train/train.py index 925470c4e4..64b6d7ba1c 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -504,7 +504,7 @@ def main(cfg: DictConfig) -> Trainer: # Callbacks callbacks: List[Callback] = [ - build_callback(str(name), callback_cfg) + build_callback(str(name), callback_cfg, cfg) for name, callback_cfg in callback_configs.items() ] if callback_configs else [] diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index b595405b62..acdb5f5375 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -10,6 +10,23 @@ from mcli import Run, RunConfig, RunStatus RUN_NAME = 'foo_bar-1234' +BASIC_PARAMS = { + 'device_eval_batch_size': 2, + 'icl_tasks': 'icl_task_example', + 'max_seq_len': 3, + 'model': { + 'name': 'model_example', + 'config_overrides': { + 'attn_config': { + 'foo': 'bar' + } + } + }, + 'tokenizer': { + 'tokenizer_example': 'tokenizer_example', + }, + 'save_folder': 'save_folder_example', +} def test_get_run_name(): @@ -37,7 +54,7 @@ def test_fails_when_not_on_platform(): match= 'AsyncEval callback is only supported when running on the MosaicML platform' ): - AsyncEval(interval='2ba') + AsyncEval(BASIC_PARAMS, interval='2ba') def test_fails_when_no_run_name(): @@ -50,26 +67,7 @@ def test_fails_when_no_run_name(): match= 'RUN_NAME environment variable must be set to use the AsyncEval callback' ): - AsyncEval(interval='2ba') - - -BASIC_PARAMS = { - 'device_eval_batch_size': 2, - 'icl_tasks': 'icl_task_example', - 'max_seq_len': 3, - 'model': { - 'name': 'model_example', - 'config_overrides': { - 'attn_config': { - 'foo': 'bar' - } - } - }, - 'tokenizer': { - 'tokenizer_example': 'tokenizer_example', - }, - 'save_folder': 'save_folder_example', -} + AsyncEval(BASIC_PARAMS, interval='2ba') def test_get_eval_parameters(): @@ -185,7 +183,7 @@ def test_get_eval_parameters(): name=RUN_NAME, image='fake-image', command='echo hi', - parameters=BASIC_PARAMS, + parameters={}, ), ) @@ -196,7 +194,8 @@ def test_get_eval_parameters(): return_value=FAKE_RUN) def test_async_eval_callback_minimal(mock_create_run: MagicMock, mock_get_run: MagicMock): - callback = AsyncEval(interval='2ba', + callback = AsyncEval(BASIC_PARAMS, + interval='2ba', compute={ 'cluster': 'c2z3', 'nodes': 2, diff --git a/tests/test_builders.py b/tests/test_builders.py index 0d24d2154f..3fd638e9f1 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -37,7 +37,7 @@ def test_tokenizer_builder(tokenizer_name: str, tokenizer_kwargs: dict): def test_build_callback_fails(): with pytest.raises(ValueError): - build_callback('nonexistent_callback', {}) + build_callback('nonexistent_callback', {}, {}) @pytest.mark.parametrize( @@ -53,12 +53,15 @@ def test_build_generate_callback( autospec=True) as mock_generate: mock_generate.return_value = None build_callback( - 'generate_callback', { + 'generate_callback', + { 'prompts': ['hello'], interval_key: interval_value, 'foo': 'bar', 'something': 'else', - }) + }, + {}, + ) assert mock_generate.call_count == 1 _, _, kwargs = mock_generate.mock_calls[0] @@ -73,8 +76,12 @@ def test_build_generate_callback_unspecified_interval(): with mock.patch.object(Generate, '__init__', autospec=True) as mock_generate: mock_generate.return_value = None - build_callback('generate_callback', { - 'prompts': ['hello'], - 'foo': 'bar', - 'something': 'else', - }) + build_callback( + 'generate_callback', + { + 'prompts': ['hello'], + 'foo': 'bar', + 'something': 'else', + }, + {}, + ) From d85ee5ed6984e57953e9adcf60c46f80c702be53 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Tue, 14 Nov 2023 11:49:54 -0800 Subject: [PATCH 20/49] config_overrides --- llmfoundry/callbacks/async_eval_callback.py | 9 ++------- tests/callbacks/test_async_eval_callback.py | 20 +++++++++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index c8581049d9..8a9bf7b4cc 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -73,11 +73,6 @@ def get_eval_models_dict( tokenizer: Dict[str, Any], ) -> List[Dict[str, Any]]: name = model.get('name') - - cfg_overrides = model.pop('config_overrides', {}) - for key in cfg_overrides: - model[key] = cfg_overrides[key] - new_model = {'model_name': name, 'model': model} if tokenizer: @@ -125,8 +120,8 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: del logger if all([ state.get_elapsed_duration() is not None, - self.check_interval(state, event), - self.last_launch != state.timestamp.batch, + self.check_interval(state, event), self.last_launch + != state.timestamp.batch, dist.get_global_rank() == 0 ]): self.launch_run() diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index acdb5f5375..d5169d2327 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -96,8 +96,10 @@ def test_get_eval_parameters(): 'model_name': 'model_example', 'model': { 'name': 'model_example', - 'attn_config': { - 'foo': 'bar' + 'config_overrides': { + 'attn_config': { + 'foo': 'bar' + }, }, }, 'tokenizer': { @@ -139,8 +141,10 @@ def test_get_eval_parameters(): 'model_name': 'model_example', 'model': { 'name': 'model_example', - 'attn_config': { - 'foo': 'bar' + 'config_overrides': { + 'attn_config': { + 'foo': 'bar' + }, }, }, 'tokenizer': { @@ -226,9 +230,11 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, 'model_name': 'model_example', 'model': { 'name': 'model_example', - 'attn_config': { - 'foo': 'bar' - } + 'config_overrides': { + 'attn_config': { + 'foo': 'bar' + }, + }, }, 'tokenizer': { 'tokenizer_example': 'tokenizer_example' From 194774d8ac4cb5e710a74ac00dd9fedee86f97db Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 15 Nov 2023 11:17:57 -0800 Subject: [PATCH 21/49] updates --- llmfoundry/callbacks/async_eval_callback.py | 208 ++++++++++++-------- tests/callbacks/test_async_eval_callback.py | 43 ++-- 2 files changed, 150 insertions(+), 101 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 8a9bf7b4cc..54f9ca2b72 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -8,9 +8,10 @@ import logging import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union -from composer.core import Callback, Event, State, Time +from composer.callbacks import CheckpointSaver +from composer.core import Callback, Event, State, Time, TimeUnit from composer.loggers import Logger from composer.loggers.mosaicml_logger import (MOSAICML_PLATFORM_ENV_VAR, RUN_NAME_ENV_VAR) @@ -29,7 +30,6 @@ 'max_seq_len', 'model', # converted into models 'tokenizer', # converted into models - 'save_folder', # required, but used as load_path } OPTIONAL_PARAMS_FOR_EVAL = { 'dist_timeout', @@ -42,9 +42,20 @@ 'seed', } +NAME_PREFIX = 'eval' -def get_run_name(previous_run_name: str, count: int) -> str: - *name_without_uuid_suffix, _ = previous_run_name.split('-') + +def get_run_name(training_run_name: str, current_interval: str) -> str: + """Get the new eval run name. + + Args: + training_run_name: The name of the current training run + current_interval: The current interval string of the training run + + Returns: + The new run name + """ + *name_without_uuid_suffix, _ = training_run_name.split('-') name_suffix = ('-'.join(name_without_uuid_suffix)) # A run name that is too long will fail a createRun call @@ -54,31 +65,85 @@ def get_run_name(previous_run_name: str, count: int) -> str: ) name_suffix = name_suffix[:MAX_RUN_NAME_LENGTH] - return f'eval{count}-{name_suffix}' + return '-'.join([NAME_PREFIX, current_interval, name_suffix]) -def get_load_path(save_folder: str, - save_latest_filename: Optional[str] = None) -> str: - # TODO: check that the prefix is remote and not a local file (not supported of course) +def get_latest_checkpoint(event: Event, state: State) -> Optional[str]: + """Get the latest checkpoint from the training run. - if save_latest_filename is None: - rank = dist.get_global_rank() - save_latest_filename = f'latest-rank{rank}.pt' + Args: + state: The current state of the training run - return f'{save_folder}/{save_latest_filename}' + Returns: + The path to the latest checkpoint, or None if there is not a latest checkpoint + """ + checkpointer = None + for callback in state.callbacks: + if isinstance(callback, CheckpointSaver): + checkpointer = callback + break + if not checkpointer: + return None -def get_eval_models_dict( - model: Dict[str, Any], - tokenizer: Dict[str, Any], -) -> List[Dict[str, Any]]: - name = model.get('name') - new_model = {'model_name': name, 'model': model} + if event.name == Event.FIT_END: + # Use the latest symlink for the end of training + return checkpointer.latest_filename - if tokenizer: - new_model['tokenizer'] = tokenizer + if not checkpointer.saved_checkpoints: + return None + + return checkpointer.saved_checkpoints[-1] + + +def get_eval_parameters( + parameters: Dict[str, Any], + checkpoint: str, + training_run_name: str, +) -> Dict[str, Any]: + """Get the parameters needed for the eval run. + + Args: + parameters: The parameters from the training run + checkpoint: The path to the latest checkpoint + training_run_name: The name of the training run + + Returns: + The parameters needed for the eval run as a dict + """ + looking_for = REQUIRED_PARAMS_FOR_EVAL.copy() + + # Go through all parameters and pull out the ones needed for eval + subset_keys = {} + for key in parameters: + if key in OPTIONAL_PARAMS_FOR_EVAL: + subset_keys[key] = parameters[key] + elif key in REQUIRED_PARAMS_FOR_EVAL: + subset_keys[key] = parameters[key] + looking_for.remove(key) + + if looking_for: + raise Exception( + f'Missing the following required parameters for async eval: {looking_for}' + ) + + # Convert the save_folder to a load_path + subset_keys['load_path'] = checkpoint + + for logger, config in subset_keys.get('loggers', {}).items(): + if logger == 'wandb': + config['group'] = config.get('name', training_run_name) + del config['name'] + + # Create new eval models list + model = subset_keys.pop('model') + new_models = {'model_name': model.get('name'), 'model': model} - return [new_model] + tokenizer = subset_keys.pop('tokenizer', None) + if tokenizer: + new_models['tokenizer'] = tokenizer + subset_keys['models'] = [new_models] + return subset_keys class AsyncEval(Callback): @@ -103,31 +168,54 @@ def __init__( interval: Union[str, int, Time], compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None, ): + self.training_config = training_config + + if isinstance(interval, str): + self.interval = Time.from_timestring(interval) + elif isinstance(interval, int): + self.interval = Time(interval, TimeUnit.EPOCH) + else: + self.interval = interval + self.check_interval = create_interval_scheduler(interval) self.compute = compute - self.count = 0 self.last_launch: Optional[Time] = None + self.last_checkpoint: Optional[str] = None # Run these during init to fail fast in any of the error cases self.current_run = self._get_current_run() - self.get_eval_parameters(training_config, self.current_run.name) + get_eval_parameters( + parameters=training_config, + checkpoint='test', + training_run_name=self.current_run.name, + ) log.info( f'Initialized AsyncEval callback. Will generate runs at interval {interval}' ) def run_event(self, event: Event, state: State, logger: Logger) -> None: del logger - if all([ - state.get_elapsed_duration() is not None, - self.check_interval(state, event), self.last_launch - != state.timestamp.batch, - dist.get_global_rank() == 0 - ]): - self.launch_run() + + should_launch_run = all([ + state.get_elapsed_duration() is not None, + self.check_interval(state, event), + self.last_launch != state.timestamp.batch, + dist.get_global_rank() == 0, + ]) + + if should_launch_run: + current_interval = f'{state.timestamp.get(self.interval.unit)}{self.interval.unit.value}' + + checkpoint = get_latest_checkpoint(event, state) + if not checkpoint or checkpoint == self.last_checkpoint: + # Do not eval a checkpoint that has already been evaluated. + return + + self.launch_run(checkpoint, current_interval) self.last_launch = state.timestamp.batch - self.count += 1 + self.last_checkpoint = checkpoint def _get_current_run(self) -> Run: if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, @@ -145,62 +233,28 @@ def _get_current_run(self) -> Run: # Allows the MapiException to be raised if the run doesn't exist return get_run(run_name, include_details=True) - def get_eval_parameters( - self, - parameters: Dict[str, Any], - run_name: str, - ) -> Dict[str, Any]: - looking_for = REQUIRED_PARAMS_FOR_EVAL.copy() - - # Go through all parameters and pull out the ones needed for eval - subset_keys = {} - for key in parameters: - if key in OPTIONAL_PARAMS_FOR_EVAL: - subset_keys[key] = parameters[key] - elif key in REQUIRED_PARAMS_FOR_EVAL: - subset_keys[key] = parameters[key] - looking_for.remove(key) - - if looking_for: - raise Exception( - f'Missing the following required parameters for async eval: {looking_for}' - ) - - # Convert the save_folder to a load_path - subset_keys['load_path'] = get_load_path( - subset_keys.pop('save_folder'), - parameters.get('save_latest_filename', None)) - - # TODO: Update this and parametrize step when the composer loggers support - # it. For now, eval runs will be logged to separate experiment tracker runs - # for logger, config in subset_keys.get('loggers', {}).items(): - # if logger == 'wandb': - # config['name'] = config.get('name', run_name) - # elif logger == 'mlflow': - # config['run_name'] = config.get('run_name', run_name) - - # Create new eval models list - subset_keys['models'] = get_eval_models_dict( - subset_keys.pop('model'), subset_keys.pop('tokenizer')) - - subset_keys['run_name'] = get_run_name(run_name, 0) - return subset_keys - - def launch_run(self) -> Run: + def launch_run(self, checkpoint: str, current_interval: str) -> Run: cfg = self.current_run.submitted_config default_compute = { 'gpus': 8, 'cluster': self.current_run.cluster, } - params = self.get_eval_parameters(self.training_config, - self.current_run.name) + + run_name = get_run_name(self.current_run.name, current_interval) + + params = get_eval_parameters( + parameters=self.training_config, + checkpoint=checkpoint, + training_run_name=self.current_run.name, + ) + params['run_name'] = run_name # TODO: This just runs an eval run, but we also want to attach the # deployment, which would require a hf conversion and parametrizing the # dependent_deployment in the run config command = 'cd llm-foundry/scripts \n composer eval/eval.py $PARAMETERS' run_config = RunConfig( - name=get_run_name(self.current_run.name, self.count), + name=run_name, image=self.current_run.image, compute=self.compute or default_compute, command=command, diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index d5169d2327..8694130dec 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -6,7 +6,9 @@ import pytest -from llmfoundry.callbacks.async_eval_callback import AsyncEval, get_run_name +from llmfoundry.callbacks.async_eval_callback import (AsyncEval, + get_eval_parameters, + get_run_name) from mcli import Run, RunConfig, RunStatus RUN_NAME = 'foo_bar-1234' @@ -25,17 +27,16 @@ 'tokenizer': { 'tokenizer_example': 'tokenizer_example', }, - 'save_folder': 'save_folder_example', } def test_get_run_name(): - a = get_run_name('foo-1234', 0) - assert a == 'eval0-foo' + a = get_run_name('foo-1234', '1ba') + assert a == 'eval-1ba-foo' # Run name should be truncated - b = get_run_name(50 * 'foo' + '-1234', 1) - assert b == 'eval1-foofoofoofoofoofoofoofoofoofoofoofoofoof' + b = get_run_name(50 * 'foo' + '-1234', '1ba') + assert b == 'eval-1ba-foofoofoofoofoofoofoofoofoofoofoofoofoof' @pytest.fixture(autouse=True, scope='module') @@ -74,13 +75,10 @@ def test_get_eval_parameters(): with pytest.raises( Exception, match='Missing the following required parameters for async eval:'): - AsyncEval.get_eval_parameters(None, {}, RUN_NAME) # type: ignore + get_eval_parameters({}, 'checkpoints/file', RUN_NAME) # minimal example - params = AsyncEval.get_eval_parameters( - None, # type: ignore - BASIC_PARAMS, - RUN_NAME) + params = get_eval_parameters(BASIC_PARAMS, 'checkpoints/file', RUN_NAME) assert params == { 'device_eval_batch_size': 2, @@ -89,9 +87,7 @@ def test_get_eval_parameters(): 'max_seq_len': 3, 'load_path': - 'save_folder_example/latest-rank0.pt', - 'run_name': - 'eval0-foo_bar', + 'checkpoints/file', 'models': [{ 'model_name': 'model_example', 'model': { @@ -109,8 +105,7 @@ def test_get_eval_parameters(): } # maximal example - params2 = AsyncEval.get_eval_parameters( - None, # type: ignore + params2 = get_eval_parameters( { # required **BASIC_PARAMS, @@ -130,12 +125,13 @@ def test_get_eval_parameters(): # ignore this 'ignore_this': 'ignore_this', }, - RUN_NAME) + 'checkpoints/file', + RUN_NAME, + ) assert params2 == { 'device_eval_batch_size': 2, 'icl_tasks': 'icl_task_example', 'max_seq_len': 3, - 'run_name': 'eval0-foo_bar', 'dist_timeout': 1, 'models': [{ 'model_name': 'model_example', @@ -162,7 +158,7 @@ def test_get_eval_parameters(): 'precision': 'precision_example', 'python_log_level': 'debug', 'seed': 5, - 'load_path': 'save_folder_example/latest-rank0.pt' + 'load_path': 'checkpoints/file', } @@ -208,12 +204,11 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, assert mock_get_run.call_count == 1 assert mock_get_run.call_args[0][0] == RUN_NAME - callback.count += 2 - callback.launch_run() + callback.launch_run('checkpoint/path', '1ba') assert mock_create_run.call_count == 1 run_config_created = mock_create_run.call_args[0][0] - assert run_config_created.name == 'eval2-foo_bar' + assert run_config_created.name == 'eval-1ba-foo_bar' assert run_config_created.image == 'fake-image' assert run_config_created.command @@ -225,7 +220,7 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, assert parameters['device_eval_batch_size'] == 2 assert parameters['icl_tasks'] == 'icl_task_example' assert parameters['max_seq_len'] == 3 - assert parameters['load_path'] == 'save_folder_example/latest-rank0.pt' + assert parameters['load_path'] == 'checkpoint/path' assert parameters['models'] == [{ 'model_name': 'model_example', 'model': { @@ -240,4 +235,4 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, 'tokenizer_example': 'tokenizer_example' } }] - assert parameters['run_name'] == 'eval0-foo_bar' # original run + assert parameters['run_name'] == 'eval-1ba-foo_bar' # original run From 08857d541bc0892a40410f5c3c2338cf7dd12a02 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 15 Nov 2023 14:19:25 -0800 Subject: [PATCH 22/49] fix test --- tests/test_builders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_builders.py b/tests/test_builders.py index 893a384f7f..be593edabd 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -109,7 +109,7 @@ def test_build_hf_checkpointer_callback(): 'save_interval': save_interval, 'mlflow_logging_config': mlflow_logging_config_dict }), - {}) + config={}) assert mock_hf_checkpointer.call_count == 1 _, _, kwargs = mock_hf_checkpointer.mock_calls[0] From 6ce8b7767fc5dfd0a7014adb7ace81789ec4cf40 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 15 Nov 2023 14:31:38 -0800 Subject: [PATCH 23/49] small fixes --- llmfoundry/callbacks/async_eval_callback.py | 8 +++++--- scripts/train/train.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 54f9ca2b72..380fe14905 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -124,7 +124,7 @@ def get_eval_parameters( if looking_for: raise Exception( - f'Missing the following required parameters for async eval: {looking_for}' + f'Missing the following required parameters for async eval: {looking_for}\n{parameters}' ) # Convert the save_folder to a load_path @@ -132,8 +132,10 @@ def get_eval_parameters( for logger, config in subset_keys.get('loggers', {}).items(): if logger == 'wandb': - config['group'] = config.get('name', training_run_name) - del config['name'] + config['group'] = config.pop('name', training_run_name) + + # mlflow currently does not support grouping, so this will just launch + # a new mlflow run # Create new eval models list model = subset_keys.pop('model') diff --git a/scripts/train/train.py b/scripts/train/train.py index 64b6d7ba1c..6b33663c91 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -504,7 +504,7 @@ def main(cfg: DictConfig) -> Trainer: # Callbacks callbacks: List[Callback] = [ - build_callback(str(name), callback_cfg, cfg) + build_callback(str(name), callback_cfg, logged_cfg) for name, callback_cfg in callback_configs.items() ] if callback_configs else [] From de155f7ea24e7385b03d3b234bf35dc9d05b9f12 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 15 Nov 2023 15:03:44 -0800 Subject: [PATCH 24/49] add logging --- llmfoundry/callbacks/async_eval_callback.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 380fe14905..73280d9c00 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -84,6 +84,7 @@ def get_latest_checkpoint(event: Event, state: State) -> Optional[str]: break if not checkpointer: + log.warning('No checkpoint saver callback found') return None if event.name == Event.FIT_END: @@ -91,6 +92,7 @@ def get_latest_checkpoint(event: Event, state: State) -> Optional[str]: return checkpointer.latest_filename if not checkpointer.saved_checkpoints: + log.warning('No saved checkpoints found on the checkpointer') return None return checkpointer.saved_checkpoints[-1] From e5f9e9e29a1d06a289dea2e3adefec9cf62c294b Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 15 Nov 2023 15:27:02 -0800 Subject: [PATCH 25/49] remove last launch check --- llmfoundry/callbacks/async_eval_callback.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 73280d9c00..8031879e31 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -184,7 +184,6 @@ def __init__( self.check_interval = create_interval_scheduler(interval) self.compute = compute - self.last_launch: Optional[Time] = None self.last_checkpoint: Optional[str] = None # Run these during init to fail fast in any of the error cases @@ -204,7 +203,6 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: should_launch_run = all([ state.get_elapsed_duration() is not None, self.check_interval(state, event), - self.last_launch != state.timestamp.batch, dist.get_global_rank() == 0, ]) @@ -217,8 +215,6 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: return self.launch_run(checkpoint, current_interval) - - self.last_launch = state.timestamp.batch self.last_checkpoint = checkpoint def _get_current_run(self) -> Run: From 3f518f9567127c640f0108a4dc8b8fd5cb8f1107 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 15 Nov 2023 15:56:03 -0800 Subject: [PATCH 26/49] better logging --- llmfoundry/callbacks/async_eval_callback.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 8031879e31..7f4b88b660 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -210,8 +210,13 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: current_interval = f'{state.timestamp.get(self.interval.unit)}{self.interval.unit.value}' checkpoint = get_latest_checkpoint(event, state) - if not checkpoint or checkpoint == self.last_checkpoint: + if not checkpoint: + return # warnings logged in get_latest_checkpoint + if checkpoint == self.last_checkpoint: # Do not eval a checkpoint that has already been evaluated. + log.info( + 'Skipping async eval because the checkpoint has not changed' + ) return self.launch_run(checkpoint, current_interval) @@ -264,9 +269,7 @@ def launch_run(self, checkpoint: str, current_interval: str) -> Run: parameters=params, ) - # Increase default timeout of 10s just in case - new_run = create_run(run_config, timeout=20) - log.info( - f'Launched new run {new_run.name} inside eval loop with config: \n{new_run.submitted_config}' - ) + log.info(f'Creating new run with config: \n{run_config}') + new_run = create_run(run_config) + log.info(f'Launched new run {new_run.name} inside eval loop') return new_run From ea742efd4f72929a94aacd9f8841b09ed02c7989 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 15 Nov 2023 17:07:00 -0800 Subject: [PATCH 27/49] fix parameters --- scripts/train/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index 6b33663c91..0598d8a1aa 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -504,7 +504,7 @@ def main(cfg: DictConfig) -> Trainer: # Callbacks callbacks: List[Callback] = [ - build_callback(str(name), callback_cfg, logged_cfg) + build_callback(str(name), callback_cfg, om.to_container(logged_cfg)) for name, callback_cfg in callback_configs.items() ] if callback_configs else [] From e3623f3d24c631b6f616afab671bc84fac1f5150 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 15 Nov 2023 18:04:14 -0800 Subject: [PATCH 28/49] fix double unit in the name --- llmfoundry/callbacks/async_eval_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 7f4b88b660..8c9af549f8 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -207,7 +207,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: ]) if should_launch_run: - current_interval = f'{state.timestamp.get(self.interval.unit)}{self.interval.unit.value}' + current_interval = f'{state.timestamp.get(self.interval.unit)}' checkpoint = get_latest_checkpoint(event, state) if not checkpoint: From 2deef3fa43040931bc87a876977c05aaf602a97d Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 15 Nov 2023 18:04:37 -0800 Subject: [PATCH 29/49] sadz --- llmfoundry/callbacks/async_eval_callback.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 8c9af549f8..c3d8202afb 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -207,8 +207,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: ]) if should_launch_run: - current_interval = f'{state.timestamp.get(self.interval.unit)}' - + current_interval = state.timestamp.get(self.interval.unit) checkpoint = get_latest_checkpoint(event, state) if not checkpoint: return # warnings logged in get_latest_checkpoint From add7fbbee444f8434f5069a9defd00fcbb56b017 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Sat, 2 Dec 2023 01:06:11 +0000 Subject: [PATCH 30/49] fies --- llmfoundry/callbacks/async_eval_callback.py | 38 +++++++++++++-------- llmfoundry/utils/builders.py | 5 ++- scripts/train/train.py | 27 ++++++++------- tests/callbacks/test_async_eval_callback.py | 2 +- 4 files changed, 43 insertions(+), 29 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index c3d8202afb..b77767052c 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -22,8 +22,6 @@ log = logging.getLogger(__name__) -MAX_RUN_NAME_LENGTH = 40 - REQUIRED_PARAMS_FOR_EVAL = { 'device_eval_batch_size', 'icl_tasks', # only required for eval, may not be specified in pure training @@ -34,6 +32,7 @@ OPTIONAL_PARAMS_FOR_EVAL = { 'dist_timeout', 'eval_gauntlet', + 'eval_loader', 'fsdp_config', 'icl_subset_num_batches', 'loggers', @@ -42,7 +41,8 @@ 'seed', } -NAME_PREFIX = 'eval' +RUN_NAME_PREFIX = 'eval' +MAX_RUN_NAME_BASE_LENGTH = 55 def get_run_name(training_run_name: str, current_interval: str) -> str: @@ -55,23 +55,27 @@ def get_run_name(training_run_name: str, current_interval: str) -> str: Returns: The new run name """ - *name_without_uuid_suffix, _ = training_run_name.split('-') - name_suffix = ('-'.join(name_without_uuid_suffix)) + name_without_uuid_suffix = training_run_name.rsplit('-', 1)[0] + + max_length = MAX_RUN_NAME_BASE_LENGTH - len(RUN_NAME_PREFIX) - len( + current_interval) - 2 # A run name that is too long will fail a createRun call - if len(name_suffix) > MAX_RUN_NAME_LENGTH: + if len(name_without_uuid_suffix) > max_length: + new_name = name_without_uuid_suffix[:max_length] log.warning( - f'Training run name {name_suffix} may be too long, truncating to {MAX_RUN_NAME_LENGTH} characters' - ) - name_suffix = name_suffix[:MAX_RUN_NAME_LENGTH] + f'Training run name {name_without_uuid_suffix} may be too long,' + + f' truncating to {new_name}') + name_without_uuid_suffix = new_name - return '-'.join([NAME_PREFIX, current_interval, name_suffix]) + return f'{RUN_NAME_PREFIX}-{current_interval}-{name_without_uuid_suffix}' def get_latest_checkpoint(event: Event, state: State) -> Optional[str]: """Get the latest checkpoint from the training run. Args: + event: The current run event state: The current state of the training run Returns: @@ -89,7 +93,7 @@ def get_latest_checkpoint(event: Event, state: State) -> Optional[str]: if event.name == Event.FIT_END: # Use the latest symlink for the end of training - return checkpointer.latest_filename + return str(checkpointer.latest_filename) if not checkpointer.saved_checkpoints: log.warning('No saved checkpoints found on the checkpointer') @@ -126,7 +130,7 @@ def get_eval_parameters( if looking_for: raise Exception( - f'Missing the following required parameters for async eval: {looking_for}\n{parameters}' + f'Missing the following required parameters for async eval: {looking_for}' ) # Convert the save_folder to a load_path @@ -141,10 +145,14 @@ def get_eval_parameters( # Create new eval models list model = subset_keys.pop('model') - new_models = {'model_name': model.get('name'), 'model': model} + + model_name = model.get('name', None) + if not model_name: + raise Exception(f'Async evaluation requires "name" keys for models') + new_models = {'model_name': model_name, 'model': model} tokenizer = subset_keys.pop('tokenizer', None) - if tokenizer: + if tokenizer is not None: new_models['tokenizer'] = tokenizer subset_keys['models'] = [new_models] return subset_keys @@ -207,7 +215,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: ]) if should_launch_run: - current_interval = state.timestamp.get(self.interval.unit) + current_interval = str(state.timestamp.get(self.interval.unit)) checkpoint = get_latest_checkpoint(event, state) if not checkpoint: return # warnings logged in get_latest_checkpoint diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index da4f46d568..23533238e8 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -160,7 +160,7 @@ def build_icl_data_and_gauntlet( def build_callback( name: str, kwargs: Union[DictConfig, Dict[str, Any]], - config: Dict[str, Any], + config: Any = None, ) -> Callback: if name == 'lr_monitor': return LRMonitor() @@ -209,6 +209,9 @@ def build_callback( kwargs = om.to_object(kwargs) # pyright: ignore return HuggingFaceCheckpointer(**kwargs) elif name == 'async_eval': + if not config: + raise ValueError( + 'Parameters config is required for async eval callback') return AsyncEval(**kwargs, training_config=config) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/scripts/train/train.py b/scripts/train/train.py index eefc29f2f8..73b8c236c1 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -529,18 +529,21 @@ def main(cfg: DictConfig) -> Trainer: mosaicml_logger.log_metrics({'data_validated': time.time()}) ## Evaluation - print('Building eval loader...') - - eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len - evaluators, _, eval_gauntlet_callback = build_evaluators( - eval_loader_config, - icl_tasks_config, - eval_gauntlet_config, - tokenizer=tokenizer, - device_eval_batch_size=device_eval_batch_size, - icl_seq_len=eval_icl_seq_len, - icl_subset_num_batches=icl_subset_num_batches, - ) + if use_async_eval: + print('Using async eval, skipping eval loader') + evaluators, eval_gauntlet_callback = [], None + else: + print('Building eval loader...') + eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len + evaluators, _, eval_gauntlet_callback = build_evaluators( + eval_loader_config, + icl_tasks_config, + eval_gauntlet_config, + tokenizer=tokenizer, + device_eval_batch_size=device_eval_batch_size, + icl_seq_len=eval_icl_seq_len, + icl_subset_num_batches=icl_subset_num_batches, + ) if eval_gauntlet_callback is not None and not use_async_eval: callbacks.append(eval_gauntlet_callback) diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index 8694130dec..c3de124b19 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -36,7 +36,7 @@ def test_get_run_name(): # Run name should be truncated b = get_run_name(50 * 'foo' + '-1234', '1ba') - assert b == 'eval-1ba-foofoofoofoofoofoofoofoofoofoofoofoofoof' + assert b == 'eval-1ba-foofoofoofoofoofoofoofoofoofoofoofoofoofoofoof' @pytest.fixture(autouse=True, scope='module') From 238086f6a94f500540cb9ebe45d1446c8bd88377 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Mon, 4 Dec 2023 18:14:44 +0000 Subject: [PATCH 31/49] git integration path validation and update --- llmfoundry/callbacks/async_eval_callback.py | 27 +++++++++- tests/callbacks/test_async_eval_callback.py | 59 ++++++++++++++++++++- 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index b77767052c..5726a5f5dc 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -261,16 +261,39 @@ def launch_run(self, checkpoint: str, current_interval: str) -> Run: ) params['run_name'] = run_name + integrations = cfg.integrations + found_llm_foundry, installation_path = False, 'llm-foundry' + for i in integrations: + if i['integration_type'] != 'git_repo' or i[ + 'git_repo'] != 'mosaicml/llm-foundry': + continue + + found_llm_foundry = True + if i['path']: + installation_path = i['path'] + + if not found_llm_foundry: + log.warning( + 'No github integration found for llm-foundry. Adding installation for latest' + ) + integrations.append({ + 'integration_type': 'git_repo', + 'git_repo': 'mosaicml/llm-foundry', + 'git_branch': 'v0.4.0', + 'pip_install': '-e .[gpu]', + 'ssh_clone': False, + }) + # TODO: This just runs an eval run, but we also want to attach the # deployment, which would require a hf conversion and parametrizing the # dependent_deployment in the run config - command = 'cd llm-foundry/scripts \n composer eval/eval.py $PARAMETERS' + command = f'cd {installation_path}/scripts \n composer eval/eval.py $PARAMETERS' run_config = RunConfig( name=run_name, image=self.current_run.image, compute=self.compute or default_compute, command=command, - integrations=cfg.integrations, + integrations=integrations, env_variables=cfg.env_variables, metadata=cfg.metadata, parameters=params, diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index c3de124b19..79711522bb 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import datetime +from copy import deepcopy from unittest.mock import MagicMock, patch import pytest @@ -11,6 +12,7 @@ get_run_name) from mcli import Run, RunConfig, RunStatus +# here RUN_NAME = 'foo_bar-1234' BASIC_PARAMS = { 'device_eval_batch_size': 2, @@ -210,7 +212,15 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, run_config_created = mock_create_run.call_args[0][0] assert run_config_created.name == 'eval-1ba-foo_bar' assert run_config_created.image == 'fake-image' - assert run_config_created.command + + print(run_config_created) + assert 'cd llm-foundry/scripts' in run_config_created.command + + integrations = run_config_created.integrations + assert len(integrations) == 1 + assert integrations[0]['integration_type'] == 'git_repo' + assert integrations[0]['git_repo'] == 'mosaicml/llm-foundry' + assert 'git_branch' in integrations[0] compute = run_config_created.compute assert compute['cluster'] == 'c2z3' @@ -236,3 +246,50 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, } }] assert parameters['run_name'] == 'eval-1ba-foo_bar' # original run + + +INTEGRATION_GIT_LLMFOUNDRY = { + 'integration_type': 'git_repo', + 'git_repo': 'mosaicml/llm-foundry', + 'git_branch': 'custom_branch', + 'path': 'custom/llm-foundry', + 'pip_install': '-e .[gpu]', + 'ssh_clone': False, +} +INTEGRATION_GIT_RANDOM = { + 'integration_type': 'git_repo', + 'git_repo': 'another-repo', + 'git_branch': 'foobar', +} + +FAKE_RUN_WITH_INTEGRATIONS = deepcopy(FAKE_RUN) +FAKE_RUN_WITH_INTEGRATIONS.submitted_config.integrations = [ + INTEGRATION_GIT_LLMFOUNDRY, INTEGRATION_GIT_RANDOM +] + + +@patch('llmfoundry.callbacks.async_eval_callback.get_run', + return_value=FAKE_RUN_WITH_INTEGRATIONS) +@patch('llmfoundry.callbacks.async_eval_callback.create_run', + return_value=FAKE_RUN_WITH_INTEGRATIONS) +def test_async_eval_callback_integrations(mock_create_run: MagicMock, + mock_get_run: MagicMock): + callback = AsyncEval(BASIC_PARAMS, + interval='2ba', + compute={ + 'cluster': 'c2z3', + 'nodes': 2, + }) + assert mock_get_run.call_count == 1 + + callback.launch_run('checkpoint/path', '1ba') + assert mock_create_run.call_count == 1 + run_config_created = mock_create_run.call_args[0][0] + + assert len(run_config_created.integrations) == 2 + # order should be retained + assert run_config_created.integrations[0] == INTEGRATION_GIT_LLMFOUNDRY + assert run_config_created.integrations[1] == INTEGRATION_GIT_RANDOM + + custom_path = run_config_created.integrations[0]['path'] + assert f'cd {custom_path}/scripts' in run_config_created.command From 53a99438c85f7643b74a9369d61fd7baef035098 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Mon, 4 Dec 2023 18:43:06 +0000 Subject: [PATCH 32/49] detect forks, better error/comment --- llmfoundry/callbacks/async_eval_callback.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 5726a5f5dc..ca791c4cab 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -264,8 +264,10 @@ def launch_run(self, checkpoint: str, current_interval: str) -> Run: integrations = cfg.integrations found_llm_foundry, installation_path = False, 'llm-foundry' for i in integrations: - if i['integration_type'] != 'git_repo' or i[ - 'git_repo'] != 'mosaicml/llm-foundry': + if i['integration_type'] != 'git_repo': + continue + + if not i['git_repo'].endswith('llm-foundry'): # detects forks continue found_llm_foundry = True @@ -273,9 +275,16 @@ def launch_run(self, checkpoint: str, current_interval: str) -> Run: installation_path = i['path'] if not found_llm_foundry: + # If github integration is not found, foundry is likely installed + # through the run command. In this case, we'll add the integration + # so the eval run will still work. However, it could cause unexpected + # behaviors because its not using custom repos or branches specified + # in the training run. For this reason, we'll log a warning log.warning( - 'No github integration found for llm-foundry. Adding installation for latest' - ) + 'No github integration found for llm-foundry. ' + + 'Adding installation to eval run for latest foundry release. ' + + 'To use a fork, custom branch, or custom version, configure ' + + 'llm-foundry installation through a github integration') integrations.append({ 'integration_type': 'git_repo', 'git_repo': 'mosaicml/llm-foundry', From 1f35a7b5cc61c8e27f2998077b279ce61692497c Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Mon, 4 Dec 2023 18:51:09 +0000 Subject: [PATCH 33/49] version import --- llmfoundry/callbacks/async_eval_callback.py | 9 ++++++--- tests/callbacks/test_async_eval_callback.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index ca791c4cab..d77a311250 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -275,20 +275,23 @@ def launch_run(self, checkpoint: str, current_interval: str) -> Run: installation_path = i['path'] if not found_llm_foundry: + from llmfoundry import __version__ as latest_foundry_version + # If github integration is not found, foundry is likely installed # through the run command. In this case, we'll add the integration # so the eval run will still work. However, it could cause unexpected # behaviors because its not using custom repos or branches specified # in the training run. For this reason, we'll log a warning + version = f'v{latest_foundry_version}' log.warning( - 'No github integration found for llm-foundry. ' + - 'Adding installation to eval run for latest foundry release. ' + + 'No github integration found for llm-foundry. Adding installation ' + + f'to eval run for latest foundry release ({version}). ' + 'To use a fork, custom branch, or custom version, configure ' + 'llm-foundry installation through a github integration') integrations.append({ 'integration_type': 'git_repo', 'git_repo': 'mosaicml/llm-foundry', - 'git_branch': 'v0.4.0', + 'git_branch': version, 'pip_install': '-e .[gpu]', 'ssh_clone': False, }) diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index 79711522bb..c864e7defc 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -220,7 +220,7 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, assert len(integrations) == 1 assert integrations[0]['integration_type'] == 'git_repo' assert integrations[0]['git_repo'] == 'mosaicml/llm-foundry' - assert 'git_branch' in integrations[0] + assert integrations[0]['git_branch'].startswith('v') compute = run_config_created.compute assert compute['cluster'] == 'c2z3' From 11845313df3880af73a516d11ec8d06247ddfd48 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Tue, 5 Dec 2023 23:30:03 +0000 Subject: [PATCH 34/49] last checkpoint --- llmfoundry/callbacks/async_eval_callback.py | 26 ++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index d77a311250..a91452616f 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -190,7 +190,12 @@ def __init__( else: self.interval = interval - self.check_interval = create_interval_scheduler(interval) + self.check_interval = create_interval_scheduler( + interval, + # There is a custom post_close to ensure that the final checkpoint + # (which is the most important) is evaled after it is written + include_end_of_training=False, + ) self.compute = compute self.last_checkpoint: Optional[str] = None @@ -219,6 +224,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: checkpoint = get_latest_checkpoint(event, state) if not checkpoint: return # warnings logged in get_latest_checkpoint + if checkpoint == self.last_checkpoint: # Do not eval a checkpoint that has already been evaluated. log.info( @@ -229,6 +235,22 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: self.launch_run(checkpoint, current_interval) self.last_checkpoint = checkpoint + def post_close(self) -> None: + if dist.get_global_rank() != 0: + return + self.training_config + + save_folder = self.training_config['save_folder'] + save_latest_filename = self.training_config.get('save_latest_filename', + None) + + if not save_latest_filename: + rank = dist.get_global_rank() + save_latest_filename = f'latest-rank{rank}.pt' + + checkpoint = f'{save_folder}/{save_latest_filename}' + self.launch_run(checkpoint, 'final') + def _get_current_run(self) -> Run: if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower() == 'false': @@ -246,6 +268,8 @@ def _get_current_run(self) -> Run: return get_run(run_name, include_details=True) def launch_run(self, checkpoint: str, current_interval: str) -> Run: + log.info(f'Launching eval run for {checkpoint} at {current_interval}') + cfg = self.current_run.submitted_config default_compute = { 'gpus': 8, From 7af73836e9ef1a790db57fc57473ffe6eec88575 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 6 Dec 2023 06:28:14 +0000 Subject: [PATCH 35/49] post_close -> close --- llmfoundry/callbacks/async_eval_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index a91452616f..bde80f4a06 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -192,7 +192,7 @@ def __init__( self.check_interval = create_interval_scheduler( interval, - # There is a custom post_close to ensure that the final checkpoint + # There is a custom close to ensure that the final checkpoint # (which is the most important) is evaled after it is written include_end_of_training=False, ) @@ -235,7 +235,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: self.launch_run(checkpoint, current_interval) self.last_checkpoint = checkpoint - def post_close(self) -> None: + def close(self) -> None: if dist.get_global_rank() != 0: return self.training_config From 9337af0cfb11ba8e67d9b5682e9ebc5c73dd0ff8 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 6 Dec 2023 06:47:23 +0000 Subject: [PATCH 36/49] add todos, fix path bug --- llmfoundry/callbacks/async_eval_callback.py | 3 ++- scripts/train/train.py | 26 +++++++++------------ 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index bde80f4a06..0b14e47c8c 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -240,6 +240,7 @@ def close(self) -> None: return self.training_config + # TODO: enforce this exists before save_folder = self.training_config['save_folder'] save_latest_filename = self.training_config.get('save_latest_filename', None) @@ -295,7 +296,7 @@ def launch_run(self, checkpoint: str, current_interval: str) -> Run: continue found_llm_foundry = True - if i['path']: + if i.get('path'): installation_path = i['path'] if not found_llm_foundry: diff --git a/scripts/train/train.py b/scripts/train/train.py index 73b8c236c1..58987e4e62 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -529,21 +529,17 @@ def main(cfg: DictConfig) -> Trainer: mosaicml_logger.log_metrics({'data_validated': time.time()}) ## Evaluation - if use_async_eval: - print('Using async eval, skipping eval loader') - evaluators, eval_gauntlet_callback = [], None - else: - print('Building eval loader...') - eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len - evaluators, _, eval_gauntlet_callback = build_evaluators( - eval_loader_config, - icl_tasks_config, - eval_gauntlet_config, - tokenizer=tokenizer, - device_eval_batch_size=device_eval_batch_size, - icl_seq_len=eval_icl_seq_len, - icl_subset_num_batches=icl_subset_num_batches, - ) + print('Building eval loader...') + eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len + evaluators, _, eval_gauntlet_callback = build_evaluators( + eval_loader_config, # TODO: async eval should not even call eval loader + icl_tasks_config if not use_async_eval else None, + eval_gauntlet_config if not use_async_eval else None, + tokenizer=tokenizer, + device_eval_batch_size=device_eval_batch_size, + icl_seq_len=eval_icl_seq_len, + icl_subset_num_batches=icl_subset_num_batches, + ) if eval_gauntlet_callback is not None and not use_async_eval: callbacks.append(eval_gauntlet_callback) From 87ffd8605a6713305af17df81929234f6d78d7e1 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 6 Dec 2023 06:59:12 +0000 Subject: [PATCH 37/49] add missing args --- llmfoundry/callbacks/async_eval_callback.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 0b14e47c8c..bfb9ef7e5a 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -235,7 +235,10 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: self.launch_run(checkpoint, current_interval) self.last_checkpoint = checkpoint - def close(self) -> None: + def close(self, state: State, logger: Logger) -> None: + del state + del logger + if dist.get_global_rank() != 0: return self.training_config From e940f1c7d8f5ec0364bb8f6f317012fae7436af3 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 6 Dec 2023 07:12:58 +0000 Subject: [PATCH 38/49] remove eval_loader in callback too --- llmfoundry/callbacks/async_eval_callback.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index bfb9ef7e5a..afa6a81006 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -32,7 +32,6 @@ OPTIONAL_PARAMS_FOR_EVAL = { 'dist_timeout', 'eval_gauntlet', - 'eval_loader', 'fsdp_config', 'icl_subset_num_batches', 'loggers', From bb040d1ac5e892db32928760f7a9e1f2296dc00f Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 6 Dec 2023 07:33:36 +0000 Subject: [PATCH 39/49] remove fit end event (already doing on close) --- llmfoundry/callbacks/async_eval_callback.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index afa6a81006..014fc53498 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -90,10 +90,6 @@ def get_latest_checkpoint(event: Event, state: State) -> Optional[str]: log.warning('No checkpoint saver callback found') return None - if event.name == Event.FIT_END: - # Use the latest symlink for the end of training - return str(checkpointer.latest_filename) - if not checkpointer.saved_checkpoints: log.warning('No saved checkpoints found on the checkpointer') return None From 14f386f8ea22048176e837f669ecd981f2f5982f Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 6 Dec 2023 23:25:46 +0000 Subject: [PATCH 40/49] misc fixes --- llmfoundry/callbacks/async_eval_callback.py | 60 +++++++++++++++------ tests/callbacks/test_async_eval_callback.py | 19 ++++++- 2 files changed, 62 insertions(+), 17 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 014fc53498..13eafe6e3f 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -8,6 +8,7 @@ import logging import os +from pathlib import Path from typing import Any, Dict, Optional, Union from composer.callbacks import CheckpointSaver @@ -94,7 +95,8 @@ def get_latest_checkpoint(event: Event, state: State) -> Optional[str]: log.warning('No saved checkpoints found on the checkpointer') return None - return checkpointer.saved_checkpoints[-1] + latest = checkpointer.saved_checkpoints[-1] + return str(Path(latest).parts[-1]) def get_eval_parameters( @@ -153,6 +155,35 @@ def get_eval_parameters( return subset_keys +def validate_interval(interval: Union[str, int, Time], + save_interval: Union[str, int, Time]) -> Time: + if isinstance(save_interval, str): + new_save_interval: Time = Time.from_timestring(save_interval) + elif isinstance(save_interval, int): + new_save_interval: Time = Time(save_interval, TimeUnit.EPOCH) + else: + new_save_interval: Time = save_interval + + if isinstance(interval, str): + result: Time = Time.from_timestring(interval) + elif isinstance(interval, int): + result: Time = Time(interval, TimeUnit.EPOCH) + else: + result: Time = interval + + if new_save_interval.unit != result.unit: + raise ValueError( + 'Save interval and async eval interval must be in the same unit') + if result < new_save_interval: + raise ValueError( + 'Async eval interval must be equal or greater (less frequent) than save interval' + ) + if result.value % new_save_interval.value != 0: + raise ValueError( + 'Async eval interval must be a multiple of save interval') + return result + + class AsyncEval(Callback): """Run the eval loop asynchronously as part of a MosaicML platform run. @@ -176,15 +207,14 @@ def __init__( compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None, ): - self.training_config = training_config - - if isinstance(interval, str): - self.interval = Time.from_timestring(interval) - elif isinstance(interval, int): - self.interval = Time(interval, TimeUnit.EPOCH) - else: - self.interval = interval + for required in ('save_interval', 'save_folder'): + if required not in training_config: + raise ValueError(f'{required} required for async eval') + self.checkpoint_save_folder = training_config['save_folder'] + self.training_config = training_config + self.interval = validate_interval(interval, + self.training_config['save_interval']) self.check_interval = create_interval_scheduler( interval, # There is a custom close to ensure that the final checkpoint @@ -220,15 +250,16 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: if not checkpoint: return # warnings logged in get_latest_checkpoint - if checkpoint == self.last_checkpoint: + full_checkpoint = f'{self.checkpoint_save_folder}/{checkpoint}' + if full_checkpoint == self.last_checkpoint: # Do not eval a checkpoint that has already been evaluated. log.info( 'Skipping async eval because the checkpoint has not changed' ) return - self.launch_run(checkpoint, current_interval) - self.last_checkpoint = checkpoint + self.launch_run(full_checkpoint, current_interval) + self.last_checkpoint = full_checkpoint def close(self, state: State, logger: Logger) -> None: del state @@ -236,10 +267,7 @@ def close(self, state: State, logger: Logger) -> None: if dist.get_global_rank() != 0: return - self.training_config - # TODO: enforce this exists before - save_folder = self.training_config['save_folder'] save_latest_filename = self.training_config.get('save_latest_filename', None) @@ -247,7 +275,7 @@ def close(self, state: State, logger: Logger) -> None: rank = dist.get_global_rank() save_latest_filename = f'latest-rank{rank}.pt' - checkpoint = f'{save_folder}/{save_latest_filename}' + checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}' self.launch_run(checkpoint, 'final') def _get_current_run(self) -> Run: diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index c864e7defc..f7c52551ab 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -6,10 +6,12 @@ from unittest.mock import MagicMock, patch import pytest +from composer.core import Time, TimeUnit from llmfoundry.callbacks.async_eval_callback import (AsyncEval, get_eval_parameters, - get_run_name) + get_run_name, + validate_interval) from mcli import Run, RunConfig, RunStatus # here @@ -164,6 +166,21 @@ def test_get_eval_parameters(): } +def test_validate_interval(): + with pytest.raises(ValueError): + validate_interval('1ba', '1ep') # different units + with pytest.raises(ValueError): + validate_interval('1ba', '2ba') # checkpointing happens less often + with pytest.raises(ValueError): + validate_interval('3ba', '2ba') # not a multiple + + assert validate_interval('2ba', '1ba') == Time(2, TimeUnit.BATCH) + two_epochs = Time(2, TimeUnit.EPOCH) + assert validate_interval(2, 2) == two_epochs + assert validate_interval(two_epochs, two_epochs) == two_epochs + assert validate_interval('2ep', two_epochs) == two_epochs + + FAKE_RUN = Run( run_uid='123', name=RUN_NAME, From ac37d094d01b62d70f3e536bfd51cbd50c44ee7d Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Wed, 6 Dec 2023 23:42:39 +0000 Subject: [PATCH 41/49] fix test --- llmfoundry/callbacks/async_eval_callback.py | 10 ++++++---- tests/callbacks/test_async_eval_callback.py | 11 ++++++----- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 13eafe6e3f..d7489f387f 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -130,9 +130,6 @@ def get_eval_parameters( f'Missing the following required parameters for async eval: {looking_for}' ) - # Convert the save_folder to a load_path - subset_keys['load_path'] = checkpoint - for logger, config in subset_keys.get('loggers', {}).items(): if logger == 'wandb': config['group'] = config.pop('name', training_run_name) @@ -146,7 +143,11 @@ def get_eval_parameters( model_name = model.get('name', None) if not model_name: raise Exception(f'Async evaluation requires "name" keys for models') - new_models = {'model_name': model_name, 'model': model} + new_models = { + 'model_name': model_name, + 'model': model, + 'load_path': checkpoint + } tokenizer = subset_keys.pop('tokenizer', None) if tokenizer is not None: @@ -250,6 +251,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: if not checkpoint: return # warnings logged in get_latest_checkpoint + # TODO: ensure the checkpoint is fully written before launching the eval run full_checkpoint = f'{self.checkpoint_save_folder}/{checkpoint}' if full_checkpoint == self.last_checkpoint: # Do not eval a checkpoint that has already been evaluated. diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index f7c52551ab..a81e12a13f 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -17,6 +17,8 @@ # here RUN_NAME = 'foo_bar-1234' BASIC_PARAMS = { + 'save_interval': '1ba', + 'save_folder': 'foobar', 'device_eval_batch_size': 2, 'icl_tasks': 'icl_task_example', 'max_seq_len': 3, @@ -90,8 +92,6 @@ def test_get_eval_parameters(): 'icl_task_example', 'max_seq_len': 3, - 'load_path': - 'checkpoints/file', 'models': [{ 'model_name': 'model_example', 'model': { @@ -105,6 +105,7 @@ def test_get_eval_parameters(): 'tokenizer': { 'tokenizer_example': 'tokenizer_example' }, + 'load_path': 'checkpoints/file', }], } @@ -150,6 +151,7 @@ def test_get_eval_parameters(): 'tokenizer': { 'tokenizer_example': 'tokenizer_example' }, + 'load_path': 'checkpoints/file', }], 'eval_gauntlet': 'eval_gauntlet_example', 'fsdp_config': { @@ -162,7 +164,6 @@ def test_get_eval_parameters(): 'precision': 'precision_example', 'python_log_level': 'debug', 'seed': 5, - 'load_path': 'checkpoints/file', } @@ -247,7 +248,6 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, assert parameters['device_eval_batch_size'] == 2 assert parameters['icl_tasks'] == 'icl_task_example' assert parameters['max_seq_len'] == 3 - assert parameters['load_path'] == 'checkpoint/path' assert parameters['models'] == [{ 'model_name': 'model_example', 'model': { @@ -260,7 +260,8 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, }, 'tokenizer': { 'tokenizer_example': 'tokenizer_example' - } + }, + 'load_path': 'checkpoint/path', }] assert parameters['run_name'] == 'eval-1ba-foo_bar' # original run From 9e11cf7a3f9bcf5d21f8338a4d235774e468cad6 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Fri, 8 Dec 2023 01:25:07 +0000 Subject: [PATCH 42/49] add back eval interval --- llmfoundry/callbacks/async_eval_callback.py | 20 +++++++++---- scripts/eval/eval.py | 2 +- scripts/train/train.py | 2 +- tests/callbacks/test_async_eval_callback.py | 32 +++++++++++++++++---- 4 files changed, 43 insertions(+), 13 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index d7489f387f..6ce57c579c 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -103,6 +103,7 @@ def get_eval_parameters( parameters: Dict[str, Any], checkpoint: str, training_run_name: str, + interval: Time, ) -> Dict[str, Any]: """Get the parameters needed for the eval run. @@ -110,6 +111,7 @@ def get_eval_parameters( parameters: The parameters from the training run checkpoint: The path to the latest checkpoint training_run_name: The name of the training run + interval: The current Time interval Returns: The parameters needed for the eval run as a dict @@ -134,6 +136,13 @@ def get_eval_parameters( if logger == 'wandb': config['group'] = config.pop('name', training_run_name) + config['init_kwargs'] = config.pop('init_kwargs', {}) + config['init_kwargs']['config'] = config['init_kwargs'].pop( + 'config', {}) + config['init_kwargs']['config']['eval_interval'] = interval.value + config['init_kwargs']['config'][ + 'eval_interval_units'] = interval.unit.value + # mlflow currently does not support grouping, so this will just launch # a new mlflow run @@ -231,6 +240,7 @@ def __init__( parameters=training_config, checkpoint='test', training_run_name=self.current_run.name, + interval=Time(0, self.interval.unit), ) log.info( f'Initialized AsyncEval callback. Will generate runs at interval {interval}' @@ -246,7 +256,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: ]) if should_launch_run: - current_interval = str(state.timestamp.get(self.interval.unit)) + current_interval = state.timestamp.get(self.interval.unit) checkpoint = get_latest_checkpoint(event, state) if not checkpoint: return # warnings logged in get_latest_checkpoint @@ -264,7 +274,6 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: self.last_checkpoint = full_checkpoint def close(self, state: State, logger: Logger) -> None: - del state del logger if dist.get_global_rank() != 0: @@ -278,7 +287,7 @@ def close(self, state: State, logger: Logger) -> None: save_latest_filename = f'latest-rank{rank}.pt' checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}' - self.launch_run(checkpoint, 'final') + self.launch_run(checkpoint, state.timestamp.get(self.interval.unit)) def _get_current_run(self) -> Run: if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, @@ -296,7 +305,7 @@ def _get_current_run(self) -> Run: # Allows the MapiException to be raised if the run doesn't exist return get_run(run_name, include_details=True) - def launch_run(self, checkpoint: str, current_interval: str) -> Run: + def launch_run(self, checkpoint: str, current_interval: Time) -> Run: log.info(f'Launching eval run for {checkpoint} at {current_interval}') cfg = self.current_run.submitted_config @@ -305,12 +314,13 @@ def launch_run(self, checkpoint: str, current_interval: str) -> Run: 'cluster': self.current_run.cluster, } - run_name = get_run_name(self.current_run.name, current_interval) + run_name = get_run_name(self.current_run.name, str(current_interval)) params = get_eval_parameters( parameters=self.training_config, checkpoint=checkpoint, training_run_name=self.current_run.name, + interval=current_interval, ) params['run_name'] = run_name diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 369a894720..194bba05d0 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -140,7 +140,7 @@ def evaluate_model( callbacks.append(eval_gauntlet_callback) loggers: List[LoggerDestination] = [ - build_logger(name, logger_cfg) + build_logger(name, om.to_container(logger_cfg, resolve=True)) for name, logger_cfg in loggers_cfg.items() ] diff --git a/scripts/train/train.py b/scripts/train/train.py index 58987e4e62..466e288415 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -453,7 +453,7 @@ def main(cfg: DictConfig) -> Trainer: # Loggers loggers = [ - build_logger(str(name), logger_cfg) + build_logger(str(name), om.to_container(logger_cfg, resolve=True)) for name, logger_cfg in logger_configs.items() ] if logger_configs else [] diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index a81e12a13f..143196aba1 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -81,10 +81,12 @@ def test_get_eval_parameters(): with pytest.raises( Exception, match='Missing the following required parameters for async eval:'): - get_eval_parameters({}, 'checkpoints/file', RUN_NAME) + get_eval_parameters({}, 'checkpoints/file', RUN_NAME, + Time(0, TimeUnit.EPOCH)) # minimal example - params = get_eval_parameters(BASIC_PARAMS, 'checkpoints/file', RUN_NAME) + params = get_eval_parameters(BASIC_PARAMS, 'checkpoints/file', RUN_NAME, + Time(0, TimeUnit.EPOCH)) assert params == { 'device_eval_batch_size': 2, @@ -122,7 +124,14 @@ def test_get_eval_parameters(): }, 'icl_subset_num_batches': 4, 'loggers': { - 'loggers_example': 'loggers_example' + 'wandb': { + 'init_kwargs': { + 'config': { + 'foo': 'bar' + }, + 'fee': 'bee' + } + } }, 'precision': 'precision_example', 'python_log_level': 'debug', @@ -132,6 +141,7 @@ def test_get_eval_parameters(): }, 'checkpoints/file', RUN_NAME, + Time(0, TimeUnit.EPOCH), ) assert params2 == { 'device_eval_batch_size': 2, @@ -159,7 +169,17 @@ def test_get_eval_parameters(): }, 'icl_subset_num_batches': 4, 'loggers': { - 'loggers_example': 'loggers_example' + 'wandb': { + 'group': 'foo_bar-1234', + 'init_kwargs': { + 'config': { + 'eval_interval': 0, + 'eval_interval_units': 'ep', + 'foo': 'bar' + }, + 'fee': 'bee' + }, + } }, 'precision': 'precision_example', 'python_log_level': 'debug', @@ -224,7 +244,7 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, assert mock_get_run.call_count == 1 assert mock_get_run.call_args[0][0] == RUN_NAME - callback.launch_run('checkpoint/path', '1ba') + callback.launch_run('checkpoint/path', Time(1, TimeUnit.BATCH)) assert mock_create_run.call_count == 1 run_config_created = mock_create_run.call_args[0][0] @@ -300,7 +320,7 @@ def test_async_eval_callback_integrations(mock_create_run: MagicMock, }) assert mock_get_run.call_count == 1 - callback.launch_run('checkpoint/path', '1ba') + callback.launch_run('checkpoint/path', Time(1, TimeUnit.BATCH)) assert mock_create_run.call_count == 1 run_config_created = mock_create_run.call_args[0][0] From aa652f349676b12da3c6d0e6e1e67a4ce6f7286f Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Fri, 8 Dec 2023 09:42:07 -0800 Subject: [PATCH 43/49] build_loggers and add tests --- llmfoundry/utils/builders.py | 18 +++++++++++++----- scripts/eval/eval.py | 2 +- scripts/train/train.py | 2 +- tests/utils/test_builders.py | 30 ++++++++++++++++++++++++++++-- 4 files changed, 43 insertions(+), 9 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 23533238e8..0aa41ca153 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -218,16 +218,24 @@ def build_callback( def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination: + for k, v in kwargs.items(): + print(k, v, type(k), type(v)) + + kwargs_dict = { + k: v if isinstance(v, str) else om.to_container(v, resolve=True) + for k, v in kwargs.items() + } + if name == 'wandb': - return WandBLogger(**kwargs) + return WandBLogger(**kwargs_dict) elif name == 'tensorboard': - return TensorboardLogger(**kwargs) + return TensorboardLogger(**kwargs_dict) elif name == 'in_memory_logger': - return InMemoryLogger(**kwargs) + return InMemoryLogger(**kwargs_dict) elif name == 'mlflow': - return MLFlowLogger(**kwargs) + return MLFlowLogger(**kwargs_dict) elif name == 'inmemory': - return InMemoryLogger(**kwargs) + return InMemoryLogger(**kwargs_dict) else: raise ValueError(f'Not sure how to build logger: {name}') diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 194bba05d0..369a894720 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -140,7 +140,7 @@ def evaluate_model( callbacks.append(eval_gauntlet_callback) loggers: List[LoggerDestination] = [ - build_logger(name, om.to_container(logger_cfg, resolve=True)) + build_logger(name, logger_cfg) for name, logger_cfg in loggers_cfg.items() ] diff --git a/scripts/train/train.py b/scripts/train/train.py index 466e288415..58987e4e62 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -453,7 +453,7 @@ def main(cfg: DictConfig) -> Trainer: # Loggers loggers = [ - build_logger(str(name), om.to_container(logger_cfg, resolve=True)) + build_logger(str(name), logger_cfg) for name, logger_cfg in logger_configs.items() ] if logger_configs else [] diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 20b2c4669c..9be6630075 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -12,6 +12,7 @@ import torch.nn as nn from composer.callbacks import Generate from composer.core import Evaluator +from composer.loggers import WandBLogger from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase @@ -20,8 +21,8 @@ from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_callback, build_eval_loaders, - build_evaluators, build_optimizer, - build_tokenizer) + build_evaluators, build_logger, + build_optimizer, build_tokenizer) @pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ @@ -130,6 +131,31 @@ def test_build_hf_checkpointer_callback(): assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict +def test_build_logger(): + with pytest.raises(ValueError): + _ = build_logger('unknown', {}) + + logger_cfg = DictConfig({ + 'project': 'foobar', + 'init_kwargs': { + 'config': { + 'foo': 'bar', + } + } + }) + wandb_logger = build_logger('wandb', logger_cfg) # type: ignore + assert isinstance(wandb_logger, WandBLogger) + assert wandb_logger.project == 'foobar' + + # confirm the typing conversion from DictConfig to dict, + # wandb.init() will fail if config is not explicitly + # dict type + ik = wandb_logger._init_kwargs + assert ik == {'config': {'foo': 'bar'}, 'project': 'foobar'} + assert isinstance(ik, dict) + assert isinstance(ik['config'], dict) + + class _DummyModule(nn.Module): def __init__(self, device: str = 'cpu', dtype: torch.dtype = torch.float32): From 0e4f0857217bba76c98f7f351cd1dfe9a523e2aa Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Mon, 11 Dec 2023 10:16:43 -0800 Subject: [PATCH 44/49] updates --- llmfoundry/callbacks/async_eval_callback.py | 17 +++++-------- llmfoundry/utils/builders.py | 2 +- scripts/train/train.py | 4 ++- tests/callbacks/test_async_eval_callback.py | 28 ++++++++++----------- 4 files changed, 23 insertions(+), 28 deletions(-) diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 6ce57c579c..8352a9e283 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -103,7 +103,6 @@ def get_eval_parameters( parameters: Dict[str, Any], checkpoint: str, training_run_name: str, - interval: Time, ) -> Dict[str, Any]: """Get the parameters needed for the eval run. @@ -111,7 +110,6 @@ def get_eval_parameters( parameters: The parameters from the training run checkpoint: The path to the latest checkpoint training_run_name: The name of the training run - interval: The current Time interval Returns: The parameters needed for the eval run as a dict @@ -136,13 +134,6 @@ def get_eval_parameters( if logger == 'wandb': config['group'] = config.pop('name', training_run_name) - config['init_kwargs'] = config.pop('init_kwargs', {}) - config['init_kwargs']['config'] = config['init_kwargs'].pop( - 'config', {}) - config['init_kwargs']['config']['eval_interval'] = interval.value - config['init_kwargs']['config'][ - 'eval_interval_units'] = interval.unit.value - # mlflow currently does not support grouping, so this will just launch # a new mlflow run @@ -240,7 +231,6 @@ def __init__( parameters=training_config, checkpoint='test', training_run_name=self.current_run.name, - interval=Time(0, self.interval.unit), ) log.info( f'Initialized AsyncEval callback. Will generate runs at interval {interval}' @@ -320,7 +310,6 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run: parameters=self.training_config, checkpoint=checkpoint, training_run_name=self.current_run.name, - interval=current_interval, ) params['run_name'] = run_name @@ -359,6 +348,12 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run: 'ssh_clone': False, }) + # This will record the timestamp and make it available for grouping + # and plotting in wandb + metadata = cfg.metadata + metadata['eval_timestamp'] = current_interval.value + metadata['eval_timestamp_unit'] = current_interval.unit.value + # TODO: This just runs an eval run, but we also want to attach the # deployment, which would require a hf conversion and parametrizing the # dependent_deployment in the run config diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 0aa41ca153..0b4d2f89ca 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -160,7 +160,7 @@ def build_icl_data_and_gauntlet( def build_callback( name: str, kwargs: Union[DictConfig, Dict[str, Any]], - config: Any = None, + config: Dict[str, Any] = None, ) -> Callback: if name == 'lr_monitor': return LRMonitor() diff --git a/scripts/train/train.py b/scripts/train/train.py index 58987e4e62..db66821fe3 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -531,8 +531,10 @@ def main(cfg: DictConfig) -> Trainer: ## Evaluation print('Building eval loader...') eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len + # TODO: evaluators should not be built at all if use_async_eval is True + # This will be fixed when eval_loader support is fully added to AsyncEval evaluators, _, eval_gauntlet_callback = build_evaluators( - eval_loader_config, # TODO: async eval should not even call eval loader + eval_loader_config, icl_tasks_config if not use_async_eval else None, eval_gauntlet_config if not use_async_eval else None, tokenizer=tokenizer, diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index 143196aba1..b3a1e98f79 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -81,12 +81,10 @@ def test_get_eval_parameters(): with pytest.raises( Exception, match='Missing the following required parameters for async eval:'): - get_eval_parameters({}, 'checkpoints/file', RUN_NAME, - Time(0, TimeUnit.EPOCH)) + get_eval_parameters({}, 'checkpoints/file', RUN_NAME) # minimal example - params = get_eval_parameters(BASIC_PARAMS, 'checkpoints/file', RUN_NAME, - Time(0, TimeUnit.EPOCH)) + params = get_eval_parameters(BASIC_PARAMS, 'checkpoints/file', RUN_NAME) assert params == { 'device_eval_batch_size': 2, @@ -126,9 +124,6 @@ def test_get_eval_parameters(): 'loggers': { 'wandb': { 'init_kwargs': { - 'config': { - 'foo': 'bar' - }, 'fee': 'bee' } } @@ -141,7 +136,6 @@ def test_get_eval_parameters(): }, 'checkpoints/file', RUN_NAME, - Time(0, TimeUnit.EPOCH), ) assert params2 == { 'device_eval_batch_size': 2, @@ -172,11 +166,6 @@ def test_get_eval_parameters(): 'wandb': { 'group': 'foo_bar-1234', 'init_kwargs': { - 'config': { - 'eval_interval': 0, - 'eval_interval_units': 'ep', - 'foo': 'bar' - }, 'fee': 'bee' }, } @@ -244,14 +233,23 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, assert mock_get_run.call_count == 1 assert mock_get_run.call_args[0][0] == RUN_NAME - callback.launch_run('checkpoint/path', Time(1, TimeUnit.BATCH)) + launch_time = Time(1, TimeUnit.BATCH) + callback.launch_run('checkpoint/path', launch_time) assert mock_create_run.call_count == 1 run_config_created = mock_create_run.call_args[0][0] assert run_config_created.name == 'eval-1ba-foo_bar' assert run_config_created.image == 'fake-image' - print(run_config_created) + metadata = run_config_created.metadata + assert 'eval_timestamp' in metadata + assert isinstance(metadata['eval_timestamp'], int) + assert metadata['eval_timestamp'] == launch_time.value + + assert 'eval_timestamp_unit' in metadata + assert isinstance(metadata['eval_timestamp_unit'], str) + assert metadata['eval_timestamp_unit'] == launch_time.unit.value + assert 'cd llm-foundry/scripts' in run_config_created.command integrations = run_config_created.integrations From dc25b2a46aec9288e2409575003394d91ac1ce42 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Mon, 11 Dec 2023 11:02:37 -0800 Subject: [PATCH 45/49] typing --- llmfoundry/utils/builders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 0b4d2f89ca..f0edd81c75 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -160,7 +160,7 @@ def build_icl_data_and_gauntlet( def build_callback( name: str, kwargs: Union[DictConfig, Dict[str, Any]], - config: Dict[str, Any] = None, + config: Optional[Dict[str, Any]] = None, ) -> Callback: if name == 'lr_monitor': return LRMonitor() @@ -209,7 +209,7 @@ def build_callback( kwargs = om.to_object(kwargs) # pyright: ignore return HuggingFaceCheckpointer(**kwargs) elif name == 'async_eval': - if not config: + if config is None: raise ValueError( 'Parameters config is required for async eval callback') return AsyncEval(**kwargs, training_config=config) From 6865b964344093bf5cea4e908423a6d1a7eff7c5 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Mon, 11 Dec 2023 11:45:11 -0800 Subject: [PATCH 46/49] changes --- llmfoundry/utils/builders.py | 10 ++++++---- scripts/train/train.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index f0edd81c75..043998fbff 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -212,15 +212,17 @@ def build_callback( if config is None: raise ValueError( 'Parameters config is required for async eval callback') - return AsyncEval(**kwargs, training_config=config) + + config_dict = { + k: v if isinstance(v, str) else om.to_container(v, resolve=True) + for k, v in kwargs.items() + } + return AsyncEval(**kwargs, training_config=config_dict) else: raise ValueError(f'Not sure how to build callback: {name}') def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination: - for k, v in kwargs.items(): - print(k, v, type(k), type(v)) - kwargs_dict = { k: v if isinstance(v, str) else om.to_container(v, resolve=True) for k, v in kwargs.items() diff --git a/scripts/train/train.py b/scripts/train/train.py index db66821fe3..9519739950 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -504,7 +504,7 @@ def main(cfg: DictConfig) -> Trainer: # Callbacks callbacks: List[Callback] = [ - build_callback(str(name), callback_cfg, om.to_container(logged_cfg)) + build_callback(str(name), callback_cfg, logged_cfg) for name, callback_cfg in callback_configs.items() ] if callback_configs else [] From 1ac70cccd51d4e604dfbde225f42a8b98bd3789b Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Mon, 11 Dec 2023 13:13:22 -0800 Subject: [PATCH 47/49] typing? --- llmfoundry/utils/builders.py | 8 ++------ scripts/train/train.py | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 043998fbff..404ad604ab 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -160,7 +160,7 @@ def build_icl_data_and_gauntlet( def build_callback( name: str, kwargs: Union[DictConfig, Dict[str, Any]], - config: Optional[Dict[str, Any]] = None, + config: Any = None, ) -> Callback: if name == 'lr_monitor': return LRMonitor() @@ -213,11 +213,7 @@ def build_callback( raise ValueError( 'Parameters config is required for async eval callback') - config_dict = { - k: v if isinstance(v, str) else om.to_container(v, resolve=True) - for k, v in kwargs.items() - } - return AsyncEval(**kwargs, training_config=config_dict) + return AsyncEval(**kwargs, training_config=config) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/scripts/train/train.py b/scripts/train/train.py index 9519739950..db66821fe3 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -504,7 +504,7 @@ def main(cfg: DictConfig) -> Trainer: # Callbacks callbacks: List[Callback] = [ - build_callback(str(name), callback_cfg, logged_cfg) + build_callback(str(name), callback_cfg, om.to_container(logged_cfg)) for name, callback_cfg in callback_configs.items() ] if callback_configs else [] From cd2a31d4efee4193c9faf3479cc5308488d4767f Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Mon, 18 Dec 2023 11:38:33 -0800 Subject: [PATCH 48/49] metadata in eval.py --- scripts/eval/eval.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 369a894720..214fb49abc 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -10,6 +10,7 @@ import pandas as pd import torch +from composer.loggers import MosaicMLLogger from composer.loggers.logger_destination import LoggerDestination from composer.models.base import ComposerModel from composer.trainer import Trainer @@ -24,7 +25,8 @@ from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_evaluators, build_logger, build_tokenizer) -from llmfoundry.utils.config_utils import pop_config, process_init_device +from llmfoundry.utils.config_utils import (log_config, pop_config, + process_init_device) def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, @@ -114,6 +116,7 @@ def evaluate_model( precision: str, eval_gauntlet_df: Optional[pd.DataFrame], icl_subset_num_batches: Optional[int], + metadata: Optional[Dict[str, str]], ): print(f'Evaluating model: {model_cfg.model_name}', flush=True) @@ -144,6 +147,20 @@ def evaluate_model( for name, logger_cfg in loggers_cfg.items() ] + if metadata is not None: + # Flatten the metadata for logging + loggers_cfg.pop('metadata', None) + loggers_cfg.update(metadata, merge=True) + + # Find the MosaicMLLogger + mosaicml_logger = next(( + logger for logger in loggers if isinstance(logger, MosaicMLLogger)), + None) + + if mosaicml_logger is not None: + mosaicml_logger.log_metrics(metadata) + mosaicml_logger._flush_metadata(force_flush=True) + if fsdp_config and model_cfg.model.get('load_in_8bit', False): raise ValueError( 'The FSDP config block is not supported when loading ' + @@ -177,6 +194,7 @@ def evaluate_model( assert composer_model is not None + print(f'Building trainer for {model_cfg.model_name}...') trainer = Trainer( run_name=run_name, seed=seed, @@ -193,6 +211,10 @@ def evaluate_model( python_log_level=python_log_level, ) + print('Logging config') + log_config(loggers_cfg) + + print(f'Starting eval for {model_cfg.model_name}...') if torch.cuda.is_available(): torch.cuda.synchronize() a = time.time() @@ -200,6 +222,7 @@ def evaluate_model( if torch.cuda.is_available(): torch.cuda.synchronize() b = time.time() + print(f'Ran {model_cfg.model_name} eval in: {b-a} seconds') return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) @@ -270,6 +293,12 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: 'icl_subset_num_batches', must_exist=False, default_value=None) + metadata: Optional[Dict[str, str]] = pop_config(cfg, + 'metadata', + must_exist=False, + default_value=None, + convert=True) + # Pop out interpolation variables. pop_config(cfg, 'model_name_or_path', must_exist=False, default_value=None) @@ -313,7 +342,8 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: python_log_level=python_log_level, precision=precision, eval_gauntlet_df=eval_gauntlet_df, - icl_subset_num_batches=icl_subset_num_batches) + icl_subset_num_batches=icl_subset_num_batches, + metadata=metadata) trainers.append(trainer) if eval_gauntlet_callback is not None: From 04865dbe56f3df4b803cb419c32badcf6539e5d1 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Tue, 19 Dec 2023 19:45:15 +0000 Subject: [PATCH 49/49] actually, just log metadata on every model eval --- scripts/eval/eval.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 383d3571a1..8dbe91e6d2 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import logging import os import sys @@ -119,6 +120,7 @@ def evaluate_model( eval_gauntlet_df: Optional[pd.DataFrame], icl_subset_num_batches: Optional[int], metadata: Optional[Dict[str, str]], + logged_config: DictConfig, ): log.info(f'Evaluating model: {model_cfg.model_name}') @@ -213,8 +215,8 @@ def evaluate_model( python_log_level=python_log_level, ) - log.info('Logging config') - log_config(loggers_cfg) + log.info('Evaluation config:') + log_config(logged_config) log.info(f'Starting eval for {model_cfg.model_name}...') if torch.cuda.is_available(): @@ -231,6 +233,10 @@ def evaluate_model( def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: om.resolve(cfg) + + # Create copy of config for logging + logged_cfg: DictConfig = copy.deepcopy(cfg) + model_configs: ListConfig = pop_config(cfg, 'models', must_exist=True) eval_gauntlet_config: Optional[Union[str, DictConfig]] = pop_config( cfg, 'eval_gauntlet', must_exist=False, default_value=None) @@ -345,7 +351,8 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: precision=precision, eval_gauntlet_df=eval_gauntlet_df, icl_subset_num_batches=icl_subset_num_batches, - metadata=metadata) + metadata=metadata, + logged_config=logged_cfg) trainers.append(trainer) if eval_gauntlet_callback is not None: