Skip to content

Commit

Permalink
Add support for eval_loader & eval_subset_num_batches in async callba…
Browse files Browse the repository at this point in the history
…ck (#834)

* Skip evalloader in training if using async eval

* add support for subset_num_batches

* remove todo

* eval first

* rename arg

* fix

* small updates

* om

* fix test

* eval run config

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
aspfohl and dakinggg authored Jan 27, 2024
1 parent 534f5b4 commit bdcce63
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 58 deletions.
78 changes: 59 additions & 19 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from composer.utils import dist
from composer.utils.misc import create_interval_scheduler

from mcli import ComputeConfig, Run, RunConfig, create_run, get_run
from mcli import Run, RunConfig, create_run, get_run

log = logging.getLogger(__name__)

Expand All @@ -33,7 +33,9 @@
OPTIONAL_PARAMS_FOR_EVAL = {
'dist_timeout',
'eval_gauntlet',
'eval_loader',
'fsdp_config',
'eval_subset_num_batches',
'icl_subset_num_batches',
'loggers',
'precision',
Expand Down Expand Up @@ -175,50 +177,84 @@ def validate_interval(interval: Union[str, int, Time],
return async_interval


def validate_eval_run_config(
eval_run_config: Optional[Dict[str, Any]]) -> Dict[str, Any]:

if not eval_run_config:
return {}

run_config = eval_run_config.copy()

supported_keys = {'image', 'command', 'compute', 'scheduling'}
found_unsupported = set()
for key in run_config:
if key not in supported_keys:
found_unsupported.add(key)

if found_unsupported:
raise ValueError(
f'Unsupported eval run config keys found: {", ".join(found_unsupported)}'
+ f'. Supported keys: {supported_keys}')

return run_config


class AsyncEval(Callback):
"""Run the eval loop asynchronously as part of a MosaicML platform run.
This callback is currently experimental. The API may change in the future.
Args:
training_config: Dict[str, Any]: The config from the training run
training_params: Dict[str, Any]: The parameter config from the training run
interval: Union[str, int, Time]: The interval describing how often eval runs should be
launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
compute: Optional[Union[ComputeConfig, Dict[str, Any]]]: The compute configuration to
use for the eval run. If not provided, the same cluster as the current run and a
single, full GPU node will be used.
eval_run_config: Optional[Dict[str, Any]]: A subset of mcli run config values to use
for the eval run. If not specified, any fields from run config will be created
dynamically from the training run config and the interval. The following fields
are supported:
- ``image``: Image of the eval run. Default: same as training run
- ``command``: Command to run for the eval run. Default: calls
`composer scripts/eval/eval.py $PARAMETERS`. If custom setup is needed,
the command should include calling the eval script with $PARAMETERS
- ``compute``: Compute to use for the eval run. Default: same cluster as
the training run and a single node (8 GPUs)
- ``scheduling``: Scheduling to use for the eval run. Default: same as training run
All fields are optional, but if specified, must be valid for a mcli run config. We
provide this optional config to give you the most flexibility in customizing the eval
run, but it is recommended to use the default values unless you have a specific use case
"""

def __init__(
self,
training_config: Dict[str, Any],
training_params: Dict[str, Any],
interval: Union[str, int, Time],
compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None,
eval_run_config: Optional[Dict[str, Any]] = None,
):

for required in ('save_interval', 'save_folder'):
if required not in training_config:
if required not in training_params:
raise ValueError(f'{required} required for async eval')

self.checkpoint_save_folder = training_config['save_folder']
self.training_config = training_config
self.checkpoint_save_folder = training_params['save_folder']
self.training_params = training_params
self.eval_run_config = validate_eval_run_config(eval_run_config)
self.interval = validate_interval(interval,
self.training_config['save_interval'])
self.training_params['save_interval'])
self.check_interval = create_interval_scheduler(
interval,
# There is a custom close to ensure that the final checkpoint
# (which is the most important) is evaled after it is written
include_end_of_training=False,
)
self.compute = compute
self.last_checkpoint: Optional[str] = None

# Run these during init to fail fast in any of the error cases
self.current_run = self._get_current_run()
get_eval_parameters(
parameters=training_config,
parameters=training_params,
checkpoint='test',
training_run_name=self.current_run.name,
)
Expand Down Expand Up @@ -259,7 +295,7 @@ def close(self, state: State, logger: Logger) -> None:
if dist.get_global_rank() != 0:
return

save_latest_filename = self.training_config.get('save_latest_filename',
save_latest_filename = self.training_params.get('save_latest_filename',
None)

if not save_latest_filename:
Expand Down Expand Up @@ -297,7 +333,7 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run:
run_name = get_run_name(self.current_run.name, str(current_interval))

params = get_eval_parameters(
parameters=self.training_config,
parameters=self.training_params,
checkpoint=checkpoint,
training_run_name=self.current_run.name,
)
Expand Down Expand Up @@ -347,12 +383,16 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run:
# TODO: This just runs an eval run, but we also want to attach the
# deployment, which would require a hf conversion and parametrizing the
# dependent_deployment in the run config
command = f'cd {installation_path}/scripts \n composer eval/eval.py $PARAMETERS'
default_command = f'cd {installation_path}/scripts \n composer eval/eval.py $PARAMETERS'
run_config = RunConfig(
name=run_name,
image=self.current_run.image,
compute=self.compute or default_compute,
command=command,
image=self.eval_run_config.get('image', self.current_run.image),
command=self.eval_run_config.get('command', default_command),
compute=self.eval_run_config.get('compute', default_compute),
scheduling=self.eval_run_config.get(
'scheduling',
self.current_run.submitted_config.scheduling,
),
integrations=integrations,
env_variables=cfg.env_variables,
metadata=cfg.metadata,
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def build_callback(
raise ValueError(
'Parameters config is required for async eval callback')

return AsyncEval(**kwargs, training_config=config)
return AsyncEval(**kwargs, training_params=config)
else:
raise ValueError(f'Not sure how to build callback: {name}')

Expand Down
17 changes: 12 additions & 5 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def evaluate_model(
python_log_level: Optional[str],
precision: str,
eval_gauntlet_df: Optional[pd.DataFrame],
eval_subset_num_batches: int,
icl_subset_num_batches: Optional[int],
metadata: Optional[Dict[str, str]],
logged_config: DictConfig,
Expand Down Expand Up @@ -224,7 +225,8 @@ def evaluate_model(
if torch.cuda.is_available():
torch.cuda.synchronize()
a = time.time()
trainer.eval(eval_dataloader=evaluators)
trainer.eval(eval_dataloader=evaluators,
subset_num_batches=eval_subset_num_batches)
if torch.cuda.is_available():
torch.cuda.synchronize()
b = time.time()
Expand Down Expand Up @@ -299,10 +301,14 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:
'loggers',
must_exist=False,
default_value={})
icl_subset_num_batches: int = pop_config(cfg,
'icl_subset_num_batches',
must_exist=False,
default_value=None)
eval_subset_num_batches: int = pop_config(cfg,
'eval_subset_num_batches',
must_exist=False,
default_value=-1)
icl_subset_num_batches: Optional[int] = pop_config(cfg,
'icl_subset_num_batches',
must_exist=False,
default_value=None)
metadata: Optional[Dict[str, str]] = pop_config(cfg,
'metadata',
must_exist=False,
Expand Down Expand Up @@ -356,6 +362,7 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:
python_log_level=python_log_level,
precision=precision,
eval_gauntlet_df=eval_gauntlet_df,
eval_subset_num_batches=eval_subset_num_batches,
icl_subset_num_batches=icl_subset_num_batches,
metadata=metadata,
logged_config=logged_cfg,
Expand Down
46 changes: 26 additions & 20 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def main(cfg: DictConfig) -> Trainer:
callback_configs: Optional[DictConfig] = pop_config(cfg,
'callbacks',
must_exist=False,
default_value=None)
default_value=None,
convert=True)
algorithm_configs: Optional[DictConfig] = pop_config(cfg,
'algorithms',
must_exist=False,
Expand Down Expand Up @@ -519,8 +520,7 @@ def main(cfg: DictConfig) -> Trainer:
for name, callback_cfg in callback_configs.items()
] if callback_configs else []

use_async_eval = any(
isinstance(callback, AsyncEval) for callback in callbacks)
use_async_eval = any(isinstance(c, AsyncEval) for c in callbacks)

# Algorithms
algorithms = [
Expand All @@ -540,22 +540,28 @@ def main(cfg: DictConfig) -> Trainer:
mosaicml_logger.log_metrics({'data_validated': time.time()})

## Evaluation
log.info('Building eval loader...')
eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len
# TODO: evaluators should not be built at all if use_async_eval is True
# This will be fixed when eval_loader support is fully added to AsyncEval
evaluators, _, eval_gauntlet_callback = build_evaluators(
eval_loader_config,
icl_tasks_config if not use_async_eval else None,
eval_gauntlet_config if not use_async_eval else None,
tokenizer=tokenizer,
device_eval_batch_size=device_eval_batch_size,
icl_seq_len=eval_icl_seq_len,
icl_subset_num_batches=icl_subset_num_batches,
)

if eval_gauntlet_callback is not None and not use_async_eval:
callbacks.append(eval_gauntlet_callback)
if use_async_eval:
evaluators = []
if eval_first:
warnings.warn(
'AsyncEval callback does not support eval_first=True. Ignoring.'
)
eval_first = False

else:
log.info('Building eval loader...')
eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len
evaluators, _, eval_gauntlet_callback = build_evaluators(
eval_loader_config,
icl_tasks_config,
eval_gauntlet_config,
tokenizer=tokenizer,
device_eval_batch_size=device_eval_batch_size,
icl_seq_len=eval_icl_seq_len,
icl_subset_num_batches=icl_subset_num_batches,
)
if eval_gauntlet_callback is not None:
callbacks.append(eval_gauntlet_callback)

# Build Model
log.info('Initializing model...')
Expand All @@ -582,7 +588,7 @@ def main(cfg: DictConfig) -> Trainer:
optimizer = build_optimizer(model, optimizer_name, optimizer_config)

# Now add the eval metrics
if eval_loader_config is not None:
if eval_loader_config is not None and not use_async_eval:
train_metrics = model.get_metrics(is_train=True)
evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics)

Expand Down
54 changes: 41 additions & 13 deletions tests/callbacks/test_async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from llmfoundry.callbacks.async_eval_callback import (AsyncEval,
get_eval_parameters,
get_run_name,
validate_eval_run_config,
validate_interval)
from mcli import Run, RunConfig, RunStatus

# here
RUN_NAME = 'foo_bar-1234'
BASIC_PARAMS = {
'save_interval': '1ba',
Expand Down Expand Up @@ -191,6 +191,29 @@ def test_validate_interval():
assert validate_interval('2ep', two_epochs) == two_epochs


def test_validate_eval_run_config():
assert validate_eval_run_config(None) == {}
assert validate_eval_run_config({}) == {}

with pytest.raises(ValueError):
validate_eval_run_config({'foo': 'bar'})

valid_config = {
'image': 'example_image',
'command': 'example_command',
'compute': {
'gpus': 1,
'cluster': 'example_cluster',
},
'scheduling': {
'priority': 'high',
'preemptible': True,
},
}
res = validate_eval_run_config(valid_config)
assert res == valid_config


FAKE_RUN = Run(
run_uid='123',
name=RUN_NAME,
Expand Down Expand Up @@ -223,12 +246,16 @@ def test_validate_interval():
return_value=FAKE_RUN)
def test_async_eval_callback_minimal(mock_create_run: MagicMock,
mock_get_run: MagicMock):
callback = AsyncEval(BASIC_PARAMS,
interval='2ba',
compute={
'cluster': 'c2z3',
'nodes': 2,
})
callback = AsyncEval(
BASIC_PARAMS,
interval='2ba',
eval_run_config={
'compute': {
'cluster': 'c2z3',
'nodes': 2,
},
},
)
assert callback.current_run.name == RUN_NAME
assert mock_get_run.call_count == 1
assert mock_get_run.call_args[0][0] == RUN_NAME
Expand Down Expand Up @@ -310,12 +337,13 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock,
return_value=FAKE_RUN_WITH_INTEGRATIONS)
def test_async_eval_callback_integrations(mock_create_run: MagicMock,
mock_get_run: MagicMock):
callback = AsyncEval(BASIC_PARAMS,
interval='2ba',
compute={
'cluster': 'c2z3',
'nodes': 2,
})
callback = AsyncEval(
BASIC_PARAMS,
interval='2ba',
eval_run_config={'compute': {
'cluster': 'c2z3',
'nodes': 2,
}})
assert mock_get_run.call_count == 1

callback.launch_run('checkpoint/path', Time(1, TimeUnit.BATCH))
Expand Down

0 comments on commit bdcce63

Please sign in to comment.