From b5cad8dd74fce30ca2117b04eb37b782bc7c6f7b Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Jul 2024 15:11:27 -0700 Subject: [PATCH 1/7] add unit test for FSDP2 + torch.compile(transformer block) Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/fsdp_utils.py | 6 ++++-- .../{test_fsdp2_eager.py => test_fsdp2.py} | 17 ++++++++++++++--- test/test_fsdp2/test_fsdp2_common.py | 8 ++++++-- 3 files changed, 24 insertions(+), 7 deletions(-) rename test/test_fsdp2/{test_fsdp2_eager.py => test_fsdp2.py} (96%) diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 81d53b5..c7eb2c0 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -64,7 +64,9 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) scales = torch.split(scale_tensor, 1) # Replicate for scale, float8_linear in zip(scales, float8_linears): - float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor + float8_linear.weight._local_tensor._precomputed_scale = ( + scale._local_tensor.squeeze() + ) # FSDP pads its local tensor on dim-0. The subclass should be preserved such @@ -301,7 +303,7 @@ def __tensor_flatten__(self): ], { "mm_config": self._mm_config, - "is_amax_initialized": is_amax_initialized, + "is_amax_initialized": self.is_amax_initialized, }, ) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2.py similarity index 96% rename from test/test_fsdp2/test_fsdp2_eager.py rename to test/test_fsdp2/test_fsdp2.py index 91c629f..734652d 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -89,6 +89,13 @@ def test_transformer_parity(self): TensorScalingType.DYNAMIC, TensorScalingType.DELAYED, ], + "compile_transformer_block": [False, True], + # "enable_fsdp_fp8_all_gather": [True], + # "precompute": [True], + # "scaling_type_w": [ + # TensorScalingType.DYNAMIC, + # ], + # "compile_transformer_block": [True], }, self._test_transformer_parity, ) @@ -98,6 +105,7 @@ def _test_transformer_parity( enable_fsdp_fp8_all_gather: bool, precompute: bool, scaling_type_w: TensorScalingType, + compile_transformer_block: bool, ): if not enable_fsdp_fp8_all_gather and precompute: return @@ -114,9 +122,11 @@ def _test_transformer_parity( swap_linear_with_float8_linear(ref_module, scaling_type_w=scaling_type_w) with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w) - for submodule in module.modules(): - if isinstance(submodule, TransformerBlock): - fully_shard(submodule) + for layer_id, transformer_block in module.layers.named_children(): + if compile_transformer_block: + transformer_block = torch.compile(transformer_block, dynamic=False) + fully_shard(transformer_block) + module.layers.register_module(layer_id, transformer_block) fully_shard(module) ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) @@ -132,6 +142,7 @@ def _test_transformer_parity( local_inp, precompute, scaling_type_w=scaling_type_w, + compile_transformer_block=compile_transformer_block, ) @skip_if_lt_x_gpu(2) diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 2638401..182904a 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, sync_float8_amax_and_scale_history, @@ -23,6 +23,7 @@ def check_parity_no_mp( local_inp: torch.Tensor, precompute: bool = False, scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC, + compile_transformer_block: bool = False, ): for iter_idx in range(10): losses: List[torch.Tensor] = [] @@ -46,7 +47,10 @@ def check_parity_no_mp( ): precompute_float8_dynamic_scale_for_fsdp(model) - test_cls.assertEqual(losses[0], losses[1]) + if compile_transformer_block: + torch.testing.assert_close(losses[0], losses[1], atol=9.5e-2, rtol=9.5e-2) + else: + test_cls.assertEqual(losses[0], losses[1]) def check_parity_bf16_mp( From 272e85b42c61f297ebe1d49716e3a61bbf621d5d Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Jul 2024 15:16:38 -0700 Subject: [PATCH 2/7] remove debug lines Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/test_fsdp2/test_fsdp2.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/test/test_fsdp2/test_fsdp2.py b/test/test_fsdp2/test_fsdp2.py index 734652d..393346c 100644 --- a/test/test_fsdp2/test_fsdp2.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -90,12 +90,6 @@ def test_transformer_parity(self): TensorScalingType.DELAYED, ], "compile_transformer_block": [False, True], - # "enable_fsdp_fp8_all_gather": [True], - # "precompute": [True], - # "scaling_type_w": [ - # TensorScalingType.DYNAMIC, - # ], - # "compile_transformer_block": [True], }, self._test_transformer_parity, ) From 097ceed44a4fdd039e74958ea6fef37cd57729b8 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Jul 2024 15:26:24 -0700 Subject: [PATCH 3/7] fix linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index 9ad76f7..215a394 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -4,8 +4,6 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Optional, Tuple - import torch from float8_experimental.float8_tensor import ( From b6ebf8de0555fe078d89a20d977708538f55e293 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Jul 2024 17:01:53 -0700 Subject: [PATCH 4/7] numeric baseline against compiled model Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/test_fsdp2/test_fsdp2.py | 4 ++++ test/test_fsdp2/test_fsdp2_common.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_fsdp2/test_fsdp2.py b/test/test_fsdp2/test_fsdp2.py index 393346c..1cbec77 100644 --- a/test/test_fsdp2/test_fsdp2.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -114,6 +114,10 @@ def _test_transformer_parity( module = self.init_transformer(weight_tying=weight_tying).cuda() ref_module = copy.deepcopy(module) swap_linear_with_float8_linear(ref_module, scaling_type_w=scaling_type_w) + if compile_transformer_block: + for layer_id, transformer_block in ref_module.layers.named_children(): + transformer_block = torch.compile(transformer_block, dynamic=False) + ref_module.layers.register_module(layer_id, transformer_block) with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w) for layer_id, transformer_block in module.layers.named_children(): diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 182904a..61edac9 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -48,7 +48,7 @@ def check_parity_no_mp( precompute_float8_dynamic_scale_for_fsdp(model) if compile_transformer_block: - torch.testing.assert_close(losses[0], losses[1], atol=9.5e-2, rtol=9.5e-2) + test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4) else: test_cls.assertEqual(losses[0], losses[1]) From 2eaa51b7239b1806e246ccca8c81dfcf927027ab Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Jul 2024 17:11:41 -0700 Subject: [PATCH 5/7] update README and CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- README.md | 2 +- test/test_everything.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6102f5a..28867f0 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ pytest test/test_numerics_integration.py ./test/test_dtensor.sh # run integration tests on the FSDP2 integration -python test/test_fsdp2/test_fsdp2_eager.py +python test/test_fsdp2/test_fsdp2.py # run all of these tests ./test/test_everything.sh diff --git a/test/test_everything.sh b/test/test_everything.sh index 5eeb17c..72ca42d 100755 --- a/test/test_everything.sh +++ b/test/test_everything.sh @@ -15,7 +15,7 @@ then ./test/test_fsdp.sh ./test/test_fsdp_compile.sh ./test/test_dtensor.sh -pytest test/test_fsdp2/test_fsdp2_eager.py +pytest test/test_fsdp2/test_fsdp2.py fi echo "all tests successful" From cc763ceda7a87105ce698af888b0d68003cd175d Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 24 Jul 2024 00:28:46 -0700 Subject: [PATCH 6/7] fix float8 all-gather in 2d Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/fsdp_utils.py | 18 +++++++++-- test/test_fsdp2/test_fsdp2.py | 53 ++++++++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 04cd797..876ae93 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -81,6 +81,8 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: torch.ops.aten.as_strided.default, torch.ops.aten._to_copy.default, torch.ops.aten._pin_memory.default, + torch.ops.aten.split.Tensor, + torch.ops.aten.clone.default, } @@ -188,12 +190,22 @@ def fsdp_post_all_gather( *, out: Optional[torch.Tensor] = None, ): + from torch.distributed._tensor import DTensor + (data,) = all_gather_outputs (scale,) = metadata if out is not None: - assert isinstance(out, Float8Tensor), f"{type(out)}" - out._scale = scale - return + if isinstance(out, Float8Tensor): + out._scale = scale + elif isinstance(out, DTensor) and isinstance( + out._local_tensor, Float8Tensor + ): + out._local_tensor._scale = scale + else: + raise RuntimeError( + f"out must be a Float8Tensor or DTensor with Float8Tensor local tensor, but got {type(out)}" + ) + return out return Float8Tensor( data, scale, diff --git a/test/test_fsdp2/test_fsdp2.py b/test/test_fsdp2/test_fsdp2.py index 1cbec77..b846fae 100644 --- a/test/test_fsdp2/test_fsdp2.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -17,7 +17,13 @@ set_enable_fsdp_fp8_all_gather, ) from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy -from torch.distributed._tensor import DTensor +from torch.distributed._tensor import ( + distribute_tensor, + DTensor, + init_device_mesh, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( @@ -516,5 +522,50 @@ def test_delayed_scaling_inplace_update(self): self.assertNotEqual(fp8_amax_w_old.item(), m_fp8.fp8_amax_w.item()) +class Test2DFloat8MultiProcess(FSDPTest, TestFloat8Common): + @property + def world_size(self) -> int: + return min(torch.cuda.device_count(), 4) + + def init_global_mesh(self) -> DeviceMesh: + dp_size = 2 if self.world_size > 2 else 1 + return init_device_mesh( + "cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp") + ) + + @skip_if_lt_x_gpu(4) + def test_fsdp_tp( + self, + ): + enable_fsdp_fp8_all_gather = True + scaling_type_w = TensorScalingType.DYNAMIC + global_mesh = self.init_global_mesh() + _, tp_mesh = global_mesh["dp"], global_mesh["tp"] + module = self.init_transformer(weight_tying=False).cuda() + with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): + swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w) + + # "attention.wq": Float8ColwiseParallel + colwise_param = distribute_tensor( + module.layers[0].attention.wq.weight, tp_mesh, [Shard(0)] + ) + self.assertTrue( + isinstance(colwise_param, DTensor) + and isinstance( + colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor + ) + ) + # "attention.wo": Float8RowwiseParallel(output_layouts=Shard(1)), + rowwise_param = distribute_tensor( + module.layers[0].attention.wo.weight, tp_mesh, [Shard(1)] + ) + self.assertTrue( + isinstance(rowwise_param, DTensor) + and isinstance( + rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor + ) + ) + + if __name__ == "__main__": run_tests() From 7fbb8672c0c5ba54d60e66ecf5be52fa63c5a3d4 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Wed, 31 Jul 2024 19:42:39 -0700 Subject: [PATCH 7/7] tested successfully Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_linear.py | 15 ++- float8_experimental/float8_tensor_parallel.py | 41 +++---- float8_experimental/fsdp_utils.py | 6 +- test/test_fsdp2/test_fsdp2.py | 101 +++++++++++++++++- 4 files changed, 132 insertions(+), 31 deletions(-) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 37cf0d5..64ed87f 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -357,12 +357,9 @@ def cast_w_to_float8( ) else: assert self.scaling_type_w is TensorScalingType.DYNAMIC - if isinstance(self.weight, Float8Tensor): # cast by FSDP - w_fp8 = self.weight - else: - w_fp8 = cast_to_float8_e4m3_dynamic( - self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W - ) + w_fp8 = cast_to_float8_e4m3_dynamic( + self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W + ) return w_fp8 def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: @@ -407,8 +404,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.has_any_delayed_scaling: self.float8_pre_forward(input) - x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized) - w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) + with torch.profiler.record_function("cast_x_to_float8"): + x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized) + with torch.profiler.record_function("cast_w_to_float8"): + w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) y = torch.matmul(x_fp8, w_fp8.t()) diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 4c5297c..899d93b 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -44,12 +44,13 @@ def _prepare_input_fn( input_tensor = DTensor.from_local( input_tensor, device_mesh, input_layouts, run_check=False ) - - input_tensor = cast_to_float8_e4m3_dynamic( - input_tensor, - mod.linear_mm_config, - gemm_input_role=GemmInputRole.X, - ) # DTensor(Float8Tensor) + + with torch.profiler.record_function("colwise_cast_to_float8_e4m3_dynamic"): + input_tensor = cast_to_float8_e4m3_dynamic( + input_tensor, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.X, + ) # DTensor(Float8Tensor) # transform the input layouts to the desired layouts of ColwiseParallel if input_layouts != desired_input_layouts: @@ -67,7 +68,8 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ) # DTensor(torch.Tensor) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) + with torch.profiler.record_function("colwise_cast_to_float8_e5m2_dynamic_bw"): + outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) # back to local tensor return outputs.to_local() if use_local_output else outputs @@ -98,11 +100,12 @@ def _prepare_input_fn( input_tensor, device_mesh, input_layouts, run_check=False ) - input_tensor = cast_to_float8_e4m3_dynamic( - input_tensor, - mod.linear_mm_config, - gemm_input_role=GemmInputRole.X, - ) # DTensor(Float8Tensor) + with torch.profiler.record_function("rowwise_cast_to_float8_e4m3_dynamic"): + input_tensor = cast_to_float8_e4m3_dynamic( + input_tensor, + mod.linear_mm_config, + gemm_input_role=GemmInputRole.X, + ) # DTensor(Float8Tensor) if input_layouts != desired_input_layouts: input_tensor = input_tensor.redistribute( @@ -119,7 +122,8 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me outputs = outputs.redistribute(placements=output_layouts, async_op=True) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) + with torch.profiler.record_function("rowwise_cast_to_float8_e5m2_dynamic_bw"): + outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) # back to local tensor if use_local_output is True return outputs.to_local() if use_local_output else outputs @@ -196,11 +200,12 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): input, mesh, (input_layout,), run_check=False ) - dt_inp = cast_to_float8_e4m3_dynamic( - dt_inp, - self.linear_mm_config, - gemm_input_role=GemmInputRole.X, - ) # DTensor(Float8Tensor) + with torch.profiler.record_function("prepareinput_cast_to_float8_e4m3_dynamic"): + dt_inp = cast_to_float8_e4m3_dynamic( + dt_inp, + self.linear_mm_config, + gemm_input_role=GemmInputRole.X, + ) # DTensor(Float8Tensor) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 876ae93..d228189 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -140,6 +140,10 @@ def unwrap(t): WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {}) ) out = func(*args, **kwargs) + if func is torch.ops.aten.split.Tensor: + # if func is torch.ops.aten.clone.default: + if torch.distributed.get_rank() == 0: + print(f"dispatched {func=}", flush=True) if func not in _ops_to_preserve_subclass: return out return pytree.tree_map_only( @@ -203,7 +207,7 @@ def fsdp_post_all_gather( out._local_tensor._scale = scale else: raise RuntimeError( - f"out must be a Float8Tensor or DTensor with Float8Tensor local tensor, but got {type(out)}" + f"out must be a Float8Tensor or DTensor(_local_tensor=Float8Tensor), but got {out}" ) return out return Float8Tensor( diff --git a/test/test_fsdp2/test_fsdp2.py b/test/test_fsdp2/test_fsdp2.py index b846fae..6005834 100644 --- a/test/test_fsdp2/test_fsdp2.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -16,6 +16,14 @@ check_parity_no_mp, set_enable_fsdp_fp8_all_gather, ) +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) +from torch.distributed._tensor import Replicate, Shard from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._tensor import ( distribute_tensor, @@ -532,6 +540,86 @@ def init_global_mesh(self) -> DeviceMesh: return init_device_mesh( "cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp") ) + + def parallelize( + self, module: "Transformer", device_mesh: DeviceMesh, use_seq_parallel: bool + ) -> nn.Module: + assert isinstance(module, Transformer), f"Requires Transformer but got {module}" + module_tp = parallelize_module(module, device_mesh, { + "tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)), + "pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(0)), + "norm": SequenceParallel(), + }) + for layer_id, transformer_block in model.layers.items(): + layer_plan = { + + + "attention.wq": Float8ColwiseParallel(), + "attention.wk": Float8ColwiseParallel(), + "attention.wv": Float8ColwiseParallel(), + "attention.wo": Float8RowwiseParallel(output_layouts=Shard(1)), + + "feed_forward": PrepareFloat8ModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": Float8ColwiseParallel(), + "feed_forward.w2": Float8RowwiseParallel(output_layouts=Shard(1)), + "feed_forward.w3": Float8ColwiseParallel(), + } + + # Adjust attention module to use the local number of heads + attn_layer = transformer_block.attention + attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() + attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + # Parallelize the attention and feed forward submodules. + for layer in module_tp.layers: + layer_parallelize_plan = {} + layer_parallelize_plan["attention"] = PrepareFloat8ModuleInput( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ) + # shard the RMSNorms + layer_parallelize_plan["attention_norm"] = SequenceParallel() + layer_parallelize_plan["ffn_norm"] = SequenceParallel() + layer_parallelize_plan["attention.wq"] = Float8ColwiseParallel() + layer_parallelize_plan["attention.wk"] = Float8ColwiseParallel() + layer_parallelize_plan["attention.wv"] = Float8ColwiseParallel() + layer_parallelize_plan["attention.wo"] = Float8RowwiseParallel(output_layouts=Shard(1)) + + layer_parallelize_plan["feed_forward.w1"] = ( + ColwiseParallel(input_layouts=Shard(1)) + if use_seq_parallel + else ColwiseParallel() + ) + layer_parallelize_plan["feed_forward.w2"] = Float8RowwiseParallel(output_layouts=Shard(1)) + + parallelize_module(layer, device_mesh, layer_parallelize_plan) + + # Parallelize the output submodule. If weight tying is enabled, we need to + # make sure output.weight is sharded consistently as tok_embeddings.weight, + # at the cost of the all_reduce operation using RowwiseParallel. + output_parallelize_plan = ( + ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Replicate(), + ) + if use_seq_parallel + else ColwiseParallel(output_layouts=Replicate()) + ) + parallelize_module(module_tp.output, device_mesh, output_parallelize_plan) + + # Manually set output.weight so that parameters and gradients are shared. + if module_tp.model_args.weight_tying: + module_tp.output.weight = module_tp.tok_embeddings.weight + + return module_tp @skip_if_lt_x_gpu(4) def test_fsdp_tp( @@ -541,13 +629,18 @@ def test_fsdp_tp( scaling_type_w = TensorScalingType.DYNAMIC global_mesh = self.init_global_mesh() _, tp_mesh = global_mesh["dp"], global_mesh["tp"] - module = self.init_transformer(weight_tying=False).cuda() + model = self.init_transformer(weight_tying=False).cuda() with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w) + swap_linear_with_float8_linear(model, scaling_type_w=scaling_type_w) + model. + loss_parallel = True + + + # "attention.wq": Float8ColwiseParallel colwise_param = distribute_tensor( - module.layers[0].attention.wq.weight, tp_mesh, [Shard(0)] + model.layers[0].attention.wq.weight, tp_mesh, [Shard(0)] ) self.assertTrue( isinstance(colwise_param, DTensor) @@ -557,7 +650,7 @@ def test_fsdp_tp( ) # "attention.wo": Float8RowwiseParallel(output_layouts=Shard(1)), rowwise_param = distribute_tensor( - module.layers[0].attention.wo.weight, tp_mesh, [Shard(1)] + model.layers[0].attention.wo.weight, tp_mesh, [Shard(1)] ) self.assertTrue( isinstance(rowwise_param, DTensor)