Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[not for land] standalone repro of memory leak on float8 + compile + … #260

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
Float8ColwiseParallel,
Float8RowwiseParallel,
)
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)
from float8_experimental.float8_utils import tensor_to_scale
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
Expand Down Expand Up @@ -221,6 +228,73 @@ def test_fp8_mlp_tensor_parallelism_base(
tp_model.out_proj.weight.grad, sp_model.out_proj.weight.grad
)

def get_cuda_mem_allocated_gb():
return torch.cuda.max_memory_allocated() / 1e9

class EmbLNLinear(nn.Module):
def __init__(self, dim0, dim1, dim2):
super().__init__()
self.emb = nn.Embedding(dim0, dim1)
self.ln = nn.LayerNorm(dim1)
self.fc = nn.Linear(dim1, dim2)

def forward(self, x):
x = self.emb(x)
x = self.ln(x)
x = self.fc(x)
return x

def test_fp8_compile_tp_sp_oom(
mesh: DeviceMesh, size=16, compile: bool = False
):
"""
A standalone repro of the OOM we observed on LLaMa 3 8B in torchtitan
with float8, compile, TP and SP on. When you run this test you should
see a memory leak, as evidenced by printouts of cuda memory used as well
as tensors not beeing freed in dumped the memory snapshot.

TODO: root cause the issue and write a better test once we fix it.
"""

vocab_size = 128256
model_dim = 4096
device = mesh.device_type
bsz = 1
world_size = mesh.size()

m = EmbLNLinear(vocab_size, model_dim, model_dim).cuda()
m = swap_linear_with_float8_linear(
m, Float8DynamicLinear, emulate=True
)

tokens = torch.ones(bsz, model_dim * world_size, device=device, dtype=torch.int64)

m = parallelize_module(
m,
mesh,
{
"emb": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"ln": SequenceParallel(),
"fc": Float8ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate(),
use_local_output=True,
),
},
)

m = torch.compile(m, dynamic=False)
torch.cuda.memory._record_memory_history()
for i in range(100):
print(i, get_cuda_mem_allocated_gb())
y = m(tokens)
y.sum().backward()
torch.cuda.memory._dump_snapshot("dtensor_test_memory.pickle")
print('done')


def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)
Expand All @@ -231,6 +305,11 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
# other test files to not use TestCase but instead just add the test
# cases in the main func.
device_mesh = setup_distributed()
test_fp8_compile_tp_sp_oom(device_mesh)
# TODO(before land): remove early return, this is for debugging only
import sys; sys.exit(0)


tests = [
test_scaled_mm,
test_fp8_redistribute,
Expand Down
Loading