Skip to content

Commit

Permalink
Update vllm kernels for ROCM (#2826)
Browse files Browse the repository at this point in the history
* (vllm) updated vllm rocm kernels

* revert silu

* update partition size

* remove grouped_topk

* (nit) remove log

* update moe-kernels commit
  • Loading branch information
mht-sharma authored Dec 18, 2024
1 parent 7eeefa3 commit 8f66d32
Show file tree
Hide file tree
Showing 17 changed files with 95 additions and 117 deletions.
13 changes: 13 additions & 0 deletions Dockerfile_amd
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ FROM kernel-builder AS vllm-builder
WORKDIR /usr/src

COPY server/Makefile-vllm Makefile
RUN pip install setuptools_scm

# Build specific version of vllm
RUN make build-vllm-rocm
Expand Down Expand Up @@ -267,6 +268,15 @@ COPY server/exllamav2_kernels/ .

RUN python setup.py build

FROM kernel-builder AS moe-kernels
WORKDIR /usr/src
ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd
ENV VLLM_TARGET_DEVICE=rocm
RUN git clone https://github.com/danieldk/moe-kernels.git && \
cd moe-kernels && \
git checkout ${MOE_KERNELS_BRANCH} && \
python setup.py install

FROM install_deps AS base-copy

# Text Generation Inference base env
Expand All @@ -289,6 +299,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages

# Copy build artifacts from moe kernels
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages

# Install server
COPY proto proto
COPY server server
Expand Down
2 changes: 1 addition & 1 deletion server/Makefile-vllm
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
commit_rocm := de990cd12537f78f74e40b5c8ee1a62d63d734dd

build-vllm-rocm:
if [ ! -d 'vllm' ]; then \
Expand Down
4 changes: 3 additions & 1 deletion server/text_generation_server/layers/attention/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def paged_reshape_and_cache(
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0
)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex

Expand Down
57 changes: 36 additions & 21 deletions server/text_generation_server/layers/attention/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,42 @@
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
from loguru import logger
import vllm._custom_ops as ops

major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5

_PARTITION_SIZE_V1V2 = 512
_PARTITION_SIZE_V1V2 = 1024
_PARTITION_SIZE_CUSTOM = 256

_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
_ON_MI250_MI300 = any(
arch in _GPU_ARCH for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"]
)

use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"

use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try:
if use_rocm_custom_paged_attn:
from vllm._custom_C import paged_attention_custom
except ImportError as e:
log_master(
logger.info,
f"Custom Paged Attention not available. Complete error: {e}",


def _use_rocm_custom_paged_attention(
qtype: torch.dtype,
head_size: int,
block_size: int,
gqa_ratio: int,
max_seq_len: int,
) -> bool:
# rocm custom page attention not support on navi (gfx1*)
return (
use_rocm_custom_paged_attn
and _ON_MI250_MI300
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 131072
)
use_rocm_custom_paged_attn = False


def paged_attention(
Expand Down Expand Up @@ -66,13 +82,8 @@ def paged_attention(

num_kv_heads = kv_cache.key.shape[1]
gqa_ratio = num_heads // num_kv_heads
use_custom = (
use_rocm_custom_paged_attn
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
and (head_size == 128 or head_size == 64)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_s <= 32768
use_custom = _use_rocm_custom_paged_attention(
query.dtype, head_size, block_size, gqa_ratio, max_s
)

if not use_custom:
Expand All @@ -90,8 +101,6 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
import vllm._custom_ops as ops

use_v1 = (
max_s <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
Expand All @@ -103,7 +112,7 @@ def paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
Expand All @@ -112,6 +121,7 @@ def paged_attention(
None,
"auto",
1.0,
1.0,
)
else:
# Run PagedAttention V2.
Expand All @@ -137,7 +147,7 @@ def paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
Expand All @@ -146,9 +156,10 @@ def paged_attention(
None,
"auto",
1.0,
1.0,
)
else:
paged_attention_custom(
ops.paged_attention_rocm(
out,
exp_sums,
max_logits,
Expand All @@ -164,6 +175,10 @@ def paged_attention(
max_s,
None,
"auto",
1.0,
1.0,
None,
_PARTITION_SIZE,
)

return out
Expand Down
37 changes: 22 additions & 15 deletions server/text_generation_server/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def forward(self, hidden_states, residual=None):
return normed_hidden_states, residual

elif SYSTEM == "rocm":
from vllm._C import ops
import vllm._custom_ops as ops

class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
Expand Down Expand Up @@ -121,6 +121,27 @@ def forward(self, hidden_states, residual=None):
residual is not None,
)
return out, residual if residual is not None else hidden_states
elif SYSTEM == "rocm":
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
ops.fused_add_rms_norm(
hidden_states,
residual,
self.weight.data,
self.variance_epsilon,
)
return hidden_states, residual

residual = hidden_states

out = torch.empty_like(hidden_states)
ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
elif hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
Expand Down Expand Up @@ -164,20 +185,6 @@ def forward(self, hidden_states, residual=None):
res = hidden_states

return normed_hidden_states, res
elif SYSTEM == "rocm":
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states

out = torch.empty_like(hidden_states)
ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
Expand Down
8 changes: 4 additions & 4 deletions server/text_generation_server/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

if ROCM_USE_SKINNY_GEMM:
try:
from vllm import _custom_C
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
f"Could not load `vllm._custom_ops` for ROCm skinny gemm. Full error: {e}"
)


Expand Down Expand Up @@ -95,12 +95,12 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
ops.wvSpltK(weight, inp, out, n, self.cu_count)
elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.LLMM1(weight, inp, out, 4)
ops.LLMM1(weight, inp, out, 4)
else:
out = F.linear(inp, weight)

Expand Down
5 changes: 1 addition & 4 deletions server/text_generation_server/layers/moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
UnquantizedWeight,
)

if SYSTEM == "rocm":
from .fused_moe_rocm import grouped_topk
from vllm.model_executor.layers.fused_moe import fused_topk
elif SYSTEM == "ipex":
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
else:
from moe_kernels.fused_moe import fused_topk, grouped_topk
Expand Down
52 changes: 0 additions & 52 deletions server/text_generation_server/layers/moe/fused_moe_rocm.py

This file was deleted.

4 changes: 1 addition & 3 deletions server/text_generation_server/layers/moe/unquantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import UnquantizedWeight, Weights

if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM == "ipex":
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
else:
from moe_kernels.fused_moe import fused_moe
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/layers/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
if SYSTEM == "cuda":
import rotary_emb
elif SYSTEM == "rocm":
from vllm._C import ops
import vllm._custom_ops as ops
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(

rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm":
from vllm._C import ops
import vllm._custom_ops as ops

# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM

if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM == "ipex":
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
else:
from moe_kernels.fused_moe import fused_moe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@

if SYSTEM == "rocm":
try:
from vllm import _custom_C
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")


class DeepseekV2Config(PretrainedConfig):
Expand Down Expand Up @@ -408,7 +408,7 @@ def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
ops.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out, reduce=reduce)
else:
gate_up_states = self.gate_up_proj(hidden_states)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def forward(

rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm":
from vllm._C import ops
import vllm._custom_ops as ops

# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
Expand Down
Loading

0 comments on commit 8f66d32

Please sign in to comment.