From 8c887b62c6ca797e1866a50d39335c59a7b1f34d Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 10 Oct 2024 18:31:35 +0000 Subject: [PATCH 1/2] Sage diff --- CMakeLists.txt | 5 +- csrc/activation_kernels.cu | 3 +- csrc/ops.h | 5 +- .../layernorm_kernels/activation_kernels.cu | 91 ++++++++++++++++++ csrc/reduction_utils.cuh | 95 +++++++++++++++++++ csrc/torch_bindings.cpp | 22 +++-- pyproject.toml | 2 +- vllm/compilation/backends.py | 20 +++- vllm/compilation/fusion.py | 33 +++++++ vllm/compilation/wrapper.py | 1 - vllm/transformers_utils/config.py | 4 +- vllm/transformers_utils/configs/__init__.py | 4 +- 12 files changed, 267 insertions(+), 18 deletions(-) create mode 100644 csrc/quantization/layernorm_kernels/activation_kernels.cu create mode 100644 csrc/reduction_utils.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index d884ac4c4ce7a..58b7c737254f7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,7 +49,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0") +set(TORCH_SUPPORTED_VERSION_CUDA "2.5.0") set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0") # @@ -240,7 +240,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/gguf/gguf_kernel.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/quantization/layernorm_kernels/activation_kernels.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 5ed1dc3b8f792..aff605a735fa2 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -69,9 +69,10 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { input.data_ptr(), d); \ }); -void silu_and_mul(torch::Tensor& out, // [..., d] +void silu_and_mul(torch::Tensor& result, // [..., d] torch::Tensor& input) // [..., 2 * d] { + torch::Tensor& out = result; LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } diff --git a/csrc/ops.h b/csrc/ops.h index 21d2c37c20383..a34717b735769 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -49,7 +49,10 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, int64_t rot_dim, torch::Tensor& cos_sin_cache_offsets); -void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void silu_and_mul(torch::Tensor& result, torch::Tensor& input); + +void silu_and_mul_quant(torch::Tensor& result, torch::Tensor const& input, + torch::Tensor const& scale); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/quantization/layernorm_kernels/activation_kernels.cu b/csrc/quantization/layernorm_kernels/activation_kernels.cu new file mode 100644 index 0000000000000..3f63f49ba3267 --- /dev/null +++ b/csrc/quantization/layernorm_kernels/activation_kernels.cu @@ -0,0 +1,91 @@ +#include +#include +#include + +#include "../../cuda_compat.h" +#include "../../dispatch_utils.h" +#include "../../reduction_utils.cuh" +// #include "quant_utils.cuh" +#ifndef USE_ROCM +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = + std::numeric_limits::max(); +#else + #include "amd/hip_float8.h" +using FP8_TYPE = c10::Float8_e4m3fnuz; +// Using the default max value from pytorch (240.0) will cause accuracy +// issue when running dynamic quantization. Here use 224.0f for rocm. +constexpr auto FP8_E4M3_MAX = 224.0f; +#endif +namespace vllm { + +template +__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, + float const scale) { + float x = 0.0f; + if constexpr (is_scale_inverted) { + x = val * scale; + } else { + x = val / scale; + } + + float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); +#ifndef USE_ROCM + return static_cast(r); +#else + // Use hardware cvt instruction for fp8 on rocm + return c10::Float8_e4m3fnuz(hip_fp8(r).data, + c10::Float8_e4m3fnuz::from_bits()); +#endif +} + +static inline __device__ int8_t float_to_int8_rn(float x) { + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +} + +template +__device__ __forceinline__ T silu(const T& x) { + // x * sigmoid(x) + return (T)(((float)x) / (1.0f + expf((float)-x))); +} + +template +__global__ void silu_and_mul_quant_kernel( + FP8_TYPE* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2 * d] + const int d, + float* __restrict__ scale) { + const int64_t token_idx = blockIdx.x; + + float inverted_scale = 1 / *scale; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const float x = (float)VLLM_LDG(&input[token_idx * 2 * d + idx]); + const float y = (float)VLLM_LDG(&input[token_idx * 2 * d + d + idx]); + float t = silu(x) * y; + out[token_idx * d + idx] = scaled_fp8_conversion( + t, inverted_scale); + } + +} +} // namespace vllm + +void silu_and_mul_quant(torch::Tensor& result, // [..., d] + torch::Tensor const& input, // [..., 2 * d] + torch::Tensor const& scale // [num_tokens] +) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + dim3 block(std::min(d, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + float* scale_ptr = scale.data_ptr(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "silu_and_mul_quant_kernel", [&] { + vllm::silu_and_mul_quant_kernel<<>>( + result.data_ptr(), input.data_ptr(), d, + scale_ptr); + }); +} diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh new file mode 100644 index 0000000000000..fafda3e8a958a --- /dev/null +++ b/csrc/reduction_utils.cuh @@ -0,0 +1,95 @@ +/* +* Adapted from +* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh +* Copyright (c) 2023, The vLLM team. +* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#pragma once + +#include "cuda_compat.h" + +namespace vllm { + +namespace detail { + +template +__inline__ __device__ T _max(T a, T b) { + return max(a, b); +} + +template +__inline__ __device__ T _sum(T a, T b) { + return a + b; +} + +} // namespace detail + +template +using ReduceFnType = T (*)(T, T); + +// Helper function to return the next largest power of 2 +static constexpr int _nextPow2(unsigned int num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +template +__inline__ __device__ T warpReduce(T val, ReduceFnType fn) { + static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, + "numLanes is not a positive power of 2!"); + static_assert(numLanes <= WARP_SIZE); +#pragma unroll + for (int mask = numLanes >> 1; mask > 0; mask >>= 1) + val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask)); + + return val; +} + +template +__inline__ __device__ T blockReduce(T val, ReduceFnType fn) { + static_assert(maxBlockSize <= 1024); + if constexpr (maxBlockSize > WARP_SIZE) { + val = warpReduce(val, fn); + // Calculates max number of lanes that need to participate in the last + // warpReduce + constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; + static __shared__ T shared[maxActiveLanes]; + int lane = threadIdx.x % WARP_SIZE; + int wid = threadIdx.x / WARP_SIZE; + if (lane == 0) shared[wid] = val; + + __syncthreads(); + + val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] + : (T)(0.0f); + val = warpReduce(val, fn); + } else { + // A single warpReduce is equal to blockReduce + val = warpReduce(val, fn); + } + return val; +} + +template +__inline__ __device__ T blockReduceMax(T val) { + return blockReduce(val, detail::_max); +} + +template +__inline__ __device__ T blockReduceSum(T val) { + return blockReduce(val, detail::_sum); +} + +} // namespace vllm \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 37dc0dd21fece..12fcce2434578 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -49,7 +49,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Activation ops // Activation function used in SwiGLU. - ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); // Activation function used in GeGLU with `none` approximation. @@ -107,15 +107,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Layernorm-quant // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( - "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, Tensor scale, float epsilon) -> " + "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, " + "Tensor scale, float epsilon) -> " "()"); - ops.impl("rms_norm_static_fp8_quant", torch::kCUDA, &rms_norm_static_fp8_quant); + ops.impl("rms_norm_static_fp8_quant", torch::kCUDA, + &rms_norm_static_fp8_quant); // In-place fused Add and RMS Normalization. ops.def( - "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor! residual, Tensor weight, " + "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, " + "Tensor! residual, Tensor weight, " "Tensor scale, float epsilon) -> ()"); - ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, &fused_add_rms_norm_static_fp8_quant); + ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, + &fused_add_rms_norm_static_fp8_quant); // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. @@ -281,6 +285,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // capability ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"); ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); + ops.def("silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); + ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); // Mamba selective scan kernel ops.def( @@ -330,12 +336,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute FP8 quantized tensor for given scaling factor. ops.def( - "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); + "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> " + "()"); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. ops.def( - "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) -> " + "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) " + "-> " "()"); ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); diff --git a/pyproject.toml b/pyproject.toml index c9057b061aad9..8b8c27e042974 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,8 @@ requires = [ "packaging", "setuptools>=61", "setuptools-scm>=8.0", - "torch == 2.4.0", "wheel", + "torch", "jinja2", ] build-backend = "setuptools.build_meta" diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index f68059f311e0d..c4b9fdc9a4edf 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -170,7 +170,7 @@ def fix_functionalization(graph: fx.Graph): kwargs = node.kwargs input = kwargs['input'] - out = kwargs['out'] + out = kwargs['result'] # TODO # Create a new call to torch.ops._C.rotary_embedding.default # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa @@ -189,6 +189,24 @@ def fix_functionalization(graph: fx.Graph): user.replace_all_uses_with(replace_node) nodes_to_remove.append(user) nodes_to_remove.append(node) + elif node.args[0] == torch.ops.neuralmagic.silu_mul_quant.default: + # + kwargs = node.kwargs + + replace_node = kwargs['result'] + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.silu_mul_quant.default, kwargs=kwargs) + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) # Remove the nodes all at once for node in nodes_to_remove: diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e375d200d620b..5297f4ac1b703 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -10,6 +10,29 @@ logger = init_logger(__name__) +@torch.library.custom_op("neuralmagic::silu_mul_quant", mutates_args=()) +def silu_mul_quant(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + result = torch.empty(x.shape[0], x.shape[1] // 2, device=x.device, dtype=torch.float8_e4m3fn) + + torch.ops._C.silu_and_mul_quant(result, x, scale) + return result + +@silu_mul_quant.register_fake +def silu_mul_quant(x: torch.Tensor, scale: torch.Tensor): + return torch.empty(x.shape[0], x.shape[1] // 2, device=x.device, dtype=torch.float8_e4m3fn) + +def silu_mul_quant_replacement(x: torch.Tensor, scale:torch.Tensor) -> torch.tensor: + # print("MATCH QUANT") + return torch.ops.neuralmagic.silu_mul_quant(x, scale) + +def silu_mul_quant_pattern(input_: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + silu_mul_result = torch.empty([input_.shape[0], input_.shape[1] // 2], dtype=torch.float16, device=input_.device) + silu_mul_func = torch.ops.higher_order.auto_functionalized(torch.ops._C.silu_and_mul.default, result = silu_mul_result, input = input_) + result = torch.empty([input_.shape[0], input_.shape[1] // 2], dtype=torch.float8_e4m3fn, device=input_.device) + static_fp8_quant_func = torch.ops.higher_order.auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, result=result, input=silu_mul_func[1], scale=scale) + return static_fp8_quant_func[1] + + def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): at1 = auto_functionalized(torch.ops._C.rms_norm.default, result=result_rms, input=input, weight=weight, @@ -75,6 +98,16 @@ def record_match_fn(match: Match): register_replacement(rms_pattern_residual_static, rms_replacement_residual_static, inputs, fwd_only, my_patterns, extra_check=record_match_fn) + # silu-mul quant + x = torch.empty((128, 256), device="cuda", dtype=torch.float16) + scale = torch.empty((1,1), device="cuda" , dtype=torch.float32) + + register_replacement(silu_mul_quant_pattern, + silu_mul_quant_replacement, + [x, scale], + fwd_only, + [my_patterns]) + return my_patterns, matches diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 1594b64a61b94..62f0c4af5ea9e 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -26,7 +26,6 @@ class TorchCompileWrapperWithCustomDispatcher: """ def __init__(self, compiled_callable: Optional[Callable] = None): - if compiled_callable is None: # default compilation settings # compiling the forward method diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index b33449c42ecf5..236170c7b3ab4 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -20,7 +20,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, ExaoneConfig, InternVLChatConfig, JAISConfig, - MedusaConfig, MllamaConfig, + MedusaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, NVLM_D_Config, Qwen2VLConfig, RWConfig, @@ -38,7 +38,7 @@ logger = init_logger(__name__) _CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = { - "mllama": MllamaConfig + # "mllama": MllamaConfig } _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 8d6385d42d002..5d2cc8e7f2d2c 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -9,7 +9,7 @@ from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig -from vllm.transformers_utils.configs.mllama import MllamaConfig +# from vllm.transformers_utils.configs.mllama import MllamaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig @@ -29,7 +29,7 @@ "MedusaConfig", "EAGLEConfig", "ExaoneConfig", - "MllamaConfig", + # "MllamaConfig", "MLPSpeculatorConfig", "NemotronConfig", "NVLM_D_Config", From 272326b58256acfb4090dc3cef9e8bbd39abe8aa Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 11 Oct 2024 01:48:09 +0000 Subject: [PATCH 2/2] enable custom_ops --- vllm/model_executor/custom_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index d0e90245ad010..d0b1b18656cd4 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -55,9 +55,9 @@ def forward_gaudi(self, *args, **kwargs): def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - - if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR: - return self.forward_native + # + # if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR: + # return self.forward_native if is_hip(): return self.forward_hip