Skip to content

Commit

Permalink
PR comments:
Browse files Browse the repository at this point in the history
- add kFp8Type constant for cuda/hip agnostic torch type checking
- check contiguous
- overflow
- reduce number of tests

Signed-off-by: luka <[email protected]>
  • Loading branch information
ProExpertProg committed Dec 12, 2024
1 parent a70d496 commit 8be34ea
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 15 deletions.
2 changes: 2 additions & 0 deletions csrc/quantization/fp8/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "quantization/vectorization.cuh"

#include <cmath>
#include <c10/core/ScalarType.h>

#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
Expand All @@ -17,6 +18,7 @@ using FP8_TYPE = c10::Float8_e4m3fnuz;
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;

namespace vllm {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ void rms_norm_dynamic_per_token_quant(
torch::Tensor& scales, // [num_tokens]
double const var_epsilon, // Variance epsilon used in norm calculation
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
out.dtype() == torch::kInt8);
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());

if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(out.dtype() == kFp8Type);
}
TORCH_CHECK(scales.dtype() == torch::kFloat32);

Expand Down
16 changes: 10 additions & 6 deletions csrc/quantization/fused_kernels/layernorm_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ template <typename scalar_t, bool has_residual = false>
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
int32_t const hidden_size, float const epsilon,
scalar_t const* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * hidden_size;
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
// sum of squares
float ss = 0.0f;

Expand Down Expand Up @@ -53,7 +53,8 @@ __device__ void compute_dynamic_per_token_scales(
float const rms, float const* __restrict__ scale_ub,
float const min_scaling_factor, int32_t const hidden_size,
scalar_t const* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * hidden_size;
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};

float block_absmax_val_maybe = 0.0f;
Expand Down Expand Up @@ -99,7 +100,8 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
float const rms, float const scale,
int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * hidden_size;
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;

for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(input[token_offset + i]);
Expand All @@ -123,7 +125,7 @@ template <typename scalar_t, bool has_residual = false>
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
int32_t const hidden_size, float const epsilon,
scalar_t const* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * hidden_size;
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);

// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vec_input =
Expand Down Expand Up @@ -184,7 +186,8 @@ __device__ void compute_dynamic_per_token_scales(
float const rms, float const* __restrict__ scale_ub,
float const min_scaling_factor, int32_t const hidden_size,
scalar_t const* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * hidden_size;
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;

// Vectorized input/weight/residual to better utilize memory bandwidth.
vec4_t<scalar_t> const* vec_input =
Expand Down Expand Up @@ -263,7 +266,8 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
float const rms, float const scale,
int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * hidden_size;
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;

// Vectorized input/output/weight/residual to better utilize memory bandwidth.
vec4_t<scalar_t> const* vec_input =
Expand Down
16 changes: 10 additions & 6 deletions tests/kernels/test_fused_quant_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@

DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
NUM_TOKENS = [1, 7, 83, 2048, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [1, 3, 4, 16, 64, 2048, 5120,
5137] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
VEC_HIDDEN_SIZES = range(1024, 1030)
# Avoid combinatorial explosion with full Cartesian product
NUM_TOKENS_HIDDEN_SIZES = [
*[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]],
*[(83, i) for i in [1, 1033, 2048, 5120]],
*[(2048, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5137]],
*[(4096, i) for i in [1, 64, 5137]],
]

ADD_RESIDUAL = [False, True]
SCALE_UBS = [True, False]
SEEDS = [0]
Expand Down Expand Up @@ -100,8 +105,7 @@ def ops_impl(weight: torch.Tensor,
scale_ub)


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
@pytest.mark.parametrize("dtype", DTYPES)
Expand Down

0 comments on commit 8be34ea

Please sign in to comment.