From fa2db9211667eaed4f75c4c8c44501b672d3e706 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 15 May 2024 12:30:33 -0700 Subject: [PATCH] [not for land] standalone repro of memory leak on float8 + compile + tp + sp Summary: Test Plan: ``` ./test/test_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: --- test/test_dtensor.py | 79 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 1f741b8..990e82d 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -23,6 +23,13 @@ Float8ColwiseParallel, Float8RowwiseParallel, ) +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) from float8_experimental.float8_utils import tensor_to_scale from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh, init_device_mesh @@ -221,6 +228,73 @@ def test_fp8_mlp_tensor_parallelism_base( tp_model.out_proj.weight.grad, sp_model.out_proj.weight.grad ) +def get_cuda_mem_allocated_gb(): + return torch.cuda.max_memory_allocated() / 1e9 + +class EmbLNLinear(nn.Module): + def __init__(self, dim0, dim1, dim2): + super().__init__() + self.emb = nn.Embedding(dim0, dim1) + self.ln = nn.LayerNorm(dim1) + self.fc = nn.Linear(dim1, dim2) + + def forward(self, x): + x = self.emb(x) + x = self.ln(x) + x = self.fc(x) + return x + +def test_fp8_compile_tp_sp_oom( + mesh: DeviceMesh, size=16, compile: bool = False +): + """ + A standalone repro of the OOM we observed on LLaMa 3 8B in torchtitan + with float8, compile, TP and SP on. When you run this test you should + see a memory leak, as evidenced by printouts of cuda memory used as well + as tensors not beeing freed in dumped the memory snapshot. + + TODO: root cause the issue and write a better test once we fix it. + """ + + vocab_size = 128256 + model_dim = 4096 + device = mesh.device_type + bsz = 1 + world_size = mesh.size() + + m = EmbLNLinear(vocab_size, model_dim, model_dim).cuda() + m = swap_linear_with_float8_linear( + m, Float8DynamicLinear, emulate=True + ) + + tokens = torch.ones(bsz, model_dim * world_size, device=device, dtype=torch.int64) + + m = parallelize_module( + m, + mesh, + { + "emb": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "ln": SequenceParallel(), + "fc": Float8ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Replicate(), + use_local_output=True, + ), + }, + ) + + m = torch.compile(m, dynamic=False) + torch.cuda.memory._record_memory_history() + for i in range(100): + print(i, get_cuda_mem_allocated_gb()) + y = m(tokens) + y.sum().backward() + torch.cuda.memory._dump_snapshot("dtensor_test_memory.pickle") + print('done') + def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) @@ -231,6 +305,11 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): # other test files to not use TestCase but instead just add the test # cases in the main func. device_mesh = setup_distributed() + test_fp8_compile_tp_sp_oom(device_mesh) + # TODO(before land): remove early return, this is for debugging only + import sys; sys.exit(0) + + tests = [ test_scaled_mm, test_fp8_redistribute,