Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] DDP GPT1.5B Torch.compile dynamo error #1308

Open
OrenLeung opened this issue Nov 19, 2024 · 1 comment
Open

[float8] DDP GPT1.5B Torch.compile dynamo error #1308

OrenLeung opened this issue Nov 19, 2024 · 1 comment

Comments

@OrenLeung
Copy link

OrenLeung commented Nov 19, 2024

Hi Torch Team,

I am currently experimenting with native torch float8 distributed training using the delayed scaling recipe on GPT 1.5B with DDP at batch=12 seq=1024 on an HGX 8xH100 (700W H100 SXM 80G SKU).

Currently, I am running into a DDP + torch.compile + float8 bug. Without enabling torch.compile it don't run into this error. I have tried using #1306 as well as main@latest Attached below is a self contained reprod & the Error Trace.

Commands

python3 test.py --enable_compile=False
python3 test.py --enable_compile=True

Error Trace

    submod_compiler.run(*example_inputs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 175, in run
    raise RuntimeError(*e.args) from e
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
RuntimeError: val

While executing %submod_479 : [num_users=1] = call_module[target=submod_479](args = (%l_self_modules_tsfmr_blks_modules_47_modules_ffn_modules_2_parameters_weight_, %l_self_modules_tsfmr_blks_modules_47_modules_ffn_modules_2_buffers_fp8_amax_weight_), kwargs = {})
Original traceback:
None

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Reprod Script

import torch
import torch.nn as nn
from torchao.float8 import (
    convert_to_float8_training,
    sync_float8_amax_and_scale_history,
    Float8LinearConfig,
    ScalingType,
    CastConfig,
)
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
import fire
import torch.multiprocessing as mp
import os
from torch.distributed import init_process_group


class CausalSelfAttention(nn.Module):
    def __init__(self, d_embd, n_heads, **kwargs):
        super().__init__()
        self.d_head = d_embd // n_heads  # D
        self.attn_proj = nn.Linear(d_embd, 3*d_embd)
        self.out_proj = nn.Linear(d_embd, d_embd)
 
    def forward(self, x_BTE):
        qkv = self.attn_proj(x_BTE).split(x_BTE.size(-1), -1)
        split_attn_head = lambda z: z.unflatten(-1, [-1, self.d_head]).transpose(1, 2)
        q_BHTD, k_BHTD, v_BHTD = map(split_attn_head, qkv)
        o_BHTD = F.scaled_dot_product_attention(q_BHTD, k_BHTD, v_BHTD, dropout_p=0.0, is_causal=True)
        o_BTE = o_BHTD.transpose(1, 2).flatten(-2)
        y_BTE = self.out_proj(o_BTE)
        return y_BTE

class GPTBlock(nn.Module):
    def __init__(self, d_embd, **kwargs):
        super().__init__()
        self.attn_norm = nn.LayerNorm(d_embd)
        self.attn = CausalSelfAttention(d_embd, **kwargs)
        self.ffn_norm = nn.LayerNorm(d_embd)
        self.ffn = nn.Sequential(
            nn.Linear(d_embd, 4*d_embd),
            nn.GELU(),
            nn.Linear(4*d_embd, d_embd)
        )

    def forward(self, x_BTE):
        x_BTE = x_BTE + self.attn(self.attn_norm(x_BTE))
        y_BTE = x_BTE + self.ffn(self.ffn_norm(x_BTE))
        return y_BTE

class GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.pos_embd = nn.Embedding(max_seq_len, d_embd)
        self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
        self.out_norm = nn.LayerNorm(d_embd)

    def forward(self, idx_BT, **kwargs):
        pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
        x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)

        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE)

        x_BTE = self.out_norm(x_BTE)
        logits_BTV = x_BTE @ self.tok_embd.weight.T  # Weight tying

        return logits_BTV

def main(enable_compile=True):
    
    train_args = (enable_compile,)
    
    mp.spawn(train, train_args, nprocs=8)

def train(rank, enable_compile=True):
    world_size = 8
    # configure delayed scaling
    config = Float8LinearConfig(
        cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
        cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
        cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
        # enable_amax_init=False,  # only needed for autocast + compile + FSDP +  float8 delayed
        # enable_pre_and_post_forward=False  # only needed for autocast + compile + FSDP +  float8 delayed
)   

    
    torch.manual_seed(3985)
    os.environ.update({'MASTER_ADDR': 'localhost', 'MASTER_PORT': '30985'})
    torch.cuda.set_device(rank)
    init_process_group(backend='nccl', rank=rank, world_size=world_size)
    
    # GPT 1.5B
    cfg_json = {
        "n_layers": 48,
        "n_heads": 25,
        "d_embd": 1600,
        "max_seq_len": 1024,
        "vocab_size": 50304,
        "arch_name": "gpt"
    }
    model = GPT(**cfg_json).to(rank)
    
    N = sum(p.numel() for p in model.parameters())  # get param count

    flops_per_iter = 6 * N * 12 * 1024
    
    optimizer = torch.optim.AdamW(model.parameters(), fused=True)


    convert_to_float8_training(model, config=config)

    model = DDP(model, gradient_as_bucket_view=True)
    if enable_compile:
        model = torch.compile(model)
    
    for step_idx in range(100):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        input_BT = torch.randint(50304, [12, 1024], dtype=torch.int64).to(rank)
        label_BT = torch.randint(50304, [12, 1024], dtype=torch.int64).to(rank)

        start.record()
        with torch.amp.autocast('cuda', torch.bfloat16):
            logits_BTV = model(input_BT)
            loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())

        loss.backward()

        sync_float8_amax_and_scale_history(model)

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        end.record()

        torch.cuda.synchronize()
        t = start.elapsed_time(end) / 1e3
        flops_per_sec = flops_per_iter / t
        print(f"finish {step_idx} step: {(flops_per_sec/1e12):.2f} TFLOP/s")

if __name__ == "__main__":
    fire.Fire(main)

Torch Versions

pip list | grep torch
pytorch-triton               3.1.0+cf34004b8a
torch                        2.6.0.dev20241118+cu124
torch-tb-profiler            0.4.3
torchao                      0.7.0+git4402195e       $PATH/ao
@vkuzo
Copy link
Contributor

vkuzo commented Nov 19, 2024

Hi @OrenLeung , I also repro this. We haven't worked on enabling float8 + compile + DDP yet as we found that FSDP is significantly more common in jobs which are large enough to benefit from float8 training. Wondering if you are open to FSDP with NO_SHARD instead of DDP? Context: https://discuss.pytorch.org/t/difference-between-ddp-vs-fsdp-no-shard/209729

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants