diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index c8581049d9..8a9bf7b4cc 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -73,11 +73,6 @@ def get_eval_models_dict( tokenizer: Dict[str, Any], ) -> List[Dict[str, Any]]: name = model.get('name') - - cfg_overrides = model.pop('config_overrides', {}) - for key in cfg_overrides: - model[key] = cfg_overrides[key] - new_model = {'model_name': name, 'model': model} if tokenizer: @@ -125,8 +120,8 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: del logger if all([ state.get_elapsed_duration() is not None, - self.check_interval(state, event), - self.last_launch != state.timestamp.batch, + self.check_interval(state, event), self.last_launch + != state.timestamp.batch, dist.get_global_rank() == 0 ]): self.launch_run() diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index acdb5f5375..d5169d2327 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -96,8 +96,10 @@ def test_get_eval_parameters(): 'model_name': 'model_example', 'model': { 'name': 'model_example', - 'attn_config': { - 'foo': 'bar' + 'config_overrides': { + 'attn_config': { + 'foo': 'bar' + }, }, }, 'tokenizer': { @@ -139,8 +141,10 @@ def test_get_eval_parameters(): 'model_name': 'model_example', 'model': { 'name': 'model_example', - 'attn_config': { - 'foo': 'bar' + 'config_overrides': { + 'attn_config': { + 'foo': 'bar' + }, }, }, 'tokenizer': { @@ -226,9 +230,11 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock, 'model_name': 'model_example', 'model': { 'name': 'model_example', - 'attn_config': { - 'foo': 'bar' - } + 'config_overrides': { + 'attn_config': { + 'foo': 'bar' + }, + }, }, 'tokenizer': { 'tokenizer_example': 'tokenizer_example'