From 364169294e5f9192680123896250694cd7b1dae0 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 31 Jul 2024 05:15:15 +0000 Subject: [PATCH] Multi-GPU works, but could make it faster --- csrc/moe/marlin_moe_ops.cu | 1376 ++++++++++++++++- csrc/moe/marlin_moe_ops.h | 3 +- csrc/moe/torch_bindings.cpp | 11 +- examples/offline_inference.py | 4 +- tests/kernels/test_moe.py | 39 +- .../layers/fused_moe/fused_moe.py | 30 +- vllm/model_executor/models/mixtral_quant.py | 36 +- 7 files changed, 1417 insertions(+), 82 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 69a4ab42da9ac..0670973a81b21 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -30,6 +30,8 @@ inline std::string str(T x) { return std::to_string(x); } +#define CPU_OFFSETS false + namespace marlin_moe { constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } @@ -243,35 +245,1018 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); half* out_half = reinterpret_cast(out_int4_ptr + offset); - int base_k = 0; + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += blockDim.x; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, + int* __restrict__ expert_offsets, + int topk_length, + int block_size) { + int expert_id = threadIdx.x; + int num_experts = blockDim.x; + + int occurrences = 0; + for (int i = 0; i < topk_length; ++i) { + occurrences += (topk_ids[i] == expert_id); + } + expert_offsets[expert_id + 1] = occurrences; + __syncthreads(); + + if (threadIdx.x == 0) { + int tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size; + expert_offsets[i + 1] = tot_offset; + } + // for (int i = 0; i < num_experts + 1; ++i) { + // printf("expert offset: %d -> %d (%d %d)\n", + // i, expert_offsets[i], topk_length, block_size); + // } + } + __syncthreads(); + +} + +#if CPU_OFFSETS + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert // TODO must decide based on offsets + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int try_m_block_ctr, // experiment + int* barrier_ctrs +) { + + // int tot_m_blocks = ceildiv(tot_m, 16); + // if (try_m_block_ctr >= tot_m_blocks) { + // return; + // } + + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + sorted_ids += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + constexpr int sorted_sh_stride = threads; + constexpr int sorted_gl_stride = threads; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (group_blocks == -1 || group_blocks == 0) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + int shs_size; + if constexpr (has_act_order) + shs_size = sh_max_num_groups * s_sh_stride + threads; + else + shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + int* sh_sorted = (int*)(sh_s + shs_size); + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_sh_wr_delta * i + a_sh_wr; + int row = a_idx / a_gl_rd_delta_o; + if (row >= prob_m) { + a_sh_wr_pred[i] = false; + } else { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int sorted_row = + replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; + int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + if (sorted_row < tot_m * (replicate_input ? 1 : topk) && + new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], + a_sh_wr_pred[i]); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // TODO fix + auto fetch_sorted_ids_to_shared = [&]() { + const int mpt = ceildiv(prob_m, threads); + for (int i = 0; i < mpt; i++) { + if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { + sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = + sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; + } + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + + FragB frag_b0 = dequant(b_quant); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + FragB frag_b1 = dequant(b_quant_shift); + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int sorted_row = sorted_ids[c_idx / c_gl_stride]; + int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], + sorted_row < tot_m * topk && + (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk))); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int row = sorted_ids[c_idx / c_gl_stride]; + if (row < tot_m * topk) { + int new_idx = row * c_gl_stride + c_idx % c_gl_stride; + C[new_idx] = c; + } + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here + if constexpr (!has_act_order && group_blocks == -1) { + res = __hmul2(res, s[0]); + } + + ((half2*)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + int row = sorted_ids[c_gl_wr / c_gl_stride]; + if (row < tot_m * topk) { + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + if (!apply_weights) { + C[off] = sh[c_sh_rd]; + } else { + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + for (int j = 0; j < 8; ++j) { + ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); + } + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + // fetch_sorted_ids_to_shared(); + __syncthreads(); + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } - for (int i = 0; i < iters; i++) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } - out_half[cur_k] = a_row_half[src_pos]; + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; - base_k += blockDim.x; + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } } - if (rest) { - if (threadIdx.x < rest) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1) { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } - out_half[cur_k] = a_row_half[src_pos]; + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } } - } - }; + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } - for (int i = 0; i < cur_block_rows; i++) { - int cur_row = start_row + i; - if (cur_row < size_m) { - permute_row(cur_row); + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + start_pipes(); + } } } } +#else + +// TODO could just run MarlinMoE? template -__global__ void MarlinMoE( +__device__ inline void RunSingleIter( const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn @@ -293,9 +1278,9 @@ __global__ void MarlinMoE( const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, int num_groups, // number of scale groups per output channel - int num_tokens_post_padded, // scales_ptrs size with padding - int expert_idx, // idx of current expert + int expert_idx, // idx of current expert // TODO must decide based on offsets int num_experts, // number of experts int topk, // topk parameter of moe int prob_m, // batch dimension m @@ -304,19 +1289,13 @@ __global__ void MarlinMoE( int tot_m, // total number of rows in A and C int* locks, // extra global storage for barrier synchronization bool replicate_input, // do we use the same input for each expert? - bool apply_weights // apply weights to output + bool apply_weights, // apply weights to output + int try_m_block_ctr // experiment ) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. + + // if (threadIdx.x == 0 && blockIdx.x == 0) { + // printf("%d, %d\n", thread_m_blocks, prob_m); + // } // For larger GEMMs we run multiple batchsize 64 versions in parallel for a // better partitioning with less reductions @@ -382,6 +1361,8 @@ __global__ void MarlinMoE( } if (slice_col == n_tiles) { sorted_ids += 16 * thread_m_blocks; + // sorted_off += 16 * thread_m_blocks; + // printf("advance 2: %d (%d %d)\n", sorted_off, blockIdx.x, threadIdx.x); locks += n_tiles; slice_col = 0; } @@ -614,6 +1595,9 @@ __global__ void MarlinMoE( int row = a_idx / a_gl_stride; int sorted_row = replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; + // if (expert_idx == 0) { + // printf("row A: %d (%d %d), iter %d\n", row, blockIdx.x, threadIdx.x, i); + // } int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; if (sorted_row < tot_m * (replicate_input ? 1 : topk) && new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { @@ -949,6 +1933,8 @@ __global__ void MarlinMoE( int c_idx = c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); int sorted_row = sorted_ids[c_idx / c_gl_stride]; + // printf("row C reduce:\n"); + // printf("row C reduce: %d (%d %d)\n", c_idx / c_gl_stride, blockIdx.x, threadIdx.x); int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], sorted_row < tot_m * topk && @@ -1054,6 +2040,9 @@ __global__ void MarlinMoE( i++) { if (c_gl_wr < c_gl_wr_end) { int row = sorted_ids[c_gl_wr / c_gl_stride]; + // if (blockIdx.x == 8 && threadIdx.x == 95) { + // printf("row C write: %d (%d %d)\n", c_gl_wr / c_gl_stride, blockIdx.x, threadIdx.x); + // } if (row < tot_m * topk) { int off = row * c_gl_stride + c_gl_wr % c_gl_stride; if (!apply_weights) { @@ -1103,6 +2092,7 @@ __global__ void MarlinMoE( // Main loop. while (slice_iters) { + // printf("slice\n"); // We unroll over both the global fetch and the register load pipeline to // ensure all shared memory accesses are static. Note that both pipelines // have even length meaning that the next iteration will always start at @@ -1175,6 +2165,7 @@ __global__ void MarlinMoE( } if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice + // TODO we deadlock here barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); @@ -1212,6 +2203,181 @@ __global__ void MarlinMoE( } } +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert // TODO must decide based on offsets + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int try_m_block_ctr, // experiment + int* barrier_ctrs +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + int m_block_ctr = try_m_block_ctr; + + constexpr int max_par = 4; // TODO should be passed as arg + const int* sorted_ids_expert = sorted_ids_base + expert_offsets[expert_idx] + + m_block_ctr * 4 * max_par; + int tot_its = expert_offsets[expert_idx + 1] - + expert_offsets[expert_idx]; + if (tot_its == 0) { + return; + } + // TODO try no padding? + int tot_m_blocks = ceildiv(tot_its, 16); + // int pad = 16 * tot_m_blocks - tot_its; + + // Main loop + for (int m_block_ctr = 0; m_block_ctr < tot_m_blocks; m_block_ctr += 4) { + + const int* sorted_ids = sorted_ids_expert; + // if (m_block_ctr >= tot_m_blocks) { + // return; + // } + + // int* locks = locks_base; //+ (prob_n / 64 * 16) * (m_block_ctr / 4); + + int max_block = tot_m_blocks - m_block_ctr; + prob_m = tot_its - 16 * m_block_ctr; + int full_prob_m = prob_m; + + // int m_offset = m_block_ctr * 16; + // printf("call with m_offset: %d / %d\n", m_offset, tot_its); + + int par = 1; + if (max_block > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + // par = (16 * max_block - pad) / 64; + par = min((16 * max_block) / 64, max_par); + prob_m = 64 * par; + m_block_ctr += 4 * (par - 1); + max_block = 4; + } + + if (max_block == 1) { + RunSingleIter( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, + prob_m, prob_n, prob_k, tot_m, locks, replicate_input, + apply_weights, try_m_block_ctr); + } + else if (max_block == 2) { + RunSingleIter( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, + prob_m, prob_n, prob_k, tot_m, locks, replicate_input, + apply_weights, try_m_block_ctr); + } + else if (max_block == 3) { + RunSingleIter( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, + prob_m, prob_n, prob_k, tot_m, locks, replicate_input, + apply_weights, try_m_block_ctr); + } + else { + RunSingleIter( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, + prob_m, prob_n, prob_k, tot_m, locks, replicate_input, + apply_weights, try_m_block_ctr); + } + + // sorted_ids_expert += 16 * max_block * par; + // break; + // cooperative_groups::this_grid().sync(); + // __atomic__ int ctr; + if (threadIdx.x == 0) { + printf("start bar0 %d %d %d | %d\n", barrier_ctrs[0], barrier_ctrs[1], + barrier_ctrs[2], gridDim.x); + atomicAdd(&barrier_ctrs[0], 1); + // if (barrier_ctrs[2] == gridDim.x) { + // barrier_ctrs[2] = 0; + // } + // else { + while(barrier_ctrs[0] != gridDim.x); + // } + if (blockIdx.x == 0) { + barrier_ctrs[2] = 0; + } + printf("start bar1 %d %d %d | %d\n", barrier_ctrs[0], barrier_ctrs[1], + barrier_ctrs[2], gridDim.x); + atomicAdd(&barrier_ctrs[1], 1); + // if (barrier_ctrs[0] == gridDim.x) { + // barrier_ctrs[0] = 0; + // } + // else { + while(barrier_ctrs[1] != gridDim.x); + // } + if (blockIdx.x == 0) { + barrier_ctrs[0] = 0; + } + printf("start bar2 %d %d %d | %d\n", barrier_ctrs[0], barrier_ctrs[1], + barrier_ctrs[2], gridDim.x); + atomicAdd(&barrier_ctrs[2], 1); + // if (barrier_ctrs[1] == gridDim.x) { + // barrier_ctrs[1] = 0; + // } + // else { + while(barrier_ctrs[2] != gridDim.x); + // } + if (blockIdx.x == 0) { + barrier_ctrs[1] = 0; + } + printf("end bar %d\n", gridDim.x); + } + + // barrier_acquire(&locks2[blockIdx.x], gridDim.x, 0, 0); + // barrier_release(&locks2[blockIdx.x], gridDim.x, 0, 0); + + } +} + +#endif + #else __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, @@ -1223,6 +2389,15 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, return; } +__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, + int* __restrict__ expert_offsets, + int topk_length, + int block_size) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + template , \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + printf("blocks: %d\n", blocks); \ MarlinMoE \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, num_groups, num_tokens_post_padded, expert_idx, \ + g_idx_ptr, expert_offsets2_ptr, num_groups, expert_idx, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights); \ + replicate_input, apply_weights, m_block, barrier_ctrs_ptr); \ } typedef struct { @@ -1401,15 +2579,16 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, + const void* topk_ids, const void* s, const void* g_idx, const void* perm, - void* a_tmp, const void* expert_offsets, int prob_m, + void* a_tmp, void* expert_offsets, void* expert_offsets2, int prob_m, int prob_n, int prob_k, void* workspace, bool has_act_order, bool is_k_full, int num_groups, - int group_size, int num_tokens_post_padded, + int group_size, int num_experts, int topk, int moe_block_size, int dev, cudaStream_t stream, int thread_k, int thread_n, int sms, int max_par, bool replicate_input, - bool apply_weights) { + bool apply_weights, void* barrier_ctrs) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1442,6 +2621,7 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, int thread_n_blocks = thread_n / 16; int blocks = sms; + printf("sms: %d\n", sms); TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); @@ -1477,7 +2657,16 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, int tot_m = prob_m; + #if CPU_OFFSETS const long* expert_offsets_ptr = (const long*)expert_offsets; + int* expert_offsets2_ptr = (int*)expert_offsets2; + #else + const int* topk_ids_ptr = (const int*)topk_ids; + int* expert_offsets2_ptr = (int*)expert_offsets2; + compute_expert_offsets<<<1, num_experts, 0, stream>>>( + topk_ids_ptr, expert_offsets2_ptr, tot_m * topk, moe_block_size); + #endif + int* barrier_ctrs_ptr = (int*)barrier_ctrs; bool do_permute_a = has_act_order; @@ -1489,6 +2678,7 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, } for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + #if CPU_OFFSETS const int4* A_ptr = (const int4*)A; int4* a_tmp_ptr = (int4*)a_tmp; const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; @@ -1517,6 +2707,7 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, int tot_its = expert_offsets_ptr[expert_idx + 1] - expert_offsets_ptr[expert_idx]; // prob_m; + // printf("%d ", tot_its); if (tot_its == 0) { continue; } @@ -1538,6 +2729,9 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, thread_m_blocks = 4; } + // doesn't matter for this version of the code + int m_block = 0; + // Define kernel configurations if (false) { @@ -1558,8 +2752,74 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, } sorted_ids_ptr += 16 * thread_m_blocks * par; + // break; + } + + ///// + + #else + + ///// + + const int4* A_ptr = (const int4*)A; + int4* a_tmp_ptr = (int4*)a_tmp; + const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; + int4* C_ptr = (int4*)C; + const float* topk_weights_ptr = (const float*)topk_weights; + // TODO can't know expert_offsets at this point + const int* sorted_ids_ptr = + (const int*)sorted_ids;// + expert_offsets_ptr[expert_idx]; + const int4* s_ptr = + (const int4*)s + + (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * + prob_n / 8) * + expert_idx; + + const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; + const int* perm_ptr = (const int*)perm + prob_k * expert_idx; + int* locks = (int*)workspace; + + // TODO we need an expert identifying mechanism here too + if (do_permute_a) { + // Permute A columns + int topk_rows = replicate_input ? tot_m : tot_m * topk; + int block_rows = ceildiv(topk_rows, blocks); + permute_cols_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); + A_ptr = a_tmp_ptr; + } + + int max_m_blocks = ceildiv(tot_m, 16); + int m_block = 0; + // for (int m_block = 0; m_block < max_m_blocks; m_block += 16) { + // Define kernel configurations + + // make it max possible value + int thread_m_blocks = 4; + + if (false) { + } + CALL_IF_MOE(16, 4, 256) + CALL_IF_MOE(8, 8, 256) + CALL_IF_MOE(8, 4, 128) + CALL_IF_MOE(4, 8, 128) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", has_act_order = " + str(has_act_order) + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + // } + + // sorted_ids_ptr += 16 * thread_m_blocks * max_par; + // sorted_ids_ptr += 16 * thread_m_blocks * 4; } + #endif } + // printf("\n"); } } // namespace marlin_moe @@ -1567,21 +2827,34 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, + const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, const torch::Tensor& expert_offsets, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_tokens_post_padded, int64_t num_experts, + bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights) { int max_par = 4; int dev = a.get_device(); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::zeros({size_m, topk, size_n}, options); + auto options_dtype = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + auto options_int = torch::TensorOptions().dtype(torch::kInt).device(a.device()); + torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype); torch::Tensor a_tmp = replicate_input - ? torch::zeros({size_m, size_k}, options) - : torch::zeros({size_m, topk, size_k}, options); + ? torch::zeros({size_m, size_k}, options_dtype) + : torch::zeros({size_m, topk, size_k}, options_dtype); + #if CPU_OFFSETS + torch::Tensor expert_offsets2 = torch::empty({0}, options_dtype); + #else + torch::Tensor expert_offsets2 + = torch::empty({num_experts + 1}, options_int); + // torch::Tensor expert_offsets2 = torch::arange(0, + // num_experts * moe_block_size, moe_block_size, + // torch::TensorOptions().dtype(torch::kInt).device(a.device())); + // torch::Tensor expert_offsets2 = expert_offsets; + #endif + torch::Tensor barrier_ctrs = torch::zeros({3}, options_int); // thread_k: `k` size of a thread_tile in `weights` (can usually be left as // auto -1) @@ -1624,13 +2897,20 @@ torch::Tensor marlin_gemm_moe( } } + // std::stringstream sstream; + // sstream << topk_ids.dtype().name(); + // std::string s = sstream.str(); + // printf("topk dtype: %s\n", s.c_str()); + + // printf("run with %ld, %ld, %ld\n", size_m, size_n, size_k); + marlin_moe::marlin_mm_moe_f16i4( a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), - topk_weights.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, + topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), expert_offsets2.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), has_act_order, is_k_full, - num_groups, group_size, num_tokens_post_padded, num_experts, topk, + num_groups, group_size, num_experts, topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_n, sms, max_par, replicate_input, apply_weights); + thread_n, sms, max_par, replicate_input, apply_weights, barrier_ctrs.data_ptr()); return c; } diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 36ad8aac92169..a24ca32a52be7 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -5,9 +5,10 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, + const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, const torch::Tensor& expert_offsets, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_tokens_post_padded, int64_t num_experts, + bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d60b2836fde5c..80bc94d46e28a 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -11,11 +11,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "token_expert_indices, Tensor gating_output) -> ()"); ops.impl("topk_softmax", torch::kCUDA, &topk_softmax); + // ops.def( + // "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " + // "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! g_idx, " + // "Tensor! perm, Tensor! expert_offsets, Tensor! workspace, int size_m, int size_n, " + // "int size_k, bool is_k_full, int num_experts, int topk) " + // "-> Tensor"); + ops.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " - "Tensor! topk_weights, Tensor! b_scales, Tensor! g_idx, Tensor! perm, " + "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! g_idx, Tensor! perm, " "Tensor! expert_offsets, Tensor! workspace, int size_m, int size_n, int " - "size_k, bool is_k_full, int num_tokens_post_padded, int num_experts, " + "size_k, bool is_k_full, int num_experts, " "int topk, int moe_block_size, bool replicate_input, bool apply_weights) " "-> Tensor"); ops.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f6..dcb75c8bae162 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,9 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +# llm = LLM(model="TheBloke/Mixtral-8x7B-v0.1-GPTQ", revision="gptq-4bit-128g-actorder_True") +llm = LLM(model="TheBloke/Mixtral-8x7B-v0.1-GPTQ", enforce_eager=True) +# llm = LLM(model="TheBloke/Mixtral-8x7B-v0.1-GPTQ") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 9d8802904ebc3..8faed802d4c85 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -137,10 +137,10 @@ def compute_max_diff(output, output_ref): # UPSTREAM SYNC: breaks NM automation. -@pytest.mark.skip("C compiler not installed in NM automation. " - "This codepath follows a triton pathway, which " - "JITs using clang or gcc. Since neither are installed " - "in our test instances, we need to skip this for now.") +# @pytest.mark.skip("C compiler not installed in NM automation. " +# "This codepath follows a triton pathway, which " +# "JITs using clang or gcc. Since neither are installed " +# "in our test instances, we need to skip this for now.") @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) @@ -148,6 +148,13 @@ def compute_max_diff(output, output_ref): @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) +# @pytest.mark.parametrize("m", [512]) +# @pytest.mark.parametrize("n", [1024]) +# @pytest.mark.parametrize("k", [512]) +# @pytest.mark.parametrize("e", [8]) +# @pytest.mark.parametrize("topk", [6]) +# @pytest.mark.parametrize("group_size", [128]) +# @pytest.mark.parametrize("act_order", [True]) def test_fused_marlin_moe( m: int, n: int, @@ -157,6 +164,8 @@ def test_fused_marlin_moe( group_size: int, act_order: bool, ): + torch.manual_seed(7) + if topk > e: return @@ -239,6 +248,28 @@ def test_fused_marlin_moe( w1_scale=scales1, w2_scale=scales2) + # print("shape: ", marlin_output.shape, triton_output.shape) + + # failctr = 0 + # for i in range(m): + # for j in range(n): + # if abs(marlin_output[i][j].item() - triton_output[i][j].item()) > 0.1: + # print(m, n, marlin_output[i][j].item(), triton_output[i][j].item()) + # failctr += 1 + # if failctr == 50: + # break + + # if compute_max_diff(marlin_output, triton_output) >= 4e-2: + # torch.set_printoptions(profile="full") + # torch.set_printoptions(precision=2) + # diff_tab = ((100 * torch.abs(marlin_output - triton_output)).int()) / 100.0 + # for i in range(diff_tab.shape[0]): + # if torch.count_nonzero(diff_tab[i,:]) > 0: + # print(i, diff_tab[i,:]) + # # print(marlin_output[:128,:]) + # # print(triton_output[:128,:]) + # torch.set_printoptions(profile="default") + assert (compute_max_diff(marlin_output, triton_output) < 4e-2) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 195179df24941..bc70b404f43fd 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2,6 +2,7 @@ import functools import json import os +import sys from typing import Any, Dict, Optional, Tuple import torch @@ -604,7 +605,8 @@ def get_expert_offsets(sorted_token_ids: torch.Tensor, topk_ids: torch.Tensor, ex_blocks = (occurrences[i].item() + block_size_m - 1) // block_size_m expert_offsets[i + 1] = ex_blocks * block_size_m + expert_offsets[i] for i in range(len(occurrences), num_experts): - expert_offsets[i] = sorted_token_ids.size()[0] + expert_offsets[i + 1] = sorted_token_ids.size()[0] + print(expert_offsets) return torch.as_tensor(expert_offsets) @@ -678,9 +680,10 @@ def single_marlin_moe( expert_offsets = get_expert_offsets(sorted_token_ids, topk_ids, E, block_size_m) + print("expert offsets:", expert_offsets); intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w, sorted_token_ids, topk_weights, scales, g_idx, + hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, g_idx, rand_perm.int(), expert_offsets, workspace, M, N, K, True, num_tokens_post_padded, E, topk, block_size_m, True, False) @@ -759,30 +762,35 @@ def fused_marlin_moe(hidden_states: torch.Tensor, sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, block_size_m, E) - max_workspace_size = (max(N, K) // 64) * 16 + max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 workspace = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False) - expert_offsets = get_expert_offsets(sorted_token_ids, topk_ids, E, - block_size_m) + # expert_offsets = get_expert_offsets(sorted_token_ids, topk_ids, E, + # block_size_m) + expert_offsets = torch.empty((0)) + # print("expert offsets:", expert_offsets, topk_ids.flatten().shape, block_size_m) intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w1, sorted_token_ids, topk_weights, w1_scale, g_idx1, - rand_perm1, expert_offsets, workspace, M, 2 * N, K, True, - num_tokens_post_padded, E, topk, block_size_m, True, False) + hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale, g_idx1, + rand_perm1, expert_offsets, workspace, M, 2 * N, + K, True, E, topk, block_size_m, True, False + ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( - intermediate_cache2, w2, sorted_token_ids, topk_weights, w2_scale, - g_idx2, rand_perm2, expert_offsets, workspace, M, K, N, True, - num_tokens_post_padded, E, topk, block_size_m, False, True) + intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids, w2_scale, + g_idx2, rand_perm2, expert_offsets, workspace, M, K, N, True, E, topk, + block_size_m, False, True) + # intermediate_cache3 = torch.zeros((M, topk, K), device=hidden_states.device, + # dtype=hidden_states.dtype) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 7cb6a79156d81..1e765bb0d84c8 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -96,12 +96,12 @@ class MixtralMoE(nn.Module): def __init__( self, config: MixtralConfig, - experimental_fused_moe: bool, + use_fused_moe: bool, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.experimental_fused_moe = experimental_fused_moe + self.use_fused_moe = use_fused_moe self.quant_config = quant_config self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() @@ -118,7 +118,7 @@ def __init__( raise ValueError( f"Rank {self.rank} has no experts assigned to it.") - if self.experimental_fused_moe: + if self.use_fused_moe: params_dtype = torch.float16 self.experts = FusedMoE(num_experts=self.num_total_experts, top_k=self.top_k, @@ -149,8 +149,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) router_logits, _ = self.gate(hidden_states) - if self.experimental_fused_moe: - return self.experts(hidden_states.half(), router_logits).bfloat16() + if self.use_fused_moe: + ret = self.experts(hidden_states.half(), router_logits) + return ret.bfloat16() else: routing_weights = F.softmax(router_logits, dim=1, @@ -260,7 +261,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, - experimental_fused_moe: bool, + use_fused_moe: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -278,7 +279,7 @@ def __init__( quant_config=quant_config) self.block_sparse_moe = MixtralMoE( config=config, - experimental_fused_moe=experimental_fused_moe, + use_fused_moe=use_fused_moe, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -319,7 +320,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, - experimental_fused_moe: bool, + use_fused_moe: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -333,7 +334,7 @@ def __init__( ) self.layers = nn.ModuleList([ MixtralDecoderLayer(config, - experimental_fused_moe, + use_fused_moe, cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) @@ -369,13 +370,18 @@ def __init__( ) -> None: super().__init__() - # TODO have a better way to set this. - # Needs some testing/improving? - self.experimental_fused_moe = True + # print(config) + # print(cache_config) + # print(quant_config) + + # FP8 hasn't been tested. Works only with enforce-eager + self.use_fused_moe = True #(config.torch_dtype != torch.float8_e4m3fn and + #config.torch_dtype != torch.float16) + # print("use fused?", config.torch_dtype) self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, self.experimental_fused_moe, + self.model = MixtralModel(config, self.use_fused_moe, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, @@ -437,7 +443,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue - if self.experimental_fused_moe: + if self.use_fused_moe: if ("block_sparse_moe.experts." in name and ".w1." not in name and ".w2." not in name and ".w3." not in name @@ -475,7 +481,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] - if self.experimental_fused_moe and shard_id is not None: + if self.use_fused_moe and shard_id is not None: weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight, name, shard_id,