Skip to content

Commit

Permalink
Fix segmentation fault for models exceeding 40B on AMD GPUs & optimiz…
Browse files Browse the repository at this point in the history
…e mul_mat_axpy operation (#217)

* fix segementation fault when model exceeds 40B on ROCm platform

* optimize axpy kernel

* optimize op: mulmat_axpy_sparse

* fix bug when model exceeds 40B on AMD GPU

* optimize op: mulmat_axpy_sparse

---------

Co-authored-by: tworan <[email protected]>
  • Loading branch information
Tworan and tworan authored Sep 6, 2024
1 parent 61cac9b commit 6ae7e06
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 43 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ endif()

if (LLAMA_HIPBLAS)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
# enable fast atomic operation
add_compile_options(-munsafe-fp-atomics)

if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ PowerInfer is a CPU/GPU LLM inference engine leveraging **activation locality**
- [2024/6/11] We are thrilled to introduce [PowerInfer-2](https://arxiv.org/abs/2406.06282), our highly optimized inference framework designed specifically for smartphones. With TurboSparse-Mixtral-47B, it achieves an impressive speed of 11.68 tokens per second, which is up to 22 times faster than other state-of-the-art frameworks.
- [2024/6/11] We are thrilled to present [Turbo Sparse](https://arxiv.org/abs/2406.05955), our TurboSparse models for fast inference. With just $0.1M, we sparsified the original Mistral and Mixtral model to nearly 90% sparsity while maintaining superior performance! For a Mixtral-level model, our TurboSparse-Mixtral activates only **4B** parameters!
- [2024/5/20] **Competition Recruitment: CCF-TCArch Customized Computing Challenge 2024**. The CCF TCARCH CCC is a national competition organized by the Technical Committee on Computer Architecture (TCARCH) of the China Computer Federation (CCF). This year's competition aims to optimize the PowerInfer inference engine using the open-source ROCm/HIP. More information about the competition can be found [here](https://ccf-tcarch-ccc.github.io/2024/).
- [2024/5/17] We now provide support for AMD devices with ROCm. (WIP for models exceeding 40B).
- [2024/5/17] We now provide support for AMD devices with ROCm.
- [2024/3/28] We are trilled to present [Bamboo LLM](https://github.com/SJTU-IPADS/Bamboo) that achieves both top-level performance and unparalleled speed with PowerInfer! Experience it with Bamboo-7B [Base](https://huggingface.co/PowerInfer/Bamboo-base-v0.1-gguf) / [DPO](https://huggingface.co/PowerInfer/Bamboo-DPO-v0.1-gguf).
- [2024/3/14] We supported ProSparse Llama 2 ([7B](https://huggingface.co/SparseLLM/prosparse-llama-2-7b)/[13B](https://huggingface.co/SparseLLM/prosparse-llama-2-13b)), ReLU models with ~90% sparsity, matching original Llama 2's performance (Thanks THUNLP & ModelBest)!
- [2024/1/11] We supported Windows with GPU inference!
Expand Down
118 changes: 76 additions & 42 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@

#define GGML_CUDA_MAX_NODES 8192

//
#define AXPY_BLOCK_Y 1
#define AXPY_BLOCK_X 512
#define AXPY_BLOCK_Z 256

// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
Expand Down Expand Up @@ -4470,6 +4475,68 @@ static __global__ void dequantize_mul_mat_axpy(const void * __restrict__ vx, con
}
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_axpy_sparse_pro(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int *lst, float *idx) {
const int thread_col = blockIdx.y * blockDim.x + threadIdx.x;
const int col = 2 * thread_col;
const int tid = threadIdx.x;
const int wid = threadIdx.y;
const int iqs = (col%qk) / qr; // x quant index
const int row_offset = blockIdx.z * AXPY_BLOCK_Z;

if (col >= ncols) {
return;
}

__shared__ dfloat dst_tmp[AXPY_BLOCK_Y][AXPY_BLOCK_X*2];
__shared__ dfloat y_tmp[AXPY_BLOCK_Z];

dst_tmp[wid][tid] = 0.0;
dst_tmp[wid][tid+AXPY_BLOCK_X] = 0.0;

if (wid == 0) {
if (lst) {
for(int i=tid; i<AXPY_BLOCK_Z; i += AXPY_BLOCK_X) {
y_tmp[i] = y[lst[i+row_offset]];
}
} else {
for(int i=tid; i<AXPY_BLOCK_Z; i += AXPY_BLOCK_X) {
// ((dfloat4*)y_tmp)[i] = *(dfloat4*)(&y[i*4+row_offset]);
y_tmp[i] = y[i+row_offset];
}
}
}
__syncthreads();

#pragma unroll 8
for(int gpu_row = wid; gpu_row < nrows; gpu_row += AXPY_BLOCK_Y) {
if(y_tmp[gpu_row] == 0.0) continue;

const int ib = ((gpu_row + row_offset)*ncols + col) / qk; // x block index

dfloat2 v;
dequantize_kernel(vx, ib, iqs, v);

dst_tmp[wid][tid] += v.x * y_tmp[gpu_row];
dst_tmp[wid][tid+AXPY_BLOCK_X] += v.y * y_tmp[gpu_row];
}

for (int offset = AXPY_BLOCK_Y / 2; offset > 0; offset >>= 1) {
if (wid < offset) {
dst_tmp[wid][tid] += dst_tmp[wid+offset][tid];
dst_tmp[wid][tid+AXPY_BLOCK_X] += dst_tmp[wid][tid+AXPY_BLOCK_X];
}
__syncthreads();
}

if (wid == 0) {
const int iybs = col - col%qk; // y block start index
const int y_offset = qr == 1 ? 1 : qk/2;
atomicAdd(&dst[iybs + iqs], dst_tmp[wid][tid]);
atomicAdd(&dst[iybs + iqs + y_offset], dst_tmp[wid][tid+AXPY_BLOCK_X]);
}
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_axpy_sparse(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int *lst, float *idx) {
// qk = quantized weights per x block
Expand Down Expand Up @@ -4614,44 +4681,6 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr
}
}

// nrows: 11008(or 32 * x < 11008), ncols: 4096
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_axpy_sparse_lessatom(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int *lst, float *idx) {
int warp_id = threadIdx.y;
int tid = threadIdx.x + blockIdx.x * 32;
int col = tid * 2;
dfloat2 v;
int iqs = (col % qk) / qr;
float tmp[2];
tmp[0] = 0.0;
tmp[1] = 0.0;
__shared__ float res[64];
res[threadIdx.x] = 0.0;
res[threadIdx.x + 32] = 0.0;

#pragma unroll 32
for (int row = warp_id; row < nrows; row += 32) {
int raw_row = lst ? lst[row] : row;
// int raw_row = row;
dfloat y_row = y[raw_row];
if (y_row == 0.0) {
continue;
}
const int ib = (row * ncols + col) / qk;
dequantize_kernel(vx, ib, iqs, v);
tmp[0] += v.x * y_row;
tmp[1] += v.y * y_row;
}
const int adder_loc = threadIdx.x % 16 + threadIdx.x / 16 * 32;
atomicAdd(res + adder_loc, tmp[0]);
atomicAdd(res + adder_loc + 16, tmp[1]);
__syncthreads();
if (warp_id < 1) {
int write_back_loc = warp_id * 32 + threadIdx.x;
dst[write_back_loc + blockIdx.x * 64] = res[write_back_loc];
}
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec_sparse(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int * lst, float * idx) {
// qk = quantized weights per x block
Expand Down Expand Up @@ -5636,10 +5665,15 @@ static void dequantize_axpy_vec_q4_0_cuda(const void * vx, const dfloat * y, flo
}
static void dequantize_axpy_sparse_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const dim3 block_dim = dim3(32, 32);
const int block_num = (ncols + 63) / 64;
dequantize_mul_mat_axpy_sparse_lessatom<QK4_0, QR4_0, dequantize_q4_0>
<<<block_num, block_dim, 0, stream>>>(vx, y, dst, ncols, nrows, lst, idx);
const int block_num_y = (ncols + AXPY_BLOCK_X*2 - 1) / AXPY_BLOCK_X / 2;
const int block_num_z = nrows / AXPY_BLOCK_Z;
const dim3 block_nums(1, block_num_y, block_num_z);
const dim3 block_dims(AXPY_BLOCK_X, AXPY_BLOCK_Y, 1);
// dequantize_mul_mat_axpy<QK4_0, QR4_0, dequantize_q4_0>
// <<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows);
// printf("launch kernel: (%d, %d)\n", block_num_x, block_num_y);
dequantize_mul_mat_axpy_sparse_pro<QK4_0, QR4_0, dequantize_q4_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, AXPY_BLOCK_Z, lst, idx);
}

static void dequantize_axpy_sparse_batch_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_rows, int src1_ncols, cudaStream_t stream, int *lst, float *idx) {
Expand Down
20 changes: 20 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4402,6 +4402,16 @@ static struct ggml_tensor * llm_build_sparse_mul_mat(
std::string full_name = "ffn_" + std::string(name) + "_sparse";
ggml_tensor * out = nullptr;

#ifdef GGML_USE_HIPBLAS
// WARNING: THIS IS A HACK!
// if up_gpu->data is null
// inference fails when model exceeds 40B on rocm device
// so we just let up_gpu->data point to itself

up_gpu->data = up_gpu;

#endif

#ifdef GGML_USE_CUBLAS
// Full offloading fast path
if (full_gpu) {
Expand Down Expand Up @@ -4445,6 +4455,16 @@ static struct ggml_tensor * llm_build_sparse_axpy(
std::string full_name = "ffn_" + std::string(name) + "_sparse";
ggml_tensor * out = nullptr;

#ifdef GGML_USE_HIPBLAS
// WARNING: THIS IS A HACK!
// if wt_gpu->data is null
// inference fails when model exceeds 40B on rocm device
// so we just let wt_gpu->data point to itself

wt_gpu->data = wt_gpu;

#endif

#ifdef GGML_USE_CUBLAS
// Full offloading fast path
if (full_gpu) {
Expand Down

0 comments on commit 6ae7e06

Please sign in to comment.