From c0c13ec26f64b504dc1bde36ee13d7b79a012290 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 19 Sep 2024 10:41:36 -0400 Subject: [PATCH] Dynamic group blocks in marlin MoE --- csrc/moe/marlin_moe_ops.cu | 135 ++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 68 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 49cc03f827f68..d22ba7fe4335a 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -344,9 +344,7 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order // whether act_order is enabled > __device__ inline void MarlinMoESingle( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -358,6 +356,8 @@ __device__ inline void MarlinMoESingle( // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, + int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale int num_groups, // number of scale groups per output channel int expert_idx, // idx of current expert int num_experts, // number of experts @@ -386,8 +386,8 @@ __device__ inline void MarlinMoESingle( int n_tiles = prob_n / 16 / thread_n_blocks; int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { + if constexpr (!has_act_order) { + if (group_blocks != -1 && group_blocks >= thread_k_blocks) { // Ensure that the number of tiles in each stripe is a multiple of the // groupsize; this avoids an annoying special case where a stripe starts // in the middle of group. @@ -481,11 +481,11 @@ __device__ inline void MarlinMoESingle( // Scale sizes/strides without act_order int s_gl_stride = prob_n / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = + int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; // Scale size/strides with act_order constexpr int tb_k = 16 * thread_k_blocks; @@ -529,7 +529,7 @@ __device__ inline void MarlinMoESingle( // No act_order int s_gl_rd; if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { + if (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + @@ -543,7 +543,7 @@ __device__ inline void MarlinMoESingle( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1) + if (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else @@ -709,10 +709,10 @@ __device__ inline void MarlinMoESingle( } } } else { - if constexpr (group_blocks != -1) { + if (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { + if (group_blocks >= thread_k_blocks) { // Only fetch scales if this tile starts a new group if (pipe % (group_blocks / thread_k_blocks) == 0) { if (s_sh_wr_pred) { @@ -800,8 +800,8 @@ __device__ inline void MarlinMoESingle( if constexpr (!has_act_order) { // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { + if (group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); @@ -921,7 +921,7 @@ __device__ inline void MarlinMoESingle( scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); } else { - if constexpr (group_blocks != -1) { + if (group_blocks != -1) { scale(frag_b0, frag_s[k % 2][j], 0); } } @@ -932,7 +932,7 @@ __device__ inline void MarlinMoESingle( act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); } else { - if constexpr (group_blocks != -1) { + if (group_blocks != -1) { scale(frag_b1, frag_s[k % 2][j], 1); } } @@ -1106,9 +1106,10 @@ __device__ inline void MarlinMoESingle( // For per-column quantization we finally apply the scale here (only for // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); + if constexpr (!has_act_order && w_type.size_bits() == 4) { + if (group_blocks == -1) { + res = __hmul2(res, s[0]); + } } ((half2*)sh)[idx] = res; @@ -1237,36 +1238,32 @@ __device__ inline void MarlinMoESingle( if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; - if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (!has_act_order) { if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { + if (group_blocks == -1) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); } + } else { + if (group_blocks == -1) { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { + if (group_blocks == -1) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { @@ -1274,15 +1271,25 @@ __device__ inline void MarlinMoESingle( reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } + } else { + if (group_blocks == -1) { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } } } // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if constexpr (!has_act_order && && w_type.size_bits() == 8) { + if (group_blocks == -1 && threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll @@ -1346,9 +1353,7 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order // whether act_order is enabled > __global__ void MarlinMoE( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -1360,6 +1365,8 @@ __global__ void MarlinMoE( // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, + int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale int num_groups, // number of scale groups per output channel int expert_idx, // idx of current expert int num_experts, // number of experts @@ -1406,30 +1413,30 @@ __global__ void MarlinMoE( if (max_block == 1) { MarlinMoESingle( + stages, has_act_order>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { MarlinMoESingle( + stages, has_act_order>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { MarlinMoESingle( + stages, has_act_order>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { MarlinMoESingle( + stages, has_act_order>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + expert_offsets, group_blocks, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } @@ -1460,9 +1467,7 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order // whether act_order is enabled > __global__ void MarlinMoE( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -1474,6 +1479,8 @@ __global__ void MarlinMoE( // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ expert_offsets, + int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale int num_groups, // number of scale groups per output channel int expert_idx, // idx of current expert int num_experts, // number of experts @@ -1510,20 +1517,19 @@ static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; #define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ - GROUP_BLOCKS, NUM_THREADS) \ + NUM_THREADS) \ else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ + has_act_order == HAS_ACT_ORDER && num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ MarlinMoE, \ + STAGES, HAS_ACT_ORDER>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ MarlinMoE \ + STAGES, HAS_ACT_ORDER> \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + g_idx_ptr, expert_offsets_ptr, group_blocks, num_groups, expert_idx, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ replicate_input, apply_weights, m_block, max_par, \ exec_cfg.max_m_blocks); \ @@ -1704,15 +1710,8 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, } #define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights,