Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

enumerate breakages of torch.compile + Float8Linear + FSDP/TP/SP #168

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __tensor_flatten__(self):
return ["_data", "_scale"], ctx

@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, metadata):
def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride):
assert len(inner_tensors) == 2
return Float8Tensor(
inner_tensors["_data"],
Expand Down
38 changes: 38 additions & 0 deletions test/fsdp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import os

import torch
import torch.distributed as dist
import torch.nn as nn

from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)

def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"

# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
dist.destroy_process_group()


def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
m = nn.Sequential(
nn.Linear(K, N, dtype=base_dtype),
nn.Linear(N, N, dtype=base_dtype),
)
if is_fp8:
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate)
return m

1 change: 1 addition & 0 deletions test/test_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pytest test/test_base.py
pytest test/test_sam.py
pytest test/test_compile.py
./test/test_fsdp.sh
./test/test_fsdp_compile.sh
./test/test_tp.sh

echo "all tests successful"
24 changes: 2 additions & 22 deletions test/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
StateDictType,
)

from fsdp_utils import setup, cleanup, get_model

torch.manual_seed(0)

# assumes user is running the script from /data/users/{user}/float8_experimental
Expand All @@ -50,28 +52,6 @@
N_ITER = 1


def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"

# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
dist.destroy_process_group()


def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
m = nn.Sequential(
nn.Linear(K, N, dtype=base_dtype),
nn.Linear(N, N, dtype=base_dtype),
)
if is_fp8:
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate)
return m


# taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
# and modified
def fsdp_main(rank, world_size, args):
Expand Down
138 changes: 138 additions & 0 deletions test/test_fsdp_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
Smoke tests of FSDP + compile + Float8Linear
"""

import os
import warnings

import fire

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from float8_experimental.float8_linear_utils import (
sync_float8_amax_and_scale_history,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from fsdp_utils import setup, cleanup, get_model

torch.manual_seed(0)

B, M, K, N = 8, 8, 32, 32
lr = 0.01
N_ITER = 3


def test_no_compile(world_size, emulate, base_dtype, rank, ref_input_local):
model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=base_dtype).to(
rank
)
model = FSDP(model, use_orig_params=True)
optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size)

for _ in range(N_ITER):
optimizer.zero_grad()
y_local = model(ref_input_local)
y_local.sum().backward()
sync_float8_amax_and_scale_history(model)
optimizer.step()

dist.barrier()

def test_fsdp_then_compile_with_workaround(world_size, emulate, base_dtype, rank, ref_input_local):
model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=base_dtype).to(
rank
)
model = FSDP(model, use_orig_params=True)
optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size)
sync_func = torch.compile(sync_float8_amax_and_scale_history)

for _ in range(N_ITER):
optimizer.zero_grad()
y_local = model(ref_input_local)
y_local.sum().backward()
sync_func(model)
optimizer.step()

if _ == 0:
# right now things only work if we compile after the first iteration
# otherwise, we get https://gist.github.com/vkuzo/665e27a4d362f3999ad9a9e786acbe02
# TODO(future): fix this
model = torch.compile(model)

dist.barrier()

def test_compile_then_fsdp(world_size, emulate, base_dtype, rank, ref_input_local):
model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=base_dtype).to(
rank
)
model = torch.compile(model)
model = FSDP(model, use_orig_params=True)
optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size)
sync_func = torch.compile(sync_float8_amax_and_scale_history)

for _ in range(N_ITER):
optimizer.zero_grad()
y_local = model(ref_input_local)
y_local.sum().backward()
sync_func(model)
optimizer.step()

dist.barrier()


# taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
# and modified
def fsdp_main(rank, world_size, args):
setup(rank, world_size)
torch.cuda.set_device(rank)

emulate, = args
base_dtype = torch.bfloat16
ref_input_global = torch.randn(B, M, K).cuda().to(base_dtype)
# basic distributed data sampling
bsz_global = ref_input_global.shape[0]
assert B % world_size == 0
bsz_local_start = int(rank / world_size * B)
bsz_local_end = int((rank + 1) / world_size * B)
ref_input_local = ref_input_global[bsz_local_start:bsz_local_end].to(rank)

test_args = world_size, emulate, base_dtype, rank, ref_input_local

test_no_compile(*test_args)
# TODO(future): remove the workaround
test_fsdp_then_compile_with_workaround(*test_args)
# TOOD(future): unbreak this if needed
# test_compile_then_fsdp(*test_args)
# fails with https://gist.github.com/vkuzo/d7c65a073ebf47d64aa5b1a56df171c6

cleanup()


def run():

emulate = False
if not torch.cuda.is_available():
warnings.warn("CUDA not available, running in emulation_mode")
emulate = True
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode"
)
emulate = True

WORLD_SIZE = torch.cuda.device_count()
args = (emulate,)
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)


if __name__ == "__main__":
fire.Fire(run)
7 changes: 7 additions & 0 deletions test/test_fsdp_compile.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash

# the NCCL_DEBUG setting is to avoid log spew
# the CUDA_VISIBLE_DEVICES setting is for easy debugging
NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/test_fsdp_compile.py

echo "done!"
22 changes: 17 additions & 5 deletions test/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ def setup_model_parallel():
return local_rank, world_size


def test_column_parallel_linear():
def test_column_parallel_linear(use_float8=True, use_compile=False):
M, K, N = 128, 64, 256
m_ref = nn.Sequential(ColumnParallelLinear(K, N)).cuda()
m = copy.deepcopy(m_ref)
swap_tp_linear_with_float8_linear(m)
if use_float8:
swap_tp_linear_with_float8_linear(m)
if use_compile:
m = torch.compile(m)

x = torch.randn(M, K, device="cuda")
y_ref = m_ref(x)
Expand All @@ -51,7 +54,7 @@ def test_column_parallel_linear():
y.sum().backward()

sqnr_y = compute_error(y_ref, y)
sqnr_w_grad = compute_error(m_ref[0].weight.grad, m[0].weight.grad)
sqnr_w_grad = compute_error(m_ref[0].weight.grad, getattr(m, '0').weight.grad)

assert sqnr_y >= 20.0, f"sqnr_y {sqnr_y} is too low"
assert sqnr_w_grad >= 20.0, f"sqnr_w_grad {sqnr_w_grad} is too low"
Expand Down Expand Up @@ -109,7 +112,7 @@ def test_ffn():
sqnr_w2_grad = compute_error(m_ref.w2.weight.grad, m.w2.weight.grad)

assert sqnr_y >= 20.0, f"sqnr_y {sqnr_y} is too low"
assert sqnr_w1_grad >= 14.0, f"sqnr_w1_grad {sqnr_w1_grad} is too low"
assert sqnr_w1_grad >= 13.0, f"sqnr_w1_grad {sqnr_w1_grad} is too low"
assert sqnr_w2_grad >= 30.0, f"sqnr_w2_grad {sqnr_w2_grad} is too low"


Expand Down Expand Up @@ -150,7 +153,16 @@ def test_ffn_sp(local_rank, world_size):

if __name__ == "__main__":
local_rank, world_size = setup_model_parallel()
test_column_parallel_linear()
test_column_parallel_linear(use_float8=True, use_compile=False)

# below passes, but a lot of graph breaks:
# https://gist.github.com/vkuzo/670b2806e222bef04da5f173c758a165
# test_column_parallel_linear(use_float8=False, use_compile=True)

# below fails with
# https://gist.github.com/vkuzo/c9891ab38c8f341243b393e8e07d40d4https://gist.github.com/vkuzo/c9891ab38c8f341243b393e8e07d40d4
# test_column_parallel_linear(use_float8=True, use_compile=True)

test_row_parallel_linear()
test_ffn()
test_ffn_sp(local_rank, world_size)
Loading