Skip to content

Commit

Permalink
Dynamic group blocks in marlin MoE
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Sep 19, 2024
1 parent 02c9afa commit c0c13ec
Showing 1 changed file with 67 additions and 68 deletions.
135 changes: 67 additions & 68 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool has_act_order // whether act_order is enabled
>
__device__ inline void MarlinMoESingle(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
Expand All @@ -358,6 +356,8 @@ __device__ inline void MarlinMoESingle(
// (k/groupsize)xn
const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets,
int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
int num_groups, // number of scale groups per output channel
int expert_idx, // idx of current expert
int num_experts, // number of experts
Expand Down Expand Up @@ -386,8 +386,8 @@ __device__ inline void MarlinMoESingle(
int n_tiles = prob_n / 16 / thread_n_blocks;
int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);

if constexpr (!has_act_order && group_blocks != -1) {
if (group_blocks >= thread_k_blocks) {
if constexpr (!has_act_order) {
if (group_blocks != -1 && group_blocks >= thread_k_blocks) {
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts
// in the middle of group.
Expand Down Expand Up @@ -481,11 +481,11 @@ __device__ inline void MarlinMoESingle(
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups =
int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
// Scale size/strides with act_order
constexpr int tb_k = 16 * thread_k_blocks;
Expand Down Expand Up @@ -529,7 +529,7 @@ __device__ inline void MarlinMoESingle(
// No act_order
int s_gl_rd;
if constexpr (!has_act_order) {
if constexpr (group_blocks == -1) {
if (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
Expand All @@ -543,7 +543,7 @@ __device__ inline void MarlinMoESingle(
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1)
if (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else
Expand Down Expand Up @@ -709,10 +709,10 @@ __device__ inline void MarlinMoESingle(
}
}
} else {
if constexpr (group_blocks != -1) {
if (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;

if constexpr (group_blocks >= thread_k_blocks) {
if (group_blocks >= thread_k_blocks) {
// Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
Expand Down Expand Up @@ -800,8 +800,8 @@ __device__ inline void MarlinMoESingle(

if constexpr (!has_act_order) {
// No act-order case
if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
if (group_blocks != -1) {
if (group_blocks >= thread_k_blocks) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
Expand Down Expand Up @@ -921,7 +921,7 @@ __device__ inline void MarlinMoESingle(
scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
} else {
if constexpr (group_blocks != -1) {
if (group_blocks != -1) {
scale(frag_b0, frag_s[k % 2][j], 0);
}
}
Expand All @@ -932,7 +932,7 @@ __device__ inline void MarlinMoESingle(
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);

} else {
if constexpr (group_blocks != -1) {
if (group_blocks != -1) {
scale(frag_b1, frag_s[k % 2][j], 1);
}
}
Expand Down Expand Up @@ -1106,9 +1106,10 @@ __device__ inline void MarlinMoESingle(

// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4) {
res = __hmul2(res, s[0]);
if constexpr (!has_act_order && w_type.size_bits() == 4) {
if (group_blocks == -1) {
res = __hmul2(res, s[0]);
}
}

((half2*)sh)[idx] = res;
Expand Down Expand Up @@ -1237,52 +1238,58 @@ __device__ inline void MarlinMoESingle(
if (slice_iters == 0) {
cp_async_wait<0>();
bool last = slice_idx == slice_count - 1;
if constexpr (!has_act_order && group_blocks == -1) {
if constexpr (!has_act_order) {
if constexpr (w_type.size_bits() == 8) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
} else {
// For 4-bit per-column scales, we only fetch them here in the
// final step before write-out
if (last) {
if (group_blocks == -1) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
}
} else {
if (group_blocks == -1) {
// For 4-bit per-column scales, we only fetch them here in the
// final step before write-out
if (last) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
}
}
}
}

thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1) {
if constexpr (w_type.size_bits() == 8) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}

} else {
if (last) {
if (group_blocks == -1) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
}
} else {
if (group_blocks == -1) {
if (last) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
}
}
}
}

// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 8) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
if constexpr (!has_act_order && && w_type.size_bits() == 8) {
if (group_blocks == -1 && threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
Expand Down Expand Up @@ -1346,9 +1353,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool has_act_order // whether act_order is enabled
>
__global__ void MarlinMoE(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
Expand All @@ -1360,6 +1365,8 @@ __global__ void MarlinMoE(
// (k/groupsize)xn
const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets,
int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
int num_groups, // number of scale groups per output channel
int expert_idx, // idx of current expert
int num_experts, // number of experts
Expand Down Expand Up @@ -1406,30 +1413,30 @@ __global__ void MarlinMoE(

if (max_block == 1) {
MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>(
stages, has_act_order>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
} else if (max_block == 2) {
MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>(
stages, has_act_order>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
} else if (max_block == 3) {
MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>(
stages, has_act_order>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
} else {
MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>(
stages, has_act_order>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
}
Expand Down Expand Up @@ -1460,9 +1467,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool has_act_order // whether act_order is enabled
>
__global__ void MarlinMoE(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
Expand All @@ -1474,6 +1479,8 @@ __global__ void MarlinMoE(
// (k/groupsize)xn
const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets,
int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
int num_groups, // number of scale groups per output channel
int expert_idx, // idx of current expert
int num_experts, // number of experts
Expand Down Expand Up @@ -1510,20 +1517,19 @@ static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;

#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
GROUP_BLOCKS, NUM_THREADS) \
NUM_THREADS) \
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \
has_act_order == HAS_ACT_ORDER && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
STAGES, HAS_ACT_ORDER>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
STAGES, HAS_ACT_ORDER> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
g_idx_ptr, expert_offsets_ptr, group_blocks, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \
exec_cfg.max_m_blocks); \
Expand Down Expand Up @@ -1704,15 +1710,8 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
}

#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
\
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS)

void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
const void* sorted_ids, const void* topk_weights,
Expand Down

0 comments on commit c0c13ec

Please sign in to comment.