diff --git a/CMakeLists.txt b/CMakeLists.txt index a2ec6c0..87e9c81 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -348,6 +348,8 @@ endif() if (LLAMA_HIPBLAS) list(APPEND CMAKE_PREFIX_PATH /opt/rocm) + # enable fast atomic operation + add_compile_options(-munsafe-fp-atomics) if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang") diff --git a/README.md b/README.md index edd551f..5cb4c89 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ PowerInfer is a CPU/GPU LLM inference engine leveraging **activation locality** - [2024/6/11] We are thrilled to introduce [PowerInfer-2](https://arxiv.org/abs/2406.06282), our highly optimized inference framework designed specifically for smartphones. With TurboSparse-Mixtral-47B, it achieves an impressive speed of 11.68 tokens per second, which is up to 22 times faster than other state-of-the-art frameworks. - [2024/6/11] We are thrilled to present [Turbo Sparse](https://arxiv.org/abs/2406.05955), our TurboSparse models for fast inference. With just $0.1M, we sparsified the original Mistral and Mixtral model to nearly 90% sparsity while maintaining superior performance! For a Mixtral-level model, our TurboSparse-Mixtral activates only **4B** parameters! - [2024/5/20] **Competition Recruitment: CCF-TCArch Customized Computing Challenge 2024**. The CCF TCARCH CCC is a national competition organized by the Technical Committee on Computer Architecture (TCARCH) of the China Computer Federation (CCF). This year's competition aims to optimize the PowerInfer inference engine using the open-source ROCm/HIP. More information about the competition can be found [here](https://ccf-tcarch-ccc.github.io/2024/). -- [2024/5/17] We now provide support for AMD devices with ROCm. (WIP for models exceeding 40B). +- [2024/5/17] We now provide support for AMD devices with ROCm. - [2024/3/28] We are trilled to present [Bamboo LLM](https://github.com/SJTU-IPADS/Bamboo) that achieves both top-level performance and unparalleled speed with PowerInfer! Experience it with Bamboo-7B [Base](https://huggingface.co/PowerInfer/Bamboo-base-v0.1-gguf) / [DPO](https://huggingface.co/PowerInfer/Bamboo-DPO-v0.1-gguf). - [2024/3/14] We supported ProSparse Llama 2 ([7B](https://huggingface.co/SparseLLM/prosparse-llama-2-7b)/[13B](https://huggingface.co/SparseLLM/prosparse-llama-2-13b)), ReLU models with ~90% sparsity, matching original Llama 2's performance (Thanks THUNLP & ModelBest)! - [2024/1/11] We supported Windows with GPU inference! diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ca932ca..a5f2c4b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -93,6 +93,11 @@ #define GGML_CUDA_MAX_NODES 8192 +// +#define AXPY_BLOCK_Y 1 +#define AXPY_BLOCK_X 512 +#define AXPY_BLOCK_Z 256 + // define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication // on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant // for large computational tasks. the drawback is that this requires some extra amount of VRAM: @@ -4470,6 +4475,68 @@ static __global__ void dequantize_mul_mat_axpy(const void * __restrict__ vx, con } } +template +static __global__ void dequantize_mul_mat_axpy_sparse_pro(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int *lst, float *idx) { + const int thread_col = blockIdx.y * blockDim.x + threadIdx.x; + const int col = 2 * thread_col; + const int tid = threadIdx.x; + const int wid = threadIdx.y; + const int iqs = (col%qk) / qr; // x quant index + const int row_offset = blockIdx.z * AXPY_BLOCK_Z; + + if (col >= ncols) { + return; + } + + __shared__ dfloat dst_tmp[AXPY_BLOCK_Y][AXPY_BLOCK_X*2]; + __shared__ dfloat y_tmp[AXPY_BLOCK_Z]; + + dst_tmp[wid][tid] = 0.0; + dst_tmp[wid][tid+AXPY_BLOCK_X] = 0.0; + + if (wid == 0) { + if (lst) { + for(int i=tid; i 0; offset >>= 1) { + if (wid < offset) { + dst_tmp[wid][tid] += dst_tmp[wid+offset][tid]; + dst_tmp[wid][tid+AXPY_BLOCK_X] += dst_tmp[wid][tid+AXPY_BLOCK_X]; + } + __syncthreads(); + } + + if (wid == 0) { + const int iybs = col - col%qk; // y block start index + const int y_offset = qr == 1 ? 1 : qk/2; + atomicAdd(&dst[iybs + iqs], dst_tmp[wid][tid]); + atomicAdd(&dst[iybs + iqs + y_offset], dst_tmp[wid][tid+AXPY_BLOCK_X]); + } +} + template static __global__ void dequantize_mul_mat_axpy_sparse(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int *lst, float *idx) { // qk = quantized weights per x block @@ -4614,44 +4681,6 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr } } -// nrows: 11008(or 32 * x < 11008), ncols: 4096 -template -static __global__ void dequantize_mul_mat_axpy_sparse_lessatom(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int *lst, float *idx) { - int warp_id = threadIdx.y; - int tid = threadIdx.x + blockIdx.x * 32; - int col = tid * 2; - dfloat2 v; - int iqs = (col % qk) / qr; - float tmp[2]; - tmp[0] = 0.0; - tmp[1] = 0.0; - __shared__ float res[64]; - res[threadIdx.x] = 0.0; - res[threadIdx.x + 32] = 0.0; - -#pragma unroll 32 - for (int row = warp_id; row < nrows; row += 32) { - int raw_row = lst ? lst[row] : row; - // int raw_row = row; - dfloat y_row = y[raw_row]; - if (y_row == 0.0) { - continue; - } - const int ib = (row * ncols + col) / qk; - dequantize_kernel(vx, ib, iqs, v); - tmp[0] += v.x * y_row; - tmp[1] += v.y * y_row; - } - const int adder_loc = threadIdx.x % 16 + threadIdx.x / 16 * 32; - atomicAdd(res + adder_loc, tmp[0]); - atomicAdd(res + adder_loc + 16, tmp[1]); - __syncthreads(); - if (warp_id < 1) { - int write_back_loc = warp_id * 32 + threadIdx.x; - dst[write_back_loc + blockIdx.x * 64] = res[write_back_loc]; - } -} - template static __global__ void dequantize_mul_mat_vec_sparse(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int * lst, float * idx) { // qk = quantized weights per x block @@ -5636,10 +5665,15 @@ static void dequantize_axpy_vec_q4_0_cuda(const void * vx, const dfloat * y, flo } static void dequantize_axpy_sparse_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream, int *lst, float *idx) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const dim3 block_dim = dim3(32, 32); - const int block_num = (ncols + 63) / 64; - dequantize_mul_mat_axpy_sparse_lessatom - <<>>(vx, y, dst, ncols, nrows, lst, idx); + const int block_num_y = (ncols + AXPY_BLOCK_X*2 - 1) / AXPY_BLOCK_X / 2; + const int block_num_z = nrows / AXPY_BLOCK_Z; + const dim3 block_nums(1, block_num_y, block_num_z); + const dim3 block_dims(AXPY_BLOCK_X, AXPY_BLOCK_Y, 1); + // dequantize_mul_mat_axpy + // <<>>(vx, y, dst, ncols, nrows); + // printf("launch kernel: (%d, %d)\n", block_num_x, block_num_y); + dequantize_mul_mat_axpy_sparse_pro + <<>>(vx, y, dst, ncols, AXPY_BLOCK_Z, lst, idx); } static void dequantize_axpy_sparse_batch_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_rows, int src1_ncols, cudaStream_t stream, int *lst, float *idx) { diff --git a/llama.cpp b/llama.cpp index 121062f..3ae9e94 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4402,6 +4402,16 @@ static struct ggml_tensor * llm_build_sparse_mul_mat( std::string full_name = "ffn_" + std::string(name) + "_sparse"; ggml_tensor * out = nullptr; +#ifdef GGML_USE_HIPBLAS +// WARNING: THIS IS A HACK! +// if up_gpu->data is null +// inference fails when model exceeds 40B on rocm device +// so we just let up_gpu->data point to itself + + up_gpu->data = up_gpu; + +#endif + #ifdef GGML_USE_CUBLAS // Full offloading fast path if (full_gpu) { @@ -4445,6 +4455,16 @@ static struct ggml_tensor * llm_build_sparse_axpy( std::string full_name = "ffn_" + std::string(name) + "_sparse"; ggml_tensor * out = nullptr; +#ifdef GGML_USE_HIPBLAS +// WARNING: THIS IS A HACK! +// if wt_gpu->data is null +// inference fails when model exceeds 40B on rocm device +// so we just let wt_gpu->data point to itself + + wt_gpu->data = wt_gpu; + +#endif + #ifdef GGML_USE_CUBLAS // Full offloading fast path if (full_gpu) {