diff --git a/benchmarks/debug.py b/benchmarks/debug.py new file mode 100644 index 0000000..22b7195 --- /dev/null +++ b/benchmarks/debug.py @@ -0,0 +1,44 @@ +""" +Can we benchmark a transformer block and meaningfully reason +about what inductor is doing? +""" + +import torch +import torch.nn as nn + +from float8_experimental.float8_linear_utils import ( + swap_linear_with_float8_linear, +) +from float8_experimental.float8_dynamic_linear import Float8DynamicLinear + +class M(nn.Module): + def forward(self, x): + return torch.cos(x) + +def run(): + + m = nn.Sequential( + # nn.Linear(32, 32), + # nn.ReLU(), + nn.Sequential( + # M(), + nn.Linear(32, 32, bias=False), + nn.ReLU(), + ), + ).cuda() + swap_linear_with_float8_linear(m, Float8DynamicLinear) + print(m) + + m = torch.compile(m) + x = torch.randn(32, 32, device='cuda').requires_grad_() + y = m(x) + + print('\n', y, '\n') + + y.sum().backward() + + print('done') + + +if __name__ == '__main__': + run() diff --git a/dtensor_test_memory.pickle b/dtensor_test_memory.pickle new file mode 100644 index 0000000..7d5afff Binary files /dev/null and b/dtensor_test_memory.pickle differ diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 1f741b8..990e82d 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -23,6 +23,13 @@ Float8ColwiseParallel, Float8RowwiseParallel, ) +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) from float8_experimental.float8_utils import tensor_to_scale from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh, init_device_mesh @@ -221,6 +228,73 @@ def test_fp8_mlp_tensor_parallelism_base( tp_model.out_proj.weight.grad, sp_model.out_proj.weight.grad ) +def get_cuda_mem_allocated_gb(): + return torch.cuda.max_memory_allocated() / 1e9 + +class EmbLNLinear(nn.Module): + def __init__(self, dim0, dim1, dim2): + super().__init__() + self.emb = nn.Embedding(dim0, dim1) + self.ln = nn.LayerNorm(dim1) + self.fc = nn.Linear(dim1, dim2) + + def forward(self, x): + x = self.emb(x) + x = self.ln(x) + x = self.fc(x) + return x + +def test_fp8_compile_tp_sp_oom( + mesh: DeviceMesh, size=16, compile: bool = False +): + """ + A standalone repro of the OOM we observed on LLaMa 3 8B in torchtitan + with float8, compile, TP and SP on. When you run this test you should + see a memory leak, as evidenced by printouts of cuda memory used as well + as tensors not beeing freed in dumped the memory snapshot. + + TODO: root cause the issue and write a better test once we fix it. + """ + + vocab_size = 128256 + model_dim = 4096 + device = mesh.device_type + bsz = 1 + world_size = mesh.size() + + m = EmbLNLinear(vocab_size, model_dim, model_dim).cuda() + m = swap_linear_with_float8_linear( + m, Float8DynamicLinear, emulate=True + ) + + tokens = torch.ones(bsz, model_dim * world_size, device=device, dtype=torch.int64) + + m = parallelize_module( + m, + mesh, + { + "emb": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "ln": SequenceParallel(), + "fc": Float8ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Replicate(), + use_local_output=True, + ), + }, + ) + + m = torch.compile(m, dynamic=False) + torch.cuda.memory._record_memory_history() + for i in range(100): + print(i, get_cuda_mem_allocated_gb()) + y = m(tokens) + y.sum().backward() + torch.cuda.memory._dump_snapshot("dtensor_test_memory.pickle") + print('done') + def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) @@ -231,6 +305,11 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): # other test files to not use TestCase but instead just add the test # cases in the main func. device_mesh = setup_distributed() + test_fp8_compile_tp_sp_oom(device_mesh) + # TODO(before land): remove early return, this is for debugging only + import sys; sys.exit(0) + + tests = [ test_scaled_mm, test_fp8_redistribute,