diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py index 963c596e76..40342f2ddb 100644 --- a/llmfoundry/models/utils/config_moe_args.py +++ b/llmfoundry/models/utils/config_moe_args.py @@ -177,9 +177,9 @@ def config_megablocks_moe_args( lbl_process_group = create_set_process_group(lbl_process_group) else: lbl_process_group = None - elif lbl_process_group is not None: + elif not isinstance(lbl_process_group, distributed.ProcessGroup): raise ValueError( - f'Unknown {lbl_process_group=}. Options are: none | expert_group | global_group | .', + f'Unknown {lbl_process_group=}. Options are: none | a process group | ``expert_group`` | ``global_group`` | .', ) ffn_config['lbl_process_group'] = lbl_process_group diff --git a/tests/models/utils/test_config_moe_args.py b/tests/models/utils/test_config_moe_args.py new file mode 100644 index 0000000000..426363d2c3 --- /dev/null +++ b/tests/models/utils/test_config_moe_args.py @@ -0,0 +1,30 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import pytest + +from llmfoundry.models.utils.config_moe_args import ( + config_megablocks_moe_args, + get_megablocks_device_mesh, +) + + +@pytest.mark.gpu +def test_config_megablocks_moe_args_error(): + ffn_config_base: dict[str, Any] = { + 'moe_world_size': 1, + 'lbl_process_group': 'not_real', + 'ffn_type': 'mb_moe', + 'fc_type': 'torch', + } + + with pytest.raises(ValueError): + config_megablocks_moe_args( + ffn_config=ffn_config_base, + d_model=128, + expansion_ratio=4, + n_layers=2, + get_device_mesh=get_megablocks_device_mesh, + )