Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP 8xMI300X]: LLama3 70B 4 Layer Proxy Model GPU Core Dumps #78

Open
OrenLeung opened this issue Oct 15, 2024 · 24 comments
Open

[FSDP 8xMI300X]: LLama3 70B 4 Layer Proxy Model GPU Core Dumps #78

OrenLeung opened this issue Oct 15, 2024 · 24 comments
Assignees

Comments

@OrenLeung
Copy link
Contributor

OrenLeung commented Oct 15, 2024

Problem Description

On Llama3 70B Proxy Model, the training stalls & gpucore dumps. The gpucore dumps are 41GByte per GPU thus i am unable to send it. Probably easier for yall to reprod this error on your end to get the gpucore dump.

I have verified on H100, te fp8 for llama3 70B fsdp 4 layer model model trains perfectly fine with a 38% TFLOP/s/GPU increase compared to bf16 torch.compile

cc: @hliuca

image

Operating System

Ubuntu

CPU

AMD CPU

GPU

MI300X

ROCm Version

ROCm 6.2.0

ROCm Component

No response

Steps to Reproduce

Docker Image

FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0

RUN apt install nano

RUN pip install uv

RUN uv pip install --system ipython pytest fire pydantic pybind11

RUN pip3 uninstall -y torch

RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]

TE install Instructions (done inside docker container)

cd /workspace
git clone --recursive https://github.com/ROCm/TransformerEngine.git
export NVTE_USE_HIPBLASLT=1
export NVTE_FRAMEWORK=pytorch
export PYTORCH_ROCM_ARCH=gfx942
cd TransformerEngine && pip install .
cd /workspace/llm-train-bench

Reprod Script

from dataclasses import asdict
from typing import Optional
from pydantic.dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

# DDP
import os
import torch.multiprocessing as mp
from torch.distributed import init_process_group, destroy_process_group

# FSDP
from functools import partial
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from tqdm import tqdm

# FP8 Transformer Engine
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp

def dprint(rank, *args, **kwargs):
    if rank == 0:
        print(*args, **kwargs)
        
class DummyDataset(Dataset):
    def __init__(self, vocab_size, max_seq_len, ds_len):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.ds_len = ds_len

    def __getitem__(self, idx):
        input_T = torch.randint(self.vocab_size, [self.max_seq_len], dtype=torch.int64)
        label_T = torch.cat([input_T[:-1], torch.randint(self.vocab_size, [1])])
        return input_T, label_T

    def __len__(self):
        return self.ds_len
        
def create_distributed_data_loader(rank, world_size, bsz, n_steps, cfg_m):
    dataset = DummyDataset(cfg_m.vocab_size, cfg_m.max_seq_len, bsz*n_steps)
    data_loader = DataLoader(
        dataset, batch_size=bsz,
        num_workers=8, pin_memory=True, shuffle=False,
        sampler=DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=True)
    )
    
    return data_loader


def configure_train_loop(data_loader, cfg_m, bsz, rank=0):
    if rank != 0:
        for step_idx, data_batch in enumerate(data_loader):
            yield step_idx, data_batch
        return

    flops_per_iter = cfg_m.flops_per_token * (bsz * cfg_m.max_seq_len)

    flops_promised = 2610e12
    
    with tqdm(total=len(data_loader)) as pbar:
        for step_idx, data_batch in enumerate(data_loader):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()

            yield step_idx, data_batch

            end.record()
            torch.cuda.synchronize()

            t = start.elapsed_time(end) / 1e3
            flops_per_sec = flops_per_iter / t
            mfu = flops_per_sec / flops_promised

            pbar.set_description(f'[rank0]  {(flops_per_sec/1e12):.2f} TFLOP/s  MFU={mfu:.2%}')
            pbar.update()

@dataclass
class LLaMAConfig:
    n_layers: int    # L
    n_heads: int     # H
    n_kv_heads: int  # J
    d_embd: int      # E
    max_seq_len: int # T
    vocab_size: int  # V
    ffn_mult: float
    ffn_factor: int
    rope_base: float
    norm_eps: float
    d_hid: int = Optional[int] # K
    arch_name: str = 'llama'

    def estimate_flops_per_token(self, model, bsz, rank=0):
        head_dim = self.d_embd // self.n_heads
        N = sum(p.numel() for p in model.parameters())  # get param count

        if rank == 0:
            print(f"Number of parameters: {N/1e9:.2f}B")    # print number of billion parameters 

        self.flops_per_token = 6 * N + 12 * self.n_layers * self.n_heads * head_dim * self.max_seq_len

    def __post_init__(self):
        assert self.d_embd % self.n_heads == 0, 'd_embd must be a multiple of n_heads.'
        assert self.d_embd % self.n_kv_heads == 0, 'd_embd must be a multiple of n_kv_heads.'
        assert self.n_kv_heads <= self.n_heads, 'n_kv_heads must not be larger than n_heads.'

        # FFN hidden dimension
        d_hid = int((4 * self.d_embd) * 2 / 3)
        d_hid = int(d_hid * self.ffn_mult)
        self.d_hid = self.ffn_factor * ((d_hid + self.ffn_factor - 1) // self.ffn_factor)                

class Fp8LLaMA(nn.Module):
    def __init__(self, vocab_size, d_embd, n_layers, n_heads, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.tsfmr_blks = nn.ModuleList(
            Fp8LLaMABlock(d_embd, n_heads=n_heads, **kwargs) for _ in range(n_layers)
        )
        self.norm_lm_head = te.LayerNormLinear(
            d_embd, vocab_size, bias=False,
            normalization='RMSNorm', eps=kwargs['norm_eps']
        )

        # Reference: https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json
        freq_cis_TE = te.attention.RotaryPositionEmbedding(d_embd//n_heads)(max_seq_len=131072)
        self.register_buffer('freq_cis_TE', freq_cis_TE.to(torch.bfloat16))

    def forward(self, idx_BT, is_first_microbatch):
        x_BTE = self.tok_embd(idx_BT)
        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE, rotary_pos_emb=self.freq_cis_TE, is_first_microbatch=is_first_microbatch)
        logits_BTV = self.norm_lm_head(x_BTE)
        return logits_BTV


class Fp8LLaMABlock(te.TransformerLayer):
    ''' Reference Implementation:
    https://github.com/NVIDIA/TransformerEngine/blob/55dcbb4b02f560d52dc1215a9de348b37487ee3d/docs/examples/te_llama/te_llama.py#L42
    '''
    def __init__(self, d_embd, d_hid, n_heads, n_kv_heads, norm_eps, **kwargs):
        super().__init__(
            hidden_size=d_embd,
            num_attention_heads=n_heads,
            num_gqa_groups=n_heads//n_kv_heads,
            fuse_qkv_params=True,
            attn_input_format='bshd',
            attention_dropout=0.0,
            normalization='RMSNorm',
            layernorm_epsilon=norm_eps,
            ffn_hidden_size=d_hid,
            bias=False,
            activation='swiglu',
            hidden_dropout=0.0
        )

def train(
    bsz: int = 10,
):

    torch.manual_seed(3985)
    world_size = torch.cuda.device_count()
    train_args = (
        world_size,
        bsz
    )
    try:
        mp.spawn(train_fsdp, train_args, nprocs=world_size)
    except:
        destroy_process_group()


def train_fsdp(
    rank, world_size, bsz
):
    # Construct process group
    os.environ.update({'MASTER_ADDR': 'localhost', 'MASTER_PORT': '30985'})
    torch.cuda.set_device(rank)
    init_process_group(backend='nccl', rank=rank, world_size=world_size)
    
    cfg = {
        "n_layers": 4,
        "n_heads": 64,
        "n_kv_heads": 8,
        "d_embd": 8192,
        "max_seq_len": 4096,
        "vocab_size": 128256,
        "ffn_mult": 1.3,
        "ffn_factor": 1024,
        "rope_base": 500000.0,
        "norm_eps": 1e-05,
        "d_hid": 28672,
        "arch_name": "llama"
    }
    
    use_fp8 = True
    grad_acc_steps = 8
    n_steps = 128*8
    # Configure training setup
    cfg_m, model_cls, blk_cls = LLaMAConfig(**cfg), Fp8LLaMA, Fp8LLaMABlock
    model = model_cls(**asdict(cfg_m)).to(rank)
    dprint(rank, f'Loaded {model_cls} model.', end=' ')
    cfg_m.estimate_flops_per_token(model, bsz, rank)  # Need to do before wrapping in FSDP

    data_loader = create_distributed_data_loader(rank, world_size, bsz, n_steps, cfg_m)
    optimizer = torch.optim.AdamW(model.parameters(), fused=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda t: 1.0)

    # FSDP
    model = FSDP(
        model,
        device_id=rank,
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16
        ),
        auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls={blk_cls}),
        use_orig_params=True
    )
    dprint(rank, f'Created FSDP model')

    prepare_te_modules_for_fsdp(model)
    dprint(rank, 'Sharded TE modules for FSDP')

    # Training loop
    loop_iter = configure_train_loop(data_loader, cfg_m, bsz, rank)
    model.train()
    
    fp8_format = Format.HYBRID  # E4M3 during forward pass, E5M2 during backward pass
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
    all_gpus = dist.new_group(backend='nccl')


    for step_idx, data_batch in loop_iter:
        input_BT, label_BT = map(lambda t: t.pin_memory().to(rank), data_batch)

        with torch.amp.autocast('cuda', torch.bfloat16):
            with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
                weight_cache = use_fp8 and (step_idx % grad_acc_steps == 0)
                logits_BTV = model(input_BT, is_first_microbatch=weight_cache)
                loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
                loss /= grad_acc_steps

        loss.backward()

        if (step_idx + 1) % grad_acc_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)


    dist.barrier()
    destroy_process_group()


if __name__ == '__main__':
    import fire
    fire.Fire(train)

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

@hliuca
Copy link

hliuca commented Oct 15, 2024

I will report this. Thanks Oren.

@OrenLeung
Copy link
Contributor Author

OrenLeung commented Oct 15, 2024

Thanks @hliuca ,

For further context, On MI300X BF16 torch.compile nightly, i get the following preliminary results:

  • 8xMI300X batch size 10: 512TFLOP/s/GPU
  • 8xMI300X batch size 12: 518.19TFLOP/s/GPU

In the reprod script, it is batch size = 10, I have can confirm that batch size 12 also causes gpucore dump

@OrenLeung
Copy link
Contributor Author

OrenLeung commented Oct 15, 2024

Interestingly when I do batch size 2 I do not gpu core dump but at this small of a batch size, the TFLOP/s/GPU is 491.22, which is 6% slower than bf16 at batch size 12.
** preliminary**

  • preliminary 8xMI300X FP8 TE batch size 2: 491.22 TFLOP/s/GPU
  • preliminary 8xMI300X FP8 TE batch size 4: causes it to stall & gpu core dump
  • preliminary 8xMI300X FP8 TE batch size 10: causes it to stall & gpu core dump

batch size 2 command

 python ./train_fsdp_llama_70_reprod.py --bsz=2

batch size 2 command

python ./train_fsdp_llama_70_reprod.py --bsz=4

batch size 10 (batch size 10 is the default in the same reprod script above so no need for args)

python ./train_fsdp_llama_70_reprod.py 

@wenchenvincent
Copy link
Collaborator

@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (#66).
Here are the numbers that I got with this PR:
8xMI300X FP8 TE batch size 2: 572 TFLOP/s/GPU
8xMI300X FP8 TE batch size 10: 696 TFLOP/s/GPU

@wenchenvincent wenchenvincent self-assigned this Oct 17, 2024
@OrenLeung
Copy link
Contributor Author

OrenLeung commented Oct 17, 2024

@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (#66). Here are the numbers that I got with this PR: 8xMI300X FP8 TE batch size 2: 572 TFLOP/s/GPU 8xMI300X FP8 TE batch size 10: 696 TFLOP/s/GPU

hi @wenchenvincent ,

Thanks for looking into this! Do you have an estimated ETA on when #66 will be merged? Since this is such a big PR, I will probably have to wait till it hits the main branch before I re-test. Probably will wait till #69 & transpose_cast_opt branch merge too.

I was also wondering which Dockerfile image you are using as the base image to obtain these results? And is this base image publicly accessible?

From your results, it does seem like your fp8 has better results than mi300x bf16. We estimate that TCO of mi300x is 78% of an h100. So to get competitve perf per $ results vs h100, mi300x fp8 will probably need to hit 742.2 TFLOP/s/GPU.

Is there other PRs or thoughts you have that would potentially help improve performance of mi300x te fp8?

cc: @hliuca

Here is my preliminary numbers on this gh issue's model (llama3 70B 4 Layer Proxy):

  • preliminary 8xMI300X BF16 batch size 8: 508 TFLOP/s/GPU
  • preliminary 8xMI300X BF16 batch size 10: 512.64 TFLOP/s/GPU
  • preliminary 8xMI300X BF16 batch size 12: 518.19 TFLOP/s/GPU
  • preliminary 8xMI300X BF16 batch size 14: OOM
  • preliminary 8xH100 BF16 batch size 2: 649.02 TFLOP/s/GPU
  • preliminary 8xH100 BF16 batch size 4: 687.13 TFLOP/s/GPU
  • preliminary 8xH100 TE FP8 batch size 2: 951.61 TFLOP/s/GPU
  • preliminary 8xH100 TE FP8 batch size 4: 759.99 TFLOP/s/GPU

@wenchenvincent
Copy link
Collaborator

@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (#66). Here are the numbers that I got with this PR: 8xMI300X FP8 TE batch size 2: 572 TFLOP/s/GPU 8xMI300X FP8 TE batch size 10: 696 TFLOP/s/GPU

hi @wenchenvincent ,

Thanks for looking into this! Do you have an estimated ETA on when #66 will be merged? Since this is such a big PR, I will probably have to wait till it hits the main branch before I re-test. Probably will wait till #69 & transpose_cast_opt branch merge too.

I was also wondering which Dockerfile image you are using as the base image to obtain these results? And is this base image publicly accessible?

From your results, it does seem like your fp8 has better results than mi300x bf16. We estimate that TCO of mi300x is 78% of an h100. So to get competitve perf per $ results vs h100, mi300x fp8 will probably need to hit 742.2 TFLOP/s/GPU.

Is there other PRs or thoughts you have that would potentially help improve performance of mi300x te fp8?

cc: @hliuca

Here is my numbers on this gh issue's model (llama3 70B 4 Layer Proxy):

  • 8xMI300X BF16 batch size 8: 508 TFLOP/s/GPU
  • 8xMI300X BF16 batch size 10: 512.64 TFLOP/s/GPU
  • 8xMI300X BF16 batch size 12: 518.19 TFLOP/s/GPU
  • 8xMI300X BF16 batch size 14: OOM
  • 8xH100 BF16 batch size 2: 649.02 TFLOP/s/GPU
  • 8xH100 BF16 batch size 4: 687.13 TFLOP/s/GPU
  • 8xH100 TE FP8 batch size 2: 951.61 TFLOP/s/GPU
  • 8xH100 TE FP8 batch size 4: 759.99 TFLOP/s/GPU

@OrenLeung #66 only needs a few minor changes and the bottleneck for merging it was our CI capability... But I expect that it would be merged this week.

I was using the same docker image that you used for producing the numbers.

I haven't got a chance to dump the traces of this model run yet, but I suspect that it might also suffer from the issue with fp8 cast transpose and some fp8 GEMM might not be tuned yet. So potentially the fp8 cast transpose optimization and fp8 GEMM tuning would further improve the performance.

@OrenLeung
Copy link
Contributor Author

OrenLeung commented Oct 17, 2024

Furthermore here is the preliminary H200 numbers. To be competitive with H200 on a perf per TCO basis, AMD needs to be at 910 TFLOP/s/GPU.

  • preliminary 8xH200 TE FP8 batch size 2: 991.56 TFLOP/s/GPU
  • preliminary 8xH200 TE FP8 batch size 4: 1107 TFLOP/s/GPU
  • preliminary 8xH200 TE FP8 batch size 8: 1167.52 TFLOP/s/GPU
  • preliminary 8xH200 TE FP8 batch size 10: OOM

@hliuca
Copy link

hliuca commented Oct 17, 2024

Thank you Oren for providing H200 data. These data are very valuable and helpful. Our TE team and other teams are actively working on all the issues you have filed.

@OrenLeung
Copy link
Contributor Author

OrenLeung commented Oct 17, 2024

Thank you Oren for providing H200 data. These data are very valuable and helpful. Our TE team and other teams are actively working on all the issues you have filed.

hi @hliuca ,

I am glad we were able to provide an optimization goal.

Please note that all of our H100 & H200 that we shared are preliminary and will probably improve too as I do tuning on them.

Also please note that we are benchmarking & evaluating AMD/Nvidia on other real world transformer models and real world GEMM training shapes that we have not shared with Nvidia or AMD to ensure that these patches to pytorch, te, hipblaslt, etc made are generalizable.

@hliuca
Copy link

hliuca commented Oct 17, 2024

Yes @OrenLeung totally understand. Thank you for driving us doing better job.

@OrenLeung
Copy link
Contributor Author

OrenLeung commented Oct 21, 2024

After #66 merged to main, I now get a prelimary number of 716.97 TFLOP/s/GPU on my internal codebase

After 32 Warmup: Mean TFLOP/s: 716.97 Mean MFU: 27.47%

Great work! @wenchenvincent !

I assume once triton transpose cast fused op & v3 ck attn merges, it will closer to H100's fp8 951.61 TFLOP/s/GPU

@wenchenvincent
Copy link
Collaborator

After #66 merged to main, I now get a prelimary number of 716.97 TFLOP/s/GPU on my internal codebase

After 32 Warmup: Mean TFLOP/s: 716.97 Mean MFU: 27.47%

Great work! @wenchenvincent !

I assume once triton transpose cast fused op & v3 ck attn merges, it will closer to H100's fp8 951.61 TFLOP/s/GPU

@OrenLeung Thank you!

@wangye805 had run this model on a different machine and he was getting 747 TFLOP/s. We're investigate why that system could give better performance and hope to make it reproducible.

Yeah, triton cast transpose should be give further improvement. And fp8 GEMM tuning in hipblasLt library and CK FA v3 should give more improvements. But for latter two, we will need to check the timeline internally.

@OrenLeung
Copy link
Contributor Author

@wenchenvincent interesting that a different machine gives a different TFLOP/s.

Note that before step 16, the TFLOPs in the reprod script usually fluctuates (as it warms up and does grad accum every 8 steps)

In my internal codebase, I usually do warmup of 32 steps then take the mean over 50 steps to get an accurate measurement of what the realistic TFLOP/s would be.

@wenchenvincent
Copy link
Collaborator

@wenchenvincent interesting that a different machine gives a different TFLOP/s.

Note that before step 16, the TFLOPs in the reprod script usually fluctuates (as it warms up and does grad accum every 8 steps)

In my internal codebase, I usually do warmup of 32 steps then take the mean over 50 steps to get an accurate measurement of what the realistic TFLOP/s would be.

It could be that the other machine has the newer version of kernel driver. And there are some system config tuning that might impact performance as well: https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/system.html

@OrenLeung
Copy link
Contributor Author

@wenchenvincent that is quite interesting tho my from understanding most of those knobs in the system tuning guide don't really effect text only transformer based models much since this class of models have very small DtoH and HtoD transfer and don't really use the CPU much. so tuning NUMA (NPS1, NPS4, etc.), etc doesn't really effect the performance.

I can see how those knobs will affect cpu dataloader heavy & heavy HtoD transfer models like image or video.

@wenchenvincent
Copy link
Collaborator

@wenchenvincent that is quite interesting tho my from understanding most of those knobs in the system tuning guide don't really effect text only transformer based models much since this class of models have very small DtoH and HtoD transfer and don't really use the CPU much. so tuning NUMA (NPS1, NPS4, etc.), etc doesn't really effect the performance.

I can see how those knobs will affect cpu dataloader heavy & heavy HtoD transfer models like image or video.

@OrenLeung Those knobs are for general MI300X system tuning. The most relevant knob to the GPU would be this one: https://rocm.docs.amd.com/en/latest/how-to/system-optimization/mi300x.html#deterministic-clock Sometimes, using the default frequency of 2100MHz for some workload would trigger PCC (Peak Current Control) event lowering the attainable GPU frequency.

@wenchenvincent
Copy link
Collaborator

Unfortunately, the machine that produced the better perf has been down in the past two days for maintenance and upgrade. Once it is up, we will continue to investigate why it could produce better numbers.

@wenchenvincent
Copy link
Collaborator

@OrenLeung Also, I think I might have forgotten to mention that we can use autotuning in TE to select the best performing kernels from hipBlasLt for specific GEMM size (if there are varieties of kernels for a specific gemm size): https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#gemm-tuning-with-hipblaslt

@wenchenvincent
Copy link
Collaborator

@OrenLeung Also, I think I might have forgotten to mention that we can use autotuning in TE to select the best performing kernels from hipBlasLt for specific GEMM size (if there are varieties of kernels for a specific gemm size): https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#gemm-tuning-with-hipblaslt

The perf number that I got was without autotuning though. Once we get the machine back up, we will try with autotuning to see how much we can get.

@OrenLeung
Copy link
Contributor Author

@wenchenvincent nice! I also seen that there is an autotuning storage PR, what was the timeline for that? Such that we don't need to autotone for every run and can just cache the optimal gemm selection

@wenchenvincent
Copy link
Collaborator

@wenchenvincent nice! I also seen that there is an autotuning storage PR, what was the timeline for that? Such that we don't need to autotone for every run and can just cache the optimal gemm selection

@OrenLeung The PR is under review and we're looking to merge it end of this week or early next week.

@wenchenvincent
Copy link
Collaborator

@OrenLeung We have the optimized cast transpose Triton kernel merged in. And with that, I got the following improvement:

8xMI300X FP8 TE batch size 10: 701 TFLOP/s -> 751.88 TFLOP/s

One of my colleagues got better number like 795 TFLOP/s with different machines and different dockers. I will check to see if I can attain that to reproduce his numbers.

@OrenLeung
Copy link
Contributor Author

hi @wenchenvincent thanks! can you send over the dockerfile?

@hliuca
Copy link

hliuca commented Nov 17, 2024

Hi @OrenLeung

image

Attached please find a dockerfile. I am working with dev teams to provide a final dockerfile in next few days. Meanwhile, if you like, you may try the following dockerfile, which provides nice perf. Thank you.

Dockerfile.rocm.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants