Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1xMI300X] GPT-2 XL 1.5B FP8 Training ~30% slower than H100 FP8 #72

Open
OrenLeung opened this issue Oct 13, 2024 · 22 comments · Fixed by #81
Open

[1xMI300X] GPT-2 XL 1.5B FP8 Training ~30% slower than H100 FP8 #72

OrenLeung opened this issue Oct 13, 2024 · 22 comments · Fixed by #81
Assignees

Comments

@OrenLeung
Copy link
Contributor

OrenLeung commented Oct 13, 2024

Problem Description

Hi AMD team,

When trying to do FP8 Training on MI300X, it is extremely slower due to extremely high cpu overhead taking up more than 81% of the time. As you can see from the profile, most of the time is spent in CPU & doing hipFree. On GPT-2 XL 1.5B, TFLOP/s is at 22 TFLOP/s. This is 10x slower than mi300x bf16.

For Comparsion, On H100 GPT-2 XL 1.5B, FP8 makes it to be 1.3x faster than BF16 H100. Not slower.

The Reprod Script is attached Below & can be ran using NVTE_FUSED_ATTN_CK=0 python3 ./train.py

image

image

cc: @hliuca

Steps to Reproduce

Versions

root@NODENAME:/workspace/llm-train-bench# pip list | grep torch
^[[Apytorch-triton-rocm     3.1.0+cf34004b8a
torch                   2.6.0.dev20241012+rocm6.2
torchvision             0.18.0a0+68ba7ec
root@NODENAME:/workspace/llm-train-bench# pip list | grep transformer
transformer_engine      1.8.0.dev0+691dc23

Install Instructions

FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0

RUN apt install nano

RUN pip install uv

RUN uv pip install --system ipython pytest fire pydantic pybind11

RUN pip3 uninstall -y torch

RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2


WORKDIR /workspace/

RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git
ENV NVTE_FRAMEWORK=pytorch
ENV PYTORCH_ROCM_ARCH=gfx942

RUN cd TransformerEngine && pip install .

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]

Reprod GPT2 XL 1.5B Training

import contextlib

import torch
import torch.nn.functional as F
import torch.nn as nn

from pydantic.dataclasses import dataclass

@dataclass
class GPTConfig:
    n_layers: int    # L
    n_heads: int     # H
    d_embd: int      # E
    max_seq_len: int = 1024
    vocab_size: int  = 50304 # V
    arch_name: str = 'gpt'

    @staticmethod
    def estimate_flops_per_token(model, config):
        # get param count
        N = sum(p.numel() for p in model.parameters())
        
        # print param count in B
        print(f"Param count: {N/1e9}B")
                 
        head_dim = config['d_embd'] // config['n_heads'] 
         
        flops_per_token = 6 * N + 12 * config['n_layers'] * config['n_heads'] * head_dim * config['max_seq_len']
        
        return flops_per_token

    def __post_init__(self):
        assert self.d_embd % self.n_heads == 0, 'd_embd must be a multiple of n_heads.'

class GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.pos_embd = nn.Embedding(max_seq_len, d_embd)
        
        
        # self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
        import transformer_engine.pytorch as te
        self.tsfmr_blks = nn.ModuleList(te.TransformerLayer(
                    d_embd,
                    d_embd * 4,
                    kwargs['n_heads'],
                    layer_number=i+1,
                    # Optional, for speedups
                    fuse_qkv_params=True,
                    attn_input_format='bshd'
                ) 
                for i in range(n_layers)                       
                )
        
        self.out_norm = nn.LayerNorm(d_embd)

    def forward(self, idx_BT):
        pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
        x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)

        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE)

        x_BTE = self.out_norm(x_BTE)
        logits_BTV = x_BTE @ self.tok_embd.weight.T  # Weight tying

        return logits_BTV

def train(
    gpu_id: int = 0,
    bsz: int = 8,
    grad_acc_steps: int = 8,
):
    torch.manual_seed(3985)
    torch.cuda.set_device(gpu_id)

    cfg_json = {
        "n_layers": 48,
        "n_heads": 25,
        "d_embd": 1600,
        "max_seq_len": 1024,
        "vocab_size": 50304,
    }

    cfg_m = GPTConfig(**cfg_json)
    model = GPT(**cfg_json).to(gpu_id)

    optimizer = torch.optim.AdamW(model.parameters(), fused=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda t: 1.0)

    flops_per_token = cfg_m.estimate_flops_per_token(model, cfg_json)
    flops_per_iter = flops_per_token * (bsz * cfg_m.max_seq_len)

    flops_promised = 2600e12

    model.train()
    
    import transformer_engine.pytorch as te
    from transformer_engine.common.recipe import Format, DelayedScaling
    fp8_format = Format.HYBRID
    # Reasonable default setting
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
    # Note: wrapped ctx in a function because the te.fp8_autocast object cannot be reused as a context for some reason.
    @contextlib.contextmanager
    def ctx():
        with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
            with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
                yield

    with ctx():
         for step_idx in range(100):
            input_BT = torch.randint(50304, [8, 1024], dtype=torch.int64).to('cuda:0')
            label_BT = torch.randint(50304, [8, 1024], dtype=torch.int64).to('cuda:0')
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()

            logits_BTV = model(input_BT)
            loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
            loss /= grad_acc_steps
            loss.backward()

            if (step_idx + 1) % grad_acc_steps == 0:  # Assume n_steps % grad_acc_steps == 0
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)

            end.record()
            torch.cuda.synchronize()

            t = start.elapsed_time(end) / 1e3
            flops_per_sec = flops_per_iter / t
            mfu = flops_per_sec / flops_promised

            print(f'{(flops_per_sec/1e12):.2f} TFLOP/s  MFU={mfu:.2%}')

if __name__ == '__main__':
    import fire
    fire.Fire(train)

Operating System

Ubuntu

CPU

AMD CPU

GPU

AMD Instinct MI300X

ROCm Version

ROCm 6.2.0

@OrenLeung OrenLeung changed the title GPT-2 XL 1.5B FP8 Training at 22 TFLOP/s MI300X GPT-2 XL 1.5B FP8 Training at 22 TFLOP/s Oct 13, 2024
@hliuca
Copy link

hliuca commented Oct 14, 2024

Hi @OrenLeung this has been reported. Thank you.

@OrenLeung
Copy link
Contributor Author

OrenLeung commented Oct 15, 2024

For reference numbers, GPT2-XL 1.5B BF16 Pytorch.compile nightly at the same batch size as the reprod script has:

  • MI300X: 215.79
  • H100: 478.73
  • H100 (fp8 transformer engine): 638.4

H100 Transformer Engine FP8 gets 1.3x faster than H100 BF16

Even when building with NVTE_USE_HIPBLASLT=1 the reprod script only gets:

  • 1xMI300X single GPU FP8 Transformer Engine: 156.92

This means that AMD FP8 is slower by 20%!!

@hliuca
Copy link

hliuca commented Oct 15, 2024

Hi @OrenLeung we have more people on the issues your reported and we will drive the fixes. Thank you.

@OrenLeung OrenLeung changed the title MI300X GPT-2 XL 1.5B FP8 Training at 22 TFLOP/s Single GPU MI300X GPT-2 XL 1.5B FP8 Training at 22 TFLOP/s Oct 15, 2024
@wangye805
Copy link
Contributor

Hi @OrenLeung, I tried to run the same script on our single GPU H100 machine but I only got around 100 TFLOPs:
image
Can you provide the reproduce instruction for H100 (to get 638.4 TFLOPs)?

Thanks.

@LiGuihong
Copy link

@OrenLeung Hi, I tried to reproduce your issue. I got 166TFLOPS on MI300. The major difference is I am using 2.6.0.dev20241014 since I cannot find the 20241012 version. May you please provide the system configuration on your side? Thanks.

@OrenLeung
Copy link
Contributor Author

OrenLeung commented Oct 16, 2024

hi @wangye805 ,

It seems like you are an H100 PCIe 350W card while I am on an H100 SXM 700W card

638.4 TFLOP/s/GPU was from H200 at batch size=38. I apologize for the error.

I have updated the following script using batch size = 14. Feel free to use a larger batch size if that helps mi300x TFLOP/s/GPU

  • 1xmi300x (fp8): 174.87
  • 1xh100 (fp8): 490
  • 1xh200 (fp8): 532.68

to change batch size just do python ./reprod.py --bsz=BATCH_SIZE

H100/H200 Dockerfile

FROM nvcr.io/nvidia/pytorch:24.09-py3

RUN pip install uv
RUN uv pip install --system ipython pytest fire pydantic

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]

Reprod Script

import contextlib

import torch
import torch.nn.functional as F
import torch.nn as nn

from pydantic.dataclasses import dataclass

@dataclass
class GPTConfig:
    n_layers: int    # L
    n_heads: int     # H
    d_embd: int      # E
    max_seq_len: int = 1024
    vocab_size: int  = 50304 # V
    arch_name: str = 'gpt'

    @staticmethod
    def estimate_flops_per_token(model, config):
        # get param count
        N = sum(p.numel() for p in model.parameters())
        
        # print param count in B
        print(f"Param count: {N/1e9}B")
                 
        head_dim = config['d_embd'] // config['n_heads'] 
         
        flops_per_token = 6 * N + 12 * config['n_layers'] * config['n_heads'] * head_dim * config['max_seq_len']
        
        return flops_per_token

    def __post_init__(self):
        assert self.d_embd % self.n_heads == 0, 'd_embd must be a multiple of n_heads.'

class GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.pos_embd = nn.Embedding(max_seq_len, d_embd)
        
        
        # self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
        import transformer_engine.pytorch as te
        self.tsfmr_blks = nn.ModuleList(te.TransformerLayer(
                    d_embd,
                    d_embd * 4,
                    kwargs['n_heads'],
                    layer_number=i+1,
                    # Optional, for speedups
                    fuse_qkv_params=True,
                    attn_input_format='bshd'
                ) 
                for i in range(n_layers)                       
                )
        
        self.out_norm = nn.LayerNorm(d_embd)

    def forward(self, idx_BT):
        pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
        x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)

        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE)

        x_BTE = self.out_norm(x_BTE)
        logits_BTV = x_BTE @ self.tok_embd.weight.T  # Weight tying

        return logits_BTV

def train(
    gpu_id: int = 0,
    bsz: int = 14,
    grad_acc_steps: int = 8,
):
    torch.manual_seed(3985)
    torch.cuda.set_device(gpu_id)

    cfg_json = {
        "n_layers": 48,
        "n_heads": 25,
        "d_embd": 1600,
        "max_seq_len": 1024,
        "vocab_size": 50304,
    }

    cfg_m = GPTConfig(**cfg_json)
    model = GPT(**cfg_json).to(gpu_id)

    optimizer = torch.optim.AdamW(model.parameters(), fused=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda t: 1.0)

    flops_per_token = cfg_m.estimate_flops_per_token(model, cfg_json)
    flops_per_iter = flops_per_token * (bsz * cfg_m.max_seq_len)

    flops_promised = 2600e12

    model.train()
    
    import transformer_engine.pytorch as te
    from transformer_engine.common.recipe import Format, DelayedScaling
    fp8_format = Format.HYBRID
    # Reasonable default setting
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
    # Note: wrapped ctx in a function because the te.fp8_autocast object cannot be reused as a context for some reason.
    @contextlib.contextmanager
    def ctx():
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
                yield

    with ctx():
         for step_idx in range(100):
            input_BT = torch.randint(50304, [bsz, 1024], dtype=torch.int64).to('cuda:0')
            label_BT = torch.randint(50304, [bsz, 1024], dtype=torch.int64).to('cuda:0')
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()

            logits_BTV = model(input_BT)
            loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
            loss /= grad_acc_steps
            loss.backward()

            if (step_idx + 1) % grad_acc_steps == 0:  # Assume n_steps % grad_acc_steps == 0
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)

            end.record()
            torch.cuda.synchronize()

            t = start.elapsed_time(end) / 1e3
            flops_per_sec = flops_per_iter / t
            mfu = flops_per_sec / flops_promised

            print(f'{(flops_per_sec/1e12):.2f} TFLOP/s  MFU={mfu:.2%}')

if __name__ == '__main__':
    import fire
    fire.Fire(train)

@OrenLeung
Copy link
Contributor Author

OrenLeung commented Oct 16, 2024

@OrenLeung Hi, I tried to reproduce your issue. I got 166TFLOPS on MI300. The major difference is I am using 2.6.0.dev20241014 since I cannot find the 20241012 version. May you please provide the system configuration on your side? Thanks.

Hi, it seems like we have similar results. I am getting 174 TFLOP/s/GPU on MI300X using the recommended hipblaslt backend. unfortunately still slower than bf16 mi300x

The system setup is provided in this gh issue already. Is there any specific information that you like to know?

@wangye805
Copy link
Contributor

I found a bug in use_fused_attention filtering and sent out PR#81 to fix this issue. With ck fused attn enabled, I can see our MI300X achieves 250 TFLOP/s with hipblaslt, compared to 300 TFLOP/s using H100, if we set the batch size to 8.

We will keep working on other optimizations to match the performance

@OrenLeung
Copy link
Contributor Author

hi @wangye805

Thanks for the fix! 250TFLOP/s is much better now.

Though, it still about 100 TFLOP/s off from H100 at the same batch size (bsz=8). and about 2.5x slower than H200 at bsz=38

here is my preliminary single gpu H100 700W SXM results:

  • preliminary 1xH100 700W SXM te fp8 (bsz=8) = 345.41 TFLOP/s
  • preliminary 1xH100 700W SXM te fp8 (bsz=10) = 398.58 TFLOP/s
  • preliminary 1xH100 700W SXM te fp8 (bsz=12) = 441.93 TFLOP/s
  • preliminary 1xH100 700W SXM te fp8 (bsz=14) = 490 TFLOP/s
  • preliminary 1xH200 700W SXM te fp8 (bsz=14) = 532.68 TFLOP/s
  • preliminary 1xH200 700W SXM te fp8 (bsz=38) = 630.1 TFLOP/s

cc: @hliuca

@OrenLeung
Copy link
Contributor Author

hi @wangye805 ,

Can we keep this issue open till mi300x single gpu is able to match h100?

@hliuca
Copy link

hliuca commented Oct 17, 2024

@OrenLeung I guess if @wangye805 increases bsz, the perf will increase too.

@OrenLeung
Copy link
Contributor Author

@OrenLeung I guess if @wangye805 increases bsz, the perf will increase too.

@hliuca

currently when both are at batch size 8, MI300x get 250TFLOP/s and H100 gets 345.41 TFLOP/s, so about a 30% difference.

But intuition is that this 30% difference will stay the same as we increase batch size for both

@hliuca
Copy link

hliuca commented Oct 17, 2024

Hi @OrenLeung I am not very sure about this. For many LLM I see, when we keep increasing workload, the gap will get smaller and smaller, and we can pass. Of course, that also depends on optimizations. We will keep working on this.

@wangye805 wangye805 reopened this Oct 17, 2024
@wangye805
Copy link
Contributor

@OrenLeung reopened this issue until we can match H100

@OrenLeung OrenLeung changed the title Single GPU MI300X GPT-2 XL 1.5B FP8 Training at 22 TFLOP/s [1xMI300X] GPT-2 XL 1.5B FP8 Training 33% slower than H100 FP8 Oct 17, 2024
@OrenLeung OrenLeung changed the title [1xMI300X] GPT-2 XL 1.5B FP8 Training 33% slower than H100 FP8 [1xMI300X] GPT-2 XL 1.5B FP8 Training ~30% slower than H100 FP8 Oct 17, 2024
@wangye805
Copy link
Contributor

With hipblaslt auto tunning, I can get 370 tflop/s with batch size 50

@OrenLeung
Copy link
Contributor Author

Thanks @wangye805, can you share me the exact env flags and what value should I set them to turn hipBlasLt auto tuning?

@wangye805
Copy link
Contributor

@OrenLeung You can use the following envs to turn on hipblaslt tunning: TE_HIPBLASLT_TUNING_RUN_COUNT=20 TE_HIPBLASLT_TUNING_ALGO_COUNT=1000

@OrenLeung
Copy link
Contributor Author

With hipblaslt auto tunning, I can get 370 tflop/s with batch size 50

@wangye805 at giant batch size, h100 can get 490 TFLOP/s, h200 can get 630 TFLOP/s. good improvement but seems like still 120 TFLOP/s difference

@OrenLeung
Copy link
Contributor Author

@OrenLeung You can use the following envs to turn on hipblaslt tunning: TE_HIPBLASLT_TUNING_RUN_COUNT=20 TE_HIPBLASLT_TUNING_ALGO_COUNT=1000

thanks! will definitely try that. tuning_algo_count=1000 lol, i guess we probably need #67 to store the tuning results?

@hliuca
Copy link

hliuca commented Nov 16, 2024

I have re-run the latest PE and got,

python 72-gpt2.py --bsz 38

1112.75 TFLOP/s MFU=42.80%
1113.10 TFLOP/s MFU=42.81%
1109.57 TFLOP/s MFU=42.68%
1114.25 TFLOP/s MFU=42.86%
1112.56 TFLOP/s MFU=42.79%
1114.90 TFLOP/s MFU=42.88%
1045.00 TFLOP/s MFU=40.19%
1143.37 TFLOP/s MFU=43.98%
1112.71 TFLOP/s MFU=42.80%
1116.26 TFLOP/s MFU=42.93%
1115.47 TFLOP/s MFU=42.90%

@OrenLeung
Copy link
Contributor Author

Hi @hliuca

I think you're using the wrong script. The original had a bug for the bsz Arg. Apologies on my end.

The updated script posted here has the fix.

#72 (comment)

@hliuca
Copy link

hliuca commented Nov 16, 2024

Thank you @OrenLeung

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants