Skip to content

Commit

Permalink
[Inference] Update rms norm kernel, benchmark with vLLM (#5315)
Browse files Browse the repository at this point in the history
* add

* xi

* del

* del

* fix
  • Loading branch information
CjhHa1 authored Jan 29, 2024
1 parent 7ddd8b3 commit 1f8a75d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 20 deletions.
14 changes: 7 additions & 7 deletions colossalai/kernel/triton/rms_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def _rmsnorm_kernel(
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):

# This triton kernel implements Root Mean Square Layer Norm (RMSNorm).

# Map the program id to the row of X and Y it should compute.
Expand Down Expand Up @@ -54,18 +53,19 @@ def _rmsnorm_kernel(
def rms_layernorm(x, weight, eps):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
# reshape input data into 2D tensor, (total token, hidden_size)
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()

BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
if N > MAX_FUSED_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")

# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32)

# enqueue kernel
_rmsnorm_kernel[(M,)](
x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
_rmsnorm_kernel[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
return y
23 changes: 10 additions & 13 deletions tests/test_infer_ops/triton/test_rmsnorm_triton.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import pytest
import torch
from packaging import version
import triton
from packaging import version
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm

from colossalai.kernel.triton import rms_layernorm
from colossalai.testing.utils import parameterize
from transformers.models.llama.modeling_llama import LlamaRMSNorm

try:
pass
Expand All @@ -24,7 +25,6 @@
@parameterize("M", [2, 4, 8, 16])
@parameterize("N", [64, 128])
def test_layer_norm(M, N):

dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
Expand All @@ -39,15 +39,14 @@ def test_layer_norm(M, N):
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5)



# Triton benchmark plot attributions
configs = [
triton.testing.Benchmark(
x_names=["SEQUENCE_TOTAL"],
x_vals=[i for i in range(128, 1025, 128)],
line_arg="provider",
line_vals=["llama_rms_layernorm", "triton_rms_layernorm"],
line_names=["llama_rms_layernorm", "triton_rms_layernorm"],
line_vals=["vllm_rms_layernorm", "triton_rms_layernorm"],
line_names=["vllm_rms_layernorm", "triton_rms_layernorm"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"RMSNorm benchmarking results",
Expand All @@ -63,18 +62,17 @@ def benchmark_rms_layernorm(
HIDDEN_SIZE: int,
):
warmup = 10
rep = 100
rep = 1000

dtype = torch.float16
eps = 1e-5
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
rms_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).cuda()
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda")
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")

if provider == "llama_rms_layernorm":
fn = lambda: rms_norm.forward(x).to(dtype)
if provider == "vllm_rms_layernorm":
fn = lambda: vllm_norm(x)
elif provider == "triton_rms_layernorm":
fn = lambda: rms_layernorm(x, weight, eps=eps)
else:
Expand All @@ -83,9 +81,8 @@ def benchmark_rms_layernorm(
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)

return ms



if __name__ == "__main__":
test_layer_norm()
# benchmark_rms_layernorm.run(save_path=".")
# benchmark_rms_layernorm.run(save_path=".", print_data=True)

0 comments on commit 1f8a75d

Please sign in to comment.