From 898fc2f924200271fa2a46894a315f1331c5442c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 24 May 2024 00:58:30 -0700 Subject: [PATCH] configurable submesh --- llmfoundry/models/layers/ffn.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 2b62c77eb6..a28725ee0f 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -382,10 +382,30 @@ def attach_ffn_mb_args( ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group +def get_fsdp_submesh_2d(device_mesh: DeviceMesh): + """Get the submesh for FSDP. + + Args: + device_mesh (DeviceMesh): The full device mesh. + + Returns: + DeviceMesh: The submesh for FSDP. + """ + if device_mesh.mesh.ndim == 2: + submesh = device_mesh['weight_parallel'] + elif device_mesh.mesh.ndim == 3: + raise RuntimeError(f'HSDP + MoE is not supported.') + else: + raise ValueError(f'{device_mesh.mesh.ndim=} not supported for MoE.') + + return submesh + + def set_ffn_device_mesh( ffn: nn.Module, moe_world_size: int, device_mesh: DeviceMesh, + get_fsdp_submesh: Callable[[DeviceMesh], DeviceMesh], ): """Sets the device mesh in FSDP kwargs. @@ -413,12 +433,7 @@ def set_ffn_device_mesh( for name, dtensorified_param in dtensorified_params: ffn.experts.mlp.register_parameter(name, dtensorified_param) - if device_mesh.mesh.ndim == 2: - submesh = device_mesh['weight_parallel'] - elif device_mesh.mesh.ndim == 3: - raise RuntimeError(f'HSDP + MoE is not supported.') - else: - raise ValueError(f'{device_mesh.mesh.ndim=} not supported for MoE.') + submesh = get_fsdp_submesh(device_mesh) ffn.experts._fsdp_kwargs_dict = { 'device_mesh': submesh, @@ -470,6 +485,7 @@ def build_mb_moe( ffn=ffn, moe_world_size=moe_world_size, device_mesh=kwargs['device_mesh'], + get_fsdp_submesh=get_fsdp_submesh_2d, ) return ffn @@ -536,6 +552,7 @@ def build_mb_dmoe( ffn=ffn, moe_world_size=moe_world_size, device_mesh=kwargs['device_mesh'], + get_fsdp_submesh=get_fsdp_submesh_2d, ) return ffn