Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[not for land] standalone repro of memory leak on float8 + compile + …
Browse files Browse the repository at this point in the history
…tp + sp

Summary:

Test Plan:

```
./test/test_dtensor.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed May 15, 2024
1 parent cb55df2 commit eac181a
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 0 deletions.
44 changes: 44 additions & 0 deletions benchmarks/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Can we benchmark a transformer block and meaningfully reason
about what inductor is doing?
"""

import torch
import torch.nn as nn

from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear

class M(nn.Module):
def forward(self, x):
return torch.cos(x)

def run():

m = nn.Sequential(
# nn.Linear(32, 32),
# nn.ReLU(),
nn.Sequential(
# M(),
nn.Linear(32, 32, bias=False),
nn.ReLU(),
),
).cuda()
swap_linear_with_float8_linear(m, Float8DynamicLinear)
print(m)

m = torch.compile(m)
x = torch.randn(32, 32, device='cuda').requires_grad_()
y = m(x)

print('\n', y, '\n')

y.sum().backward()

print('done')


if __name__ == '__main__':
run()
Binary file added dtensor_test_memory.pickle
Binary file not shown.
79 changes: 79 additions & 0 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
Float8ColwiseParallel,
Float8RowwiseParallel,
)
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)
from float8_experimental.float8_utils import tensor_to_scale
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
Expand Down Expand Up @@ -221,6 +228,73 @@ def test_fp8_mlp_tensor_parallelism_base(
tp_model.out_proj.weight.grad, sp_model.out_proj.weight.grad
)

def get_cuda_mem_allocated_gb():
return torch.cuda.max_memory_allocated() / 1e9

class EmbLNLinear(nn.Module):
def __init__(self, dim0, dim1, dim2):
super().__init__()
self.emb = nn.Embedding(dim0, dim1)
self.ln = nn.LayerNorm(dim1)
self.fc = nn.Linear(dim1, dim2)

def forward(self, x):
x = self.emb(x)
x = self.ln(x)
x = self.fc(x)
return x

def test_fp8_compile_tp_sp_oom(
mesh: DeviceMesh, size=16, compile: bool = False
):
"""
A standalone repro of the OOM we observed on LLaMa 3 8B in torchtitan
with float8, compile, TP and SP on. When you run this test you should
see a memory leak, as evidenced by printouts of cuda memory used as well
as tensors not beeing freed in dumped the memory snapshot.
TODO: root cause the issue and write a better test once we fix it.
"""

vocab_size = 128256
model_dim = 4096
device = mesh.device_type
bsz = 1
world_size = mesh.size()

m = EmbLNLinear(vocab_size, model_dim, model_dim).cuda()
m = swap_linear_with_float8_linear(
m, Float8DynamicLinear, emulate=True
)

tokens = torch.ones(bsz, model_dim * world_size, device=device, dtype=torch.int64)

m = parallelize_module(
m,
mesh,
{
"emb": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"ln": SequenceParallel(),
"fc": Float8ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate(),
use_local_output=True,
),
},
)

m = torch.compile(m, dynamic=False)
torch.cuda.memory._record_memory_history()
for i in range(100):
print(i, get_cuda_mem_allocated_gb())
y = m(tokens)
y.sum().backward()
torch.cuda.memory._dump_snapshot("dtensor_test_memory.pickle")
print('done')


def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)
Expand All @@ -231,6 +305,11 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
# other test files to not use TestCase but instead just add the test
# cases in the main func.
device_mesh = setup_distributed()
test_fp8_compile_tp_sp_oom(device_mesh)
# TODO(before land): remove early return, this is for debugging only
import sys; sys.exit(0)


tests = [
test_scaled_mm,
test_fp8_redistribute,
Expand Down

0 comments on commit eac181a

Please sign in to comment.