Skip to content

Commit

Permalink
small updates
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Nov 13, 2023
1 parent 28e47df commit 78cc0b8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
28 changes: 19 additions & 9 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""
Run the eval loop asynchronously as part of a MosaicML platform run.
This callback is currently experimental. The API may change in the future.
"""

import logging
import os
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -42,6 +48,7 @@ def get_run_name(previous_run_name: str, count: int) -> str:
*name_without_uuid_suffix, _ = previous_run_name.split('-')
name_suffix = ('-'.join(name_without_uuid_suffix))

# A run name that is too long will fail a createRun call
if len(name_suffix) > MAX_RUN_NAME_LENGTH:
log.warning(
f'Training run name {name_suffix} may be too long, truncating to {MAX_RUN_NAME_LENGTH} characters'
Expand All @@ -68,7 +75,7 @@ def get_eval_models_dict(
) -> List[Dict[str, Any]]:
name = model.get('name')

cfg_overrides = model.pop('cfg_overrides', {})
cfg_overrides = model.pop('config_overrides', {})
for key in cfg_overrides:
model[key] = cfg_overrides[key]

Expand All @@ -83,6 +90,8 @@ def get_eval_models_dict(
class AsyncEval(Callback):
"""Run the eval loop asynchronously as part of a MosaicML platform run.
This callback is currently experimental. The API may change in the future.
Args:
interval: Union[str, int, Time]: The interval describing how often eval runs should be
launched. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Expand Down Expand Up @@ -117,8 +126,8 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
del logger
if all([
state.get_elapsed_duration() is not None,
self.check_interval(state, event),
self.last_launch != state.timestamp.batch,
self.check_interval(state, event), self.last_launch
!= state.timestamp.batch,
dist.get_global_rank() == 0
]):
self.launch_run()
Expand Down Expand Up @@ -168,12 +177,13 @@ def get_eval_parameters(
subset_keys.pop('save_folder'),
parameters.get('save_latest_filename', None))

# Update the loggers to use the training run name
for logger, config in subset_keys.get('loggers', {}).items():
if logger == 'wandb':
config['name'] = config.get('name', run_name)
elif logger == 'mlflow':
config['run_name'] = config.get('run_name', run_name)
# TODO: Update this and parametrize step when the composer loggers support
# it. For now, eval runs will be logged to separate experiment tracker runs
# for logger, config in subset_keys.get('loggers', {}).items():
# if logger == 'wandb':
# config['name'] = config.get('name', run_name)
# elif logger == 'mlflow':
# config['run_name'] = config.get('run_name', run_name)

# Create new eval models list
subset_keys['models'] = get_eval_models_dict(
Expand Down
3 changes: 2 additions & 1 deletion tests/callbacks/test_async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_get_run_name():
a = get_run_name('foo-1234', 0)
assert a == 'eval0-foo'

# Run name should be truncated
b = get_run_name(50 * 'foo' + '-1234', 1)
assert b == 'eval1-foofoofoofoofoofoofoofoofoofoofoofoofoof'

Expand Down Expand Up @@ -58,7 +59,7 @@ def test_fails_when_no_run_name():
'max_seq_len': 3,
'model': {
'name': 'model_example',
'cfg_overrides': {
'config_overrides': {
'attn_config': {
'foo': 'bar'
}
Expand Down

0 comments on commit 78cc0b8

Please sign in to comment.