diff --git a/colossalai/kernel/triton/rms_layernorm.py b/colossalai/kernel/triton/rms_layernorm.py index b514c7789a02..71a724008513 100644 --- a/colossalai/kernel/triton/rms_layernorm.py +++ b/colossalai/kernel/triton/rms_layernorm.py @@ -23,7 +23,6 @@ def _rmsnorm_kernel( eps, # epsilon to avoid division by zero BLOCK_SIZE: tl.constexpr, ): - # This triton kernel implements Root Mean Square Layer Norm (RMSNorm). # Map the program id to the row of X and Y it should compute. @@ -54,18 +53,19 @@ def _rmsnorm_kernel( def rms_layernorm(x, weight, eps): # allocate output y = torch.empty_like(x) - # reshape input data into 2D tensor + # reshape input data into 2D tensor, (total token, hidden_size) x_arg = x.reshape(-1, x.shape[-1]) M, N = x_arg.shape # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_SIZE: + if N > MAX_FUSED_SIZE: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32) + # enqueue kernel - _rmsnorm_kernel[(M,)]( - x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps - ) + _rmsnorm_kernel[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) return y diff --git a/tests/test_infer_ops/triton/test_rmsnorm_triton.py b/tests/test_infer_ops/triton/test_rmsnorm_triton.py index 6828151ce083..7cc69657cd85 100644 --- a/tests/test_infer_ops/triton/test_rmsnorm_triton.py +++ b/tests/test_infer_ops/triton/test_rmsnorm_triton.py @@ -1,11 +1,12 @@ import pytest import torch -from packaging import version import triton +from packaging import version +from transformers.models.llama.modeling_llama import LlamaRMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm from colossalai.kernel.triton import rms_layernorm from colossalai.testing.utils import parameterize -from transformers.models.llama.modeling_llama import LlamaRMSNorm try: pass @@ -24,7 +25,6 @@ @parameterize("M", [2, 4, 8, 16]) @parameterize("N", [64, 128]) def test_layer_norm(M, N): - dtype = torch.float16 eps = 1e-5 x_shape = (M, N) @@ -39,15 +39,14 @@ def test_layer_norm(M, N): assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5) - # Triton benchmark plot attributions configs = [ triton.testing.Benchmark( x_names=["SEQUENCE_TOTAL"], x_vals=[i for i in range(128, 1025, 128)], line_arg="provider", - line_vals=["llama_rms_layernorm", "triton_rms_layernorm"], - line_names=["llama_rms_layernorm", "triton_rms_layernorm"], + line_vals=["vllm_rms_layernorm", "triton_rms_layernorm"], + line_names=["vllm_rms_layernorm", "triton_rms_layernorm"], styles=[("red", "-"), ("blue", "-")], ylabel="ms", plot_name=f"RMSNorm benchmarking results", @@ -63,18 +62,17 @@ def benchmark_rms_layernorm( HIDDEN_SIZE: int, ): warmup = 10 - rep = 100 + rep = 1000 dtype = torch.float16 eps = 1e-5 x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE) w_shape = (x_shape[-1],) weight = torch.ones(w_shape, dtype=dtype, device="cuda") - rms_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).cuda() + vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda") x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - - if provider == "llama_rms_layernorm": - fn = lambda: rms_norm.forward(x).to(dtype) + if provider == "vllm_rms_layernorm": + fn = lambda: vllm_norm(x) elif provider == "triton_rms_layernorm": fn = lambda: rms_layernorm(x, weight, eps=eps) else: @@ -83,9 +81,8 @@ def benchmark_rms_layernorm( ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms - if __name__ == "__main__": test_layer_norm() - # benchmark_rms_layernorm.run(save_path=".") \ No newline at end of file + # benchmark_rms_layernorm.run(save_path=".", print_data=True)