From 9967a68bb0ae8fbab4adcac3f11fab29337e34ac Mon Sep 17 00:00:00 2001 From: turboderp Date: Fri, 9 Jun 2023 22:50:00 +0200 Subject: [PATCH 01/15] Updated todo --- TODO.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/TODO.md b/TODO.md index ea251e6e..23459a74 100644 --- a/TODO.md +++ b/TODO.md @@ -9,15 +9,16 @@ ## GPU compatibility (etc.) -- [ ] Support for ROCm/AMD GPUs +- [x] Support for ROCm/AMD GPUs +- [ ] Optimize more for ROCm - [ ] Test that CUDA code works on GTX 10-series and RTX 20-series at some point - [x] Test performance on P40 (would be a good GPU to support) - [ ] Improve performance on P40 - [x] Tunable kernel parameters - [ ] More tunable kernel parameters - [x] Test on Windows -- [ ] Easier extension loading on Windows -- [ ] Setup instructions for Windows +- [x] Easier extension loading on Windows +- [x] Setup instructions for Windows ## Testing @@ -34,7 +35,7 @@ - [x] Support for de-quantizing select matrices at load time - [x] ~~Better vector-matrix multiplication for de-quantized matrices~~ (dequant was a dead end) -- [ ] Fused QKV projection +- [x] Fused QKV projection - [x] Fused MLP - [x] Fused RoPE - [x] ~~Build attention mask in CUDA rather than PyTorch~~ From a3c352fe21834eb517a1976e9b280117a39d58c2 Mon Sep 17 00:00:00 2001 From: turboderp Date: Sat, 10 Jun 2023 16:30:38 +0200 Subject: [PATCH 02/15] Clean up kernel switching --- exllama_ext/cuda_func/column_remap.cu | 1 - exllama_ext/cuda_func/q4_matmul.cu | 122 ++++++++++++-------------- exllama_ext/cuda_func/q4_mlp.cu | 40 ++++----- exllama_ext/cuda_func/rms_norm.cu | 55 +++++++++--- exllama_ext/cuda_func/rope.cu | 31 +++++-- 5 files changed, 143 insertions(+), 106 deletions(-) diff --git a/exllama_ext/cuda_func/column_remap.cu b/exllama_ext/cuda_func/column_remap.cu index 4beb5b4a..65514c59 100644 --- a/exllama_ext/cuda_func/column_remap.cu +++ b/exllama_ext/cuda_func/column_remap.cu @@ -33,7 +33,6 @@ __global__ void column_remap_kernel } } - // Remap columns in x to correspond to sequential group index before matmul // // perform x -> seq_x such that seq_x @ seq_w == x @ w diff --git a/exllama_ext/cuda_func/q4_matmul.cu b/exllama_ext/cuda_func/q4_matmul.cu index 6310f486..7d793928 100644 --- a/exllama_ext/cuda_func/q4_matmul.cu +++ b/exllama_ext/cuda_func/q4_matmul.cu @@ -11,12 +11,28 @@ const int THREADS_X = 32; // Block size and thread count along columns in w and out const int THREADS_Y = 1; // Block size and thread count along rows in x and out +typedef void (*fp_q4_matmul_kernel) +( + const half*, + const uint32_t*, + half*, + const half*, + const uint32_t*, + const int, + const int, + const int, + const int, + const int, + const uint32_t*, + bool no_zero +); + template __global__ void q4_matmul_kernel ( const half* __restrict__ x, const uint32_t* __restrict__ w, - half* __restrict__ out, // (y) + half* __restrict__ out, const half* __restrict__ w_scales, const uint32_t* __restrict__ w_zeros, const int height, @@ -48,12 +64,6 @@ __global__ void q4_matmul_kernel // Zero output -// if (!no_zero && blockIdx.z == 0) -// { -// out_.set(x_row, w_column, {}); -// __syncthreads(); -// } - if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) { *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; @@ -70,33 +80,23 @@ __global__ void q4_matmul_kernel // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this // could be slightly faster - for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; ) + for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) { if constexpr (use_half2) { half2 w_scale = w_scales_.item_half2half2(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - if constexpr (use_x_map) - acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else - acc = dot_product_8(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - - group++; - k += groupsize; + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); } else { half w_scale = w_scales_.item(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - if constexpr (use_x_map) - acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else - acc_h = dot_product_8_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - - group++; - k += groupsize; + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); } } } @@ -104,36 +104,25 @@ __global__ void q4_matmul_kernel { // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache - if constexpr (use_half2) + for (int k = x_column; k < x_column + iterations * 8; k += 8) { - for (int k = x_column; k < x_column + iterations * 8; ) + if constexpr (use_half2) { int group = k / groupsize; half2 w_scale = w_scales_.item_half2half2(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - if constexpr (use_x_map) - acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else - acc = dot_product_8(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - - k += 8; + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); } - } - else - { - for (int k = x_column; k < x_column + iterations * 8; ) + else { int group = k / groupsize; half w_scale = w_scales_.item(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - if constexpr (use_x_map) - acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else - acc_h = dot_product_8_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - - k += 8; + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); } } } @@ -151,6 +140,29 @@ __global__ void q4_matmul_kernel } } +fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) +{ + // + if (tuningParams->matmul_no_half2) { + if (block_size_z % groupsize == 0) { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } else { + if (block_size_z % groupsize == 0) + { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } +}; + // Compute y = x @ w void q4_matmul_cuda @@ -199,32 +211,8 @@ void q4_matmul_cuda (dim + block_size_z - 1) / block_size_z ); - if (tuningParams->matmul_no_half2) - { - if (block_size_z % w->groupsize == 0) - { - if (x_map) q4_matmul_kernel <<>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); - else q4_matmul_kernel <<>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, NULL, no_zero); - } - else - { - if (x_map) q4_matmul_kernel <<>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); - else q4_matmul_kernel <<>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, NULL, no_zero); - } - } - else - { - if (block_size_z % w->groupsize == 0) - { - if (x_map) q4_matmul_kernel <<>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); - else q4_matmul_kernel <<>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, NULL, no_zero); - } - else - { - if (x_map) q4_matmul_kernel <<>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); - else q4_matmul_kernel <<>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, NULL, no_zero); - } - } + fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); } void q4_matmul_recons_cuda @@ -257,4 +245,4 @@ void q4_matmul_recons_cuda const half beta = __float2half(0.0f); cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); -} \ No newline at end of file +} diff --git a/exllama_ext/cuda_func/q4_mlp.cu b/exllama_ext/cuda_func/q4_mlp.cu index 036d8b0c..77fa09a4 100644 --- a/exllama_ext/cuda_func/q4_mlp.cu +++ b/exllama_ext/cuda_func/q4_mlp.cu @@ -34,6 +34,14 @@ __device__ __forceinline__ half2 silu(half2 x) return result; } +typedef void (*fp_silu_mul_cuda_kernel) +( + half*, + const half*, + const int, + const int +); + template __global__ void silu_mul_cuda_kernel ( @@ -78,6 +86,16 @@ __global__ void silu_mul_cuda_kernel } } +fp_silu_mul_cuda_kernel silu_mul_cuda_kernel_pick(ExLlamaTuning* tuningParams) +{ + // + if (tuningParams->matmul_no_half2) { + return silu_mul_cuda_kernel; + } else { + return silu_mul_cuda_kernel; + } +}; + void q4_mlp_cuda ( ExLlamaTuning* tuningParams, @@ -116,26 +134,8 @@ void q4_mlp_cuda 1 ); - if (tuningParams->silu_no_half2) - { - silu_mul_cuda_kernel<<>> - ( - buffers->temp_mlp, - buffers->temp_mlp + height * up->width, - height, - up->width - ); - } - else - { - silu_mul_cuda_kernel<<>> - ( - buffers->temp_mlp, - buffers->temp_mlp + height * up->width, - height, - up->width - ); - } + fp_silu_mul_cuda_kernel kernel = silu_mul_cuda_kernel_pick(tuningParams); + kernel<<>>(buffers->temp_mlp, buffers->temp_mlp + height * up->width, height, up->width); // x += temp1 @ down (implicitly add the residual connection by not zeroing the output in the matmul) diff --git a/exllama_ext/cuda_func/rms_norm.cu b/exllama_ext/cuda_func/rms_norm.cu index 388809bb..d31deba0 100644 --- a/exllama_ext/cuda_func/rms_norm.cu +++ b/exllama_ext/cuda_func/rms_norm.cu @@ -9,6 +9,14 @@ const int BLOCKSIZE_X = 16; // scratch = sum(x * x, dim = -1) +typedef void (*fp_rms_norm_row_product_kernel) +( + half*, + float*, + const int, + const int +); + template __global__ void rms_norm_row_product_kernel ( @@ -72,6 +80,18 @@ __global__ void rms_norm_row_product_kernel // x = x * w / sqrt(scratch / dim + epsilon) +typedef void (*fp_rms_norm_kernel) +( + half*, + const half*, + half*, + float*, + const float, + const float, + const int, + const int +); + template __global__ void rms_norm_kernel ( @@ -131,6 +151,26 @@ __global__ void rms_norm_kernel // if (column >= dim - BLOCKSIZE_X) scratch[row] = 0.0f; } +fp_rms_norm_row_product_kernel rms_norm_row_product_kernel_pick(ExLlamaTuning* tuningParams) +{ + // + if (tuningParams->matmul_no_half2) { + return rms_norm_row_product_kernel; + } else { + return rms_norm_row_product_kernel; + } +}; + +fp_rms_norm_kernel rms_norm_kernel_pick(ExLlamaTuning* tuningParams) +{ + // + if (tuningParams->matmul_no_half2) { + return rms_norm_kernel; + } else { + return rms_norm_kernel; + } +}; + // x = x * w / sqrt(row_mean(x * x) + epsilon) // // works in-place if x == out @@ -163,16 +203,11 @@ void rms_norm_cuda //cudaMemsetAsync(temp, 0, rows * sizeof(float)); - if (tuningParams->rmsnorm_no_half2) - { - rms_norm_row_product_kernel<<>>(x, temp, rows, dim); - rms_norm_kernel<<>>(x, w, out, temp, epsilon, r_dim, rows, dim); - } - else - { - rms_norm_row_product_kernel<<>>(x, temp, rows, dim); - rms_norm_kernel<<>>(x, w, out, temp, epsilon, r_dim, rows, dim); - } + fp_rms_norm_row_product_kernel kernel1 = rms_norm_row_product_kernel_pick(tuningParams); + kernel1<<>>(x, temp, rows, dim); + + fp_rms_norm_kernel kernel2 = rms_norm_kernel_pick(tuningParams); + kernel2<<>>(x, w, out, temp, epsilon, r_dim, rows, dim); //cudaMemsetAsync(temp, 0, rows * sizeof(float)); } diff --git a/exllama_ext/cuda_func/rope.cu b/exllama_ext/cuda_func/rope.cu index b0bc2d94..d8fdeaec 100644 --- a/exllama_ext/cuda_func/rope.cu +++ b/exllama_ext/cuda_func/rope.cu @@ -6,6 +6,17 @@ const int THREADS_X = 32; const int THREADS_Y = 4; const int MAX_POS_EMBEDDINGS = 32768; // Actual number doesn't matter +typedef void (*fp_rope_cuda_kernel) +( + half*, + const half*, + const half*, + int, + int, + int, + int +); + template __global__ void rope_cuda_kernel ( @@ -73,6 +84,16 @@ __global__ void rope_cuda_kernel } } +fp_rope_cuda_kernel rope_cuda_kernel_pick(ExLlamaTuning* tuningParams) +{ + // + if (tuningParams->matmul_no_half2) { + return rope_cuda_kernel; + } else { + return rope_cuda_kernel; + } +}; + void rope_cuda ( ExLlamaTuning* tuningParams, @@ -94,12 +115,6 @@ void rope_cuda 1 ); - if (tuningParams->rope_no_half2) - { - rope_cuda_kernel<<>>(x, sin, cos, rows, head_dim, num_heads, past_len); - } - else - { - rope_cuda_kernel<<>>(x, sin, cos, rows, head_dim, num_heads, past_len); - } + fp_rope_cuda_kernel kernel = rope_cuda_kernel_pick(tuningParams); + kernel<<>>(x, sin, cos, rows, head_dim, num_heads, past_len); } From 94a29080bcd58c5c8732d5150eb3a532869fe4c0 Mon Sep 17 00:00:00 2001 From: turboderp Date: Sat, 10 Jun 2023 16:32:01 +0200 Subject: [PATCH 03/15] Remove some unused code --- exllama_ext/matrix.cuh | 178 ----------------------------------------- 1 file changed, 178 deletions(-) diff --git a/exllama_ext/matrix.cuh b/exllama_ext/matrix.cuh index f62ec1e3..2ee2f729 100644 --- a/exllama_ext/matrix.cuh +++ b/exllama_ext/matrix.cuh @@ -308,182 +308,4 @@ __device__ inline half dot_product_8_x_map_h return result; } -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v1 and v2, constant zero/scale - -__device__ inline half2 dot_product_8_dual -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v1_, - MatrixView_q4_column& v2_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v1_scale_2, - const uint32_t v1_zero, // + 1 (!!) - const half2 v2_scale_2, - const uint32_t v2_zero, // + 1 (!!) - const int count -) -{ - const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); - const uint32_t* v1_ptr = (const uint32_t*) v1_.item_uint32_ptr(v_row, v_column); - const uint32_t* v2_ptr = (const uint32_t*) v2_.item_uint32_ptr(v_row, v_column); - half2 result1 = {}; - half2 result2 = {}; - - for (int i = 0; i < count; i++) - { - uint32_t v1_read = *v1_ptr; v1_ptr += v1_.width; - uint32_t v2_read = *v2_ptr; v2_ptr += v2_.width; - - half v1_0 = __int2half_rn((int)((v1_read ) & 0x0f) - v1_zero); - half v1_1 = __int2half_rn((int)((v1_read >> 4) & 0x0f) - v1_zero); - half v1_2 = __int2half_rn((int)((v1_read >> 8) & 0x0f) - v1_zero); - half v1_3 = __int2half_rn((int)((v1_read >> 12) & 0x0f) - v1_zero); - half v1_4 = __int2half_rn((int)((v1_read >> 16) & 0x0f) - v1_zero); - half v1_5 = __int2half_rn((int)((v1_read >> 20) & 0x0f) - v1_zero); - half v1_6 = __int2half_rn((int)((v1_read >> 24) & 0x0f) - v1_zero); - half v1_7 = __int2half_rn((int)((v1_read >> 28) ) - v1_zero); - - half v2_0 = __int2half_rn((int)((v2_read ) & 0x0f) - v2_zero); - half v2_1 = __int2half_rn((int)((v2_read >> 4) & 0x0f) - v2_zero); - half v2_2 = __int2half_rn((int)((v2_read >> 8) & 0x0f) - v2_zero); - half v2_3 = __int2half_rn((int)((v2_read >> 12) & 0x0f) - v2_zero); - half v2_4 = __int2half_rn((int)((v2_read >> 16) & 0x0f) - v2_zero); - half v2_5 = __int2half_rn((int)((v2_read >> 20) & 0x0f) - v2_zero); - half v2_6 = __int2half_rn((int)((v2_read >> 24) & 0x0f) - v2_zero); - half v2_7 = __int2half_rn((int)((v2_read >> 28) ) - v2_zero); - - half2 v1_01 = __halves2half2(v1_0, v1_1); - half2 v1_23 = __halves2half2(v1_2, v1_3); - half2 v1_45 = __halves2half2(v1_4, v1_5); - half2 v1_67 = __halves2half2(v1_6, v1_7); - - half2 v2_01 = __halves2half2(v2_0, v2_1); - half2 v2_23 = __halves2half2(v2_2, v2_3); - half2 v2_45 = __halves2half2(v2_4, v2_5); - half2 v2_67 = __halves2half2(v2_6, v2_7); - - v1_01 = __hmul2(v1_01, v1_scale_2); - v1_23 = __hmul2(v1_23, v1_scale_2); - v1_45 = __hmul2(v1_45, v1_scale_2); - v1_67 = __hmul2(v1_67, v1_scale_2); - - v2_01 = __hmul2(v2_01, v2_scale_2); - v2_23 = __hmul2(v2_23, v2_scale_2); - v2_45 = __hmul2(v2_45, v2_scale_2); - v2_67 = __hmul2(v2_67, v2_scale_2); - - half2 h_01 = *h_ptr++; - half2 h_23 = *h_ptr++; - half2 h_45 = *h_ptr++; - half2 h_67 = *h_ptr++; - - result1 = __hfma2(h_01, v1_01, result1); - result1 = __hfma2(h_23, v1_23, result1); - result1 = __hfma2(h_45, v1_45, result1); - result1 = __hfma2(h_67, v1_67, result1); - - result2 = __hfma2(h_01, v2_01, result2); - result2 = __hfma2(h_23, v2_23, result2); - result2 = __hfma2(h_45, v2_45, result2); - result2 = __hfma2(h_67, v2_67, result2); - } - - half result1_ = __hadd(result1.x, result1.y); - half result2_ = __hadd(result2.x, result2.y); - - return __hadd2(acc, __halves2half2(result1_, result2_)); -} - -__device__ inline half2 dot_product_8_dual_buffered -( - const half2 acc, - const half* x_row_buffer, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v1_, - MatrixView_q4_column& v2_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v1_scale_2, - const uint32_t v1_zero, // + 1 (!!) - const half2 v2_scale_2, - const uint32_t v2_zero, // + 1 (!!) - const int count -) -{ - const half2* h_ptr = (const half2*) &x_row_buffer[h_column]; - const uint32_t* v1_ptr = (const uint32_t*) v1_.item_uint32_ptr(v_row, v_column); - const uint32_t* v2_ptr = (const uint32_t*) v2_.item_uint32_ptr(v_row, v_column); - half2 result1 = {}; - half2 result2 = {}; - - for (int i = 0; i < count; i++) - { - uint32_t v1_read = *v1_ptr; v1_ptr += v1_.width; - uint32_t v2_read = *v2_ptr; v2_ptr += v2_.width; - - half v1_0 = __int2half_rn((int)((v1_read ) & 0x0f) - v1_zero); - half v1_1 = __int2half_rn((int)((v1_read >> 4) & 0x0f) - v1_zero); - half v1_2 = __int2half_rn((int)((v1_read >> 8) & 0x0f) - v1_zero); - half v1_3 = __int2half_rn((int)((v1_read >> 12) & 0x0f) - v1_zero); - half v1_4 = __int2half_rn((int)((v1_read >> 16) & 0x0f) - v1_zero); - half v1_5 = __int2half_rn((int)((v1_read >> 20) & 0x0f) - v1_zero); - half v1_6 = __int2half_rn((int)((v1_read >> 24) & 0x0f) - v1_zero); - half v1_7 = __int2half_rn((int)((v1_read >> 28) ) - v1_zero); - - half v2_0 = __int2half_rn((int)((v2_read ) & 0x0f) - v2_zero); - half v2_1 = __int2half_rn((int)((v2_read >> 4) & 0x0f) - v2_zero); - half v2_2 = __int2half_rn((int)((v2_read >> 8) & 0x0f) - v2_zero); - half v2_3 = __int2half_rn((int)((v2_read >> 12) & 0x0f) - v2_zero); - half v2_4 = __int2half_rn((int)((v2_read >> 16) & 0x0f) - v2_zero); - half v2_5 = __int2half_rn((int)((v2_read >> 20) & 0x0f) - v2_zero); - half v2_6 = __int2half_rn((int)((v2_read >> 24) & 0x0f) - v2_zero); - half v2_7 = __int2half_rn((int)((v2_read >> 28) ) - v2_zero); - - half2 v1_01 = __halves2half2(v1_0, v1_1); - half2 v1_23 = __halves2half2(v1_2, v1_3); - half2 v1_45 = __halves2half2(v1_4, v1_5); - half2 v1_67 = __halves2half2(v1_6, v1_7); - - half2 v2_01 = __halves2half2(v2_0, v2_1); - half2 v2_23 = __halves2half2(v2_2, v2_3); - half2 v2_45 = __halves2half2(v2_4, v2_5); - half2 v2_67 = __halves2half2(v2_6, v2_7); - - v1_01 = __hmul2(v1_01, v1_scale_2); - v1_23 = __hmul2(v1_23, v1_scale_2); - v1_45 = __hmul2(v1_45, v1_scale_2); - v1_67 = __hmul2(v1_67, v1_scale_2); - - v2_01 = __hmul2(v2_01, v2_scale_2); - v2_23 = __hmul2(v2_23, v2_scale_2); - v2_45 = __hmul2(v2_45, v2_scale_2); - v2_67 = __hmul2(v2_67, v2_scale_2); - - half2 h_01 = *h_ptr++; - half2 h_23 = *h_ptr++; - half2 h_45 = *h_ptr++; - half2 h_67 = *h_ptr++; - - result1 = __hfma2(h_01, v1_01, result1); - result1 = __hfma2(h_23, v1_23, result1); - result1 = __hfma2(h_45, v1_45, result1); - result1 = __hfma2(h_67, v1_67, result1); - - result2 = __hfma2(h_01, v2_01, result2); - result2 = __hfma2(h_23, v2_23, result2); - result2 = __hfma2(h_45, v2_45, result2); - result2 = __hfma2(h_67, v2_67, result2); - } - - half result1_ = __hadd(result1.x, result1.y); - half result2_ = __hadd(result2.x, result2.y); - - return __hadd2(acc, __halves2half2(result1_, result2_)); -} - #endif \ No newline at end of file From c73d921950f5ffb93316fb665df3fa64f2783a65 Mon Sep 17 00:00:00 2001 From: turboderp Date: Sat, 10 Jun 2023 20:02:29 +0200 Subject: [PATCH 04/15] Small matmul optimization --- exllama_ext/cuda_buffers.cu | 23 +++++++ exllama_ext/cuda_buffers.cuh | 4 ++ exllama_ext/matrix.cuh | 127 +++++++++++++++-------------------- 3 files changed, 82 insertions(+), 72 deletions(-) diff --git a/exllama_ext/cuda_buffers.cu b/exllama_ext/cuda_buffers.cu index 844209ec..a6e4ea9d 100644 --- a/exllama_ext/cuda_buffers.cu +++ b/exllama_ext/cuda_buffers.cu @@ -1,6 +1,10 @@ +#define _cuda_buffers_cu #include "cuda_buffers.cuh" CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; +// __constant__ half2 q4_table[16][256]; +// half2 q4_table_host[16][256]; +// bool q4_table_init = false; CudaBuffers::CudaBuffers ( @@ -64,4 +68,23 @@ void prepare_buffers_cuda ); g_buffers[_device] = buffers; + +// if (!q4_table_init) +// { +// for (uint v_zero = 0; v_zero < 16; v_zero++) +// { +// for (uint v_read = 0; v_read < 256; v_read++) +// { +// half v_0 = __float2half((float)((int)((v_read ) & 0x0f) - v_zero - 1)); +// half v_1 = __float2half((float)((int)((v_read >> 4) & 0x0f) - v_zero - 1)); +// half2 v_01 = {v_0, v_1}; +// q4_table_host[v_zero][v_read] = v_01; +// } +// } +// q4_table_init = true; +// } +// +// cudaSetDevice(_device); +// cudaMemcpyToSymbol(q4_table, q4_table_host, 16 * 256 * sizeof(half2)); +// cudaDeviceSynchronize(); } diff --git a/exllama_ext/cuda_buffers.cuh b/exllama_ext/cuda_buffers.cuh index 92dd8135..ce948a12 100644 --- a/exllama_ext/cuda_buffers.cuh +++ b/exllama_ext/cuda_buffers.cuh @@ -8,6 +8,10 @@ const int CUDA_MAX_DEVICES = 16; +// #ifndef _cuda_buffers_cu +// extern __constant__ half2 q4_table[16][256]; +// #endif + class CudaBuffers { public: diff --git a/exllama_ext/matrix.cuh b/exllama_ext/matrix.cuh index 2ee2f729..1b9ad165 100644 --- a/exllama_ext/matrix.cuh +++ b/exllama_ext/matrix.cuh @@ -4,6 +4,8 @@ #include #include +//#include "cuda_buffers.cuh" + class MatrixView_half { public: @@ -11,14 +13,14 @@ public: const int height; const int width; - __device__ inline MatrixView_half(const half* data, const int height, const int width) + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) : data(data), height(height), width(width) { } - __device__ inline half item(int row, int column) const { return data[row * width + column]; } - __device__ inline half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ inline half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } - __device__ inline const half* item_ptr(int row, int column) const { return &data[row * width + column]; } + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } }; class MatrixView_half_rw @@ -28,15 +30,15 @@ public: const int height; const int width; - __device__ inline MatrixView_half_rw(half* data, const int height, const int width) + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) : data(data), height(height), width(width) { } - __device__ inline half item(int row, int column) const { return data[row * width + column]; } - __device__ inline half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ inline half* item_ptr(int row, int column) { return &data[row * width + column]; } - __device__ inline void set(int row, int column, half value) { data[row * width + column] = value; } - __device__ inline void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } }; class MatrixView_q4_row @@ -46,11 +48,11 @@ public: const int height; const int width; - __device__ inline MatrixView_q4_row(const uint32_t* data, const int height, const int width) + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) : data(data), height(height), width(width) { } - __device__ inline int item(int row, int column) const + __device__ __forceinline__ int item(int row, int column) const { int shift = (column & 0x07) * 4; return (data[row * width / 8 + column / 8] >> shift) & 0x0f; @@ -64,25 +66,25 @@ public: const int height; const int width; - __device__ inline MatrixView_q4_column(const uint32_t* data, const int height, const int width) + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) : data(data), height(height), width(width) { } - __device__ inline int item(int row, int column) const + __device__ __forceinline__ int item(int row, int column) const { int shift = (row & 0x07) * 4; return (data[row / 8 * width + column] >> shift) & 0x0f; } - __device__ inline uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } - __device__ inline const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } }; // TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu // Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale -__device__ inline half2 dot_product_8 +__device__ __forceinline__ half2 dot_product_8 ( const half2 acc, MatrixView_half& h_, @@ -118,21 +120,22 @@ __device__ inline half2 dot_product_8 half2 v_45 = __halves2half2(v_4, v_5); half2 v_67 = __halves2half2(v_6, v_7); - v_01 = __hmul2(v_01, v_scale_2); - v_23 = __hmul2(v_23, v_scale_2); - v_45 = __hmul2(v_45, v_scale_2); - v_67 = __hmul2(v_67, v_scale_2); +// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) +// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; +// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; +// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; - result = __hfma2(*h_ptr++, v_01, result); - result = __hfma2(*h_ptr++, v_23, result); - result = __hfma2(*h_ptr++, v_45, result); - result = __hfma2(*h_ptr++, v_67, result); + half2 tmp = __hmul2(*h_ptr++, v_01); + tmp = __hfma2(*h_ptr++, v_23, tmp); + tmp = __hfma2(*h_ptr++, v_45, tmp); + tmp = __hfma2(*h_ptr++, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); } return result; } -__device__ inline half dot_product_8_h +__device__ __forceinline__ half dot_product_8_h ( const half acc, MatrixView_half& h_, @@ -163,23 +166,15 @@ __device__ inline half dot_product_8_h half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - v_0 = __hmul(v_0, v_scale); - v_1 = __hmul(v_1, v_scale); - v_2 = __hmul(v_2, v_scale); - v_3 = __hmul(v_3, v_scale); - v_4 = __hmul(v_4, v_scale); - v_5 = __hmul(v_5, v_scale); - v_6 = __hmul(v_6, v_scale); - v_7 = __hmul(v_7, v_scale); - - result = __hfma(*h_ptr++, v_0, result); - result = __hfma(*h_ptr++, v_1, result); - result = __hfma(*h_ptr++, v_2, result); - result = __hfma(*h_ptr++, v_3, result); - result = __hfma(*h_ptr++, v_4, result); - result = __hfma(*h_ptr++, v_5, result); - result = __hfma(*h_ptr++, v_6, result); - result = __hfma(*h_ptr++, v_7, result); + half tmp = __hmul(*h_ptr++, v_0); + tmp = __hfma(*h_ptr++, v_1, tmp); + tmp = __hfma(*h_ptr++, v_2, tmp); + tmp = __hfma(*h_ptr++, v_3, tmp); + tmp = __hfma(*h_ptr++, v_4, tmp); + tmp = __hfma(*h_ptr++, v_5, tmp); + tmp = __hfma(*h_ptr++, v_6, tmp); + tmp = __hfma(*h_ptr++, v_7, tmp); + result = __hfma(v_scale, tmp, result); } return result; @@ -187,7 +182,7 @@ __device__ inline half dot_product_8_h // Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map -__device__ inline half2 dot_product_8_x_map +__device__ __forceinline__ half2 dot_product_8_x_map ( const half2 acc, MatrixView_half& h_, @@ -225,11 +220,6 @@ __device__ inline half2 dot_product_8_x_map half2 v_45 = __halves2half2(v_4, v_5); half2 v_67 = __halves2half2(v_6, v_7); - v_01 = __hmul2(v_01, v_scale_2); - v_23 = __hmul2(v_23, v_scale_2); - v_45 = __hmul2(v_45, v_scale_2); - v_67 = __hmul2(v_67, v_scale_2); - half h_0 = h_ptr[*x_map_ptr++]; half h_1 = h_ptr[*x_map_ptr++]; half h_2 = h_ptr[*x_map_ptr++]; @@ -244,16 +234,17 @@ __device__ inline half2 dot_product_8_x_map half2 h_45 = __halves2half2(h_4, h_5); half2 h_67 = __halves2half2(h_6, h_7); - result = __hfma2(h_01, v_01, result); - result = __hfma2(h_23, v_23, result); - result = __hfma2(h_45, v_45, result); - result = __hfma2(h_67, v_67, result); + half2 tmp = __hmul2(h_01, v_01); + tmp = __hfma2(h_23, v_23, tmp); + tmp = __hfma2(h_45, v_45, tmp); + tmp = __hfma2(h_67, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); } return result; } -__device__ inline half dot_product_8_x_map_h +__device__ __forceinline__ half dot_product_8_x_map_h ( const half acc, MatrixView_half& h_, @@ -286,23 +277,15 @@ __device__ inline half dot_product_8_x_map_h half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - v_0 = __hmul(v_0, v_scale); - v_1 = __hmul(v_1, v_scale); - v_2 = __hmul(v_2, v_scale); - v_3 = __hmul(v_3, v_scale); - v_4 = __hmul(v_4, v_scale); - v_5 = __hmul(v_5, v_scale); - v_6 = __hmul(v_6, v_scale); - v_7 = __hmul(v_7, v_scale); - - result = __hfma(h_ptr[*x_map_ptr++], v_0, result); - result = __hfma(h_ptr[*x_map_ptr++], v_1, result); - result = __hfma(h_ptr[*x_map_ptr++], v_2, result); - result = __hfma(h_ptr[*x_map_ptr++], v_3, result); - result = __hfma(h_ptr[*x_map_ptr++], v_4, result); - result = __hfma(h_ptr[*x_map_ptr++], v_5, result); - result = __hfma(h_ptr[*x_map_ptr++], v_6, result); - result = __hfma(h_ptr[*x_map_ptr++], v_7, result); + half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); + tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); + result = __hfma(v_scale, tmp, result); } return result; From 8fcd528c4a5ec20c88f7a2523fa68197cd39593a Mon Sep 17 00:00:00 2001 From: turboderp Date: Sat, 10 Jun 2023 23:08:43 +0200 Subject: [PATCH 05/15] -lineinfo for profiling --- cuda_ext.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuda_ext.py b/cuda_ext.py index 5f2c9a3c..4ef29bd2 100644 --- a/cuda_ext.py +++ b/cuda_ext.py @@ -57,7 +57,7 @@ def find_msvc(): extra_include_paths = [os.path.join(library_dir, "exllama_ext")], verbose = verbose, extra_ldflags = ["cublas.lib"] if windows else [], - extra_cuda_cflags = ["-U__HIP_NO_HALF_CONVERSIONS__", "-O3"] if torch.version.hip else [], + extra_cuda_cflags = ["-lineinfo"] + (["-U__HIP_NO_HALF_CONVERSIONS__", "-O3"] if torch.version.hip else []), extra_cflags = ["-O3"] # extra_cflags = ["-ftime-report", "-DTORCH_USE_CUDA_DSA"] ) @@ -68,7 +68,7 @@ def find_msvc(): from exllama_ext import q4_matmul from exllama_ext import half_matmul from exllama_ext import half_matmul_cublas -from exllama_ext import q4_mlp +# from exllama_ext import q4_mlp from exllama_ext import rms_norm from exllama_ext import rope_ from exllama_ext import rep_penalty From 630bad4880517825028aec016c3895994a493e24 Mon Sep 17 00:00:00 2001 From: turboderp Date: Sat, 10 Jun 2023 23:12:24 +0200 Subject: [PATCH 06/15] Housekeeping --- .../test_benchmark_perf.sh | 0 .../test_benchmark_perf2.sh | 0 .../test_benchmark_ppl.sh | 0 test_chatbot.sh => sh/test_chatbot.sh | 0 test_compat.sh => sh/test_compat.sh | 0 test_masking.py | 36 ------------------- test_profile.py | 36 ------------------- 7 files changed, 72 deletions(-) rename test_benchmark_perf.sh => sh/test_benchmark_perf.sh (100%) rename test_benchmark_perf2.sh => sh/test_benchmark_perf2.sh (100%) rename test_benchmark_ppl.sh => sh/test_benchmark_ppl.sh (100%) rename test_chatbot.sh => sh/test_chatbot.sh (100%) rename test_compat.sh => sh/test_compat.sh (100%) delete mode 100644 test_masking.py delete mode 100644 test_profile.py diff --git a/test_benchmark_perf.sh b/sh/test_benchmark_perf.sh similarity index 100% rename from test_benchmark_perf.sh rename to sh/test_benchmark_perf.sh diff --git a/test_benchmark_perf2.sh b/sh/test_benchmark_perf2.sh similarity index 100% rename from test_benchmark_perf2.sh rename to sh/test_benchmark_perf2.sh diff --git a/test_benchmark_ppl.sh b/sh/test_benchmark_ppl.sh similarity index 100% rename from test_benchmark_ppl.sh rename to sh/test_benchmark_ppl.sh diff --git a/test_chatbot.sh b/sh/test_chatbot.sh similarity index 100% rename from test_chatbot.sh rename to sh/test_chatbot.sh diff --git a/test_compat.sh b/sh/test_compat.sh similarity index 100% rename from test_compat.sh rename to sh/test_compat.sh diff --git a/test_masking.py b/test_masking.py deleted file mode 100644 index 5db015a7..00000000 --- a/test_masking.py +++ /dev/null @@ -1,36 +0,0 @@ -from model import ExLlama, ExLlamaCache, ExLlamaConfig -from tokenizer import ExLlamaTokenizer -import torch - -# Quick test to confirm that caching is working as intended. The two first passes together should produce roughly the -# same logits between them as the third pass, unless causal masking is incorrectly applied for the cached tokens, -# which it seems to be when using the built-in causal modes of SDP and xformers attention. Explicitly supplying a -# correct mask at least works for SDP, although it probably leaves some performance on the table. -# TODO: Make it not be the way that it is but so that it works instead. - -tokenizer_model_path = "/mnt/str/models/llama-7b-4bit-128g/tokenizer.model" -model_config_path = "/mnt/str/models/llama-7b-4bit-128g/config.json" -model_path = "/mnt/str/models/llama-7b-4bit-128g/llama-7b-4bit-128g.safetensors" - -config = ExLlamaConfig(model_config_path) -config.model_path = model_path -config.attention_method = ExLlamaConfig.AttentionMethod.PYTORCH_SCALED_DP -model = ExLlama(config) -cache = ExLlamaCache(model) - -tokenizer = ExLlamaTokenizer(tokenizer_model_path) - -ids = tokenizer.encode("Hello!") - -with torch.no_grad(): - - logits = model.forward(ids, cache, last_id_only = False) - print(logits) - - logits = model.forward(ids, cache, last_id_only = False) - print(logits) - - cache.current_seq_len = 0 - ids = torch.cat((ids, ids), dim = 1) - logits = model.forward(ids, cache, last_id_only = False) - print(logits) diff --git a/test_profile.py b/test_profile.py deleted file mode 100644 index 222fa235..00000000 --- a/test_profile.py +++ /dev/null @@ -1,36 +0,0 @@ - -from model import ExLlama, ExLlamaCache, ExLlamaConfig -from tokenizer import ExLlamaTokenizer -import torch - -import cProfile, pstats, io -from pstats import SortKey - -tokenizer_model_path = "/mnt/str/models/llama-30b-4bit-128g/tokenizer.model" -model_config_path = "/mnt/str/models/llama-30b-4bit-128g/config.json" -model_path = "/mnt/str/models/llama-30b-4bit-128g/llama-30b-4bit-128g.safetensors" - -tokenizer = ExLlamaTokenizer(tokenizer_model_path) - -config = ExLlamaConfig(model_config_path) -config.model_path = model_path -model = ExLlama(config) -cache = ExLlamaCache(model) - -ids = torch.randint(0, 31999, (1, 1024)) - -pr = cProfile.Profile() -pr.enable() - -with torch.no_grad(): - for i in range(128): - model.forward(ids, cache) - ids = torch.randint(0, 31999, (1, 1)) - cache.current_seq_len = 0 - -pr.disable() -s = io.StringIO() -sortby = SortKey.CUMULATIVE -ps = pstats.Stats(pr, stream=s).sort_stats(sortby) -ps.print_stats() -print(s.getvalue()) From bdb985bd6d63049c42dcb8a1662868ecff20299b Mon Sep 17 00:00:00 2001 From: nikuya3 <54780682+nikuya3@users.noreply.github.com> Date: Sat, 10 Jun 2023 21:20:42 +0000 Subject: [PATCH 07/15] Add docker support (#43) * Add docker support * Pin pytorch and other packages, remove torchvision and torchaudio --- .dockerignore | 2 ++ .env | 4 ++++ Dockerfile | 24 +++++++++++++++++++++++ README.md | 48 +++++++++++++++++++++++++++++++++++++++++++--- docker-compose.yml | 28 +++++++++++++++++++++++++++ entrypoint.sh | 15 +++++++++++++++ requirements.txt | 7 ++++--- 7 files changed, 122 insertions(+), 6 deletions(-) create mode 100644 .dockerignore create mode 100644 .env create mode 100644 Dockerfile create mode 100644 docker-compose.yml create mode 100755 entrypoint.sh diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..b210a632 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,2 @@ +exllama_sessions +models diff --git a/.env b/.env new file mode 100644 index 00000000..0b6f65f3 --- /dev/null +++ b/.env @@ -0,0 +1,4 @@ +PORT=5000 +MODEL_PATH=models/LLaMA-7B-4bit-128g # replace with the actual model path on the host +CONTAINER_MODEL_PATH=/app/model +SESSIONS_PATH=./exllama_sessions diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..b7dd15ce --- /dev/null +++ b/Dockerfile @@ -0,0 +1,24 @@ +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 as build + +ENV RUN_UID=1000 + +RUN apt-get update && \ + DEBIAN_FRONTEND=noninteractive apt-get install -y ninja-build python3 python3-pip && \ + rm -rf /var/lib/apt/lists/* + +# Setup user which will run the service +RUN useradd -m -u $RUN_UID user +USER user + +COPY --chown=user . /app + +WORKDIR /app + +RUN pip install --upgrade pip setuptools wheel \ + && pip install -r requirements.txt \ + && pip install flask==2.3.2 + +USER root + +STOPSIGNAL SIGINT +ENTRYPOINT ["/bin/bash", "-c", "/app/entrypoint.sh $0 $@"] diff --git a/README.md b/README.md index 2b0cec28..6cafb4fd 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ incompatibilities with older cards. I have no way of testing that right now. This list might be incomplete: -* `torch` tested on 2.1.0 (nightly) with cu118 +* `torch` tested on 2.0.1 and 2.1.0 (nightly) with cu118 * `safetensors` 0.3.1 * `sentencepiece` * `ninja` @@ -28,7 +28,7 @@ This list might be incomplete: ## Linux/WSL prerequisites - pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 + pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 ## Windows prerequisites @@ -79,6 +79,48 @@ To run it: Note that sessions are stored in `~/exllama_sessions/`. +## Docker +For security benefits and easier deployment, it is also possible to run the web UI in an isolated docker container. Note: the docker image currently only supports NVIDIA GPUs. + +### Requirements +- [Docker](https://docs.docker.com/engine/install/) +- [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) + +It is recommended to run docker in [rootless mode](https://docs.docker.com/engine/security/rootless/). + +### Build + +The easiest way to build the docker image is using docker compose. First, set the `MODEL_PATH` and `SESSIONS_PATH` variables in the `.env` file to the actual directories on the host. Then run: + +``` +docker compose build +``` + +It is also possible to manually build the image: + +``` +docker build -t exllama-web +``` + +### Run + +Using docker compose: + +``` +docker compose up +``` + +The web UI can now be accessed on the host at http://localhost:5000. + +The configuration can be viewed in `docker-compose.yml` and changed by creating a `docker-compose.override.yml` file. + +Run manually: + +``` +docker run --gpus all -p 5000:5000 -v :/app/model/ --rm -it exllama-web --host 0.0.0.0:5000 +``` + + ## Results so far ### New implementation @@ -171,4 +213,4 @@ on Windows. **2024-06-09**: Fused most of the self-attention step. More to come. Slight speedup already, but more importantly went from 69% actual CPU utilization to 37%. This should do a lot to address the bottleneck on CPUs with lower -single-threaded performance. \ No newline at end of file +single-threaded performance. diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..5e47c1c2 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,28 @@ +version: "3.9" +name: exllama +services: + web: + build: + context: . + command: | + --host 0.0.0.0:$PORT + env_file: + - .env + environment: + - CONTAINER_MODEL_PATH=$CONTAINER_MODEL_PATH + volumes: + - $MODEL_PATH:$CONTAINER_MODEL_PATH + - $SESSIONS_PATH:/home/user/exllama_sessions + ports: + - "$PORT:$PORT" + tmpfs: + - /tmp + stdin_open: true + tty: true + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100755 index 00000000..79864aa8 --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash +set -Eeuo pipefail + +# Ensure that the model path is set +if [ -z $CONTAINER_MODEL_PATH ]; then + echo "Must specify model path" + exit 1 +fi + +# Ensure that bind-mounted directories are owned by the user that runs the service +chown -R $RUN_UID:$RUN_UID $CONTAINER_MODEL_PATH +chown -R $RUN_UID:$RUN_UID /home/user/exllama_sessions + +# Run service as specified (non-root) user +exec runuser -u $(id -un $RUN_UID) -- python3 /app/webui/app.py -d $CONTAINER_MODEL_PATH $@ diff --git a/requirements.txt b/requirements.txt index 65fc81f8..a5860316 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -safetensors>=0.3.1 -sentencepiece -ninja +torch==2.0.1 +safetensors==0.3.1 +sentencepiece==0.1.99 +ninja==1.11.1 From fecebe28bd06bc63dd9a7fa0dece7f1855a07c15 Mon Sep 17 00:00:00 2001 From: turboderp Date: Sun, 11 Jun 2023 00:21:44 +0200 Subject: [PATCH 08/15] Removed debug mode, forced no_grad() in forward pass --- model.py | 260 ++++++------------------------------ test_benchmark_inference.py | 173 +++++++++++------------- webui/session.py | 196 ++++++++++++++------------- 3 files changed, 219 insertions(+), 410 deletions(-) diff --git a/model.py b/model.py index 83c5cf5c..468f380c 100644 --- a/model.py +++ b/model.py @@ -6,12 +6,7 @@ import json import math from enum import Enum -import struct -# Magic numbers - -optimal_switch_thd = 6 # Model mostly runs one token at a time, or many. So this probably doesn't matter too much. -attn_switch_thd = 16 class ParsedEnum(Enum): @@ -69,7 +64,6 @@ def __init__(self, model_config_path): self.max_seq_len = 2048 # Reduce to save memory. Can also be increased, but the pretrained models produce degenerate output after 2048 tokens in any case. Should be possible to finetune for longer sequence lengths. self.gpu_peer_fix = False # Apparently Torch can have problems transferring tensors directly one GPU to another sometimes. Enable this to move tensors via system RAM instead, where needed self.auto_map = None # List of ints with memory allocation in GB, per CUDA device, overrides device_map - self.debug = False # Tuning @@ -104,28 +98,11 @@ def set_auto_map(self, map_string): else: self.auto_map = [float(alloc) for alloc in map_string.split(",")] -def _dump_tensor(t, name): - - if t is None: - with open(name, "w"): - pass - with open(name + ".shape", "w"): - pass - else: - t.cpu().numpy().tofile(name) - t = t.view(-1, t.shape[-1]) - with open(name + ".shape", "wb") as file: - shape_struct = struct.pack(" 0 and _rows(hidden_states) <= self.config.fused_mlp_thd: - if self.config.debug: print(f" !! - method: fused") - cuda_ext.exllama_ext.q4_mlp(hidden_states.view(-1, hidden_states.shape[-1]), self.post_attention_layernorm.weight, self.config.rms_norm_eps, @@ -456,8 +379,6 @@ def forward(self, hidden_states, cache, buffer): else: - if self.config.debug: print(f" !! - method: normal") - residual = hidden_states hidden_states = self.post_attention_layernorm.forward(hidden_states, buffer) hidden_states = self.mlp.forward(hidden_states, buffer) @@ -538,11 +459,6 @@ def copy_states(self, target, from_column, from_columns, to_column, to_columns, target_view_v.copy_(source_view_v) - def debug(self, index): - - print(f" !! - cache device: {self.key_states[index].device}, seq_len: {self.current_seq_len}") - - # Device map for the model. class ExLlamaDeviceMap: @@ -575,23 +491,6 @@ def map(self, key, loading = False): raise ValueError("Unknown key: " + key) - def debug(self): - - print(f" !! Device map:") - print(f" !! - embed_tokens: {self.embed_tokens}") - - a = 0 - while a < len(self.layers): - b = min(a + 10, len(self.layers)) - print(f" !! - layers [{a}:{b}]:", end = "") - for i in range(a, b): print(" " + self.layers[i], end = "") - print("") - a = b - - print(f" !! - norm: {self.norm}") - print(f" !! - lm_head: {self.lm_head}") - - class ExLlamaBuffer: config: ExLlamaConfig @@ -613,11 +512,6 @@ def to(self, device): return new - def debug(self): - - print(f" !! - attn_mask: {_describe_tensor(self.attn_mask, True)}") - - def _device_to_int(device): return int(device[device.find(":") + 1:]) @@ -628,26 +522,9 @@ def _skip_key(key): if key.endswith(".rotary_emb.inv_freq"): return True return False -def _describe_tensor(tensor, stats = False): - - if tensor is None: return "None" - desc = f"device: {tensor.device}" - desc += f", shape: {str(list(tensor.shape))}" - desc += f", dtype: {str(tensor.dtype).replace('torch.', '')}" - if stats: - if tensor.dtype in (torch.float16, torch.float32, torch.float64): - desc += f", min: {tensor.min().item():.6f}" - desc += f", max: {tensor.max().item():.6f}" - desc += f", std: {tensor.std().item():.6f}" - else: - desc += f", min: {tensor.min().item()}" - desc += f", max: {tensor.max().item()}" - return desc - def _move_tensor(tensor, new_device, name, config): device = str(tensor.device) if device == new_device: return tensor - if config.debug: print(f" !! Moving {name} from {device} to {new_device}") if config.gpu_peer_fix: if device.startswith("cuda:") and new_device.startswith("cuda:"): tensor = tensor.to("cpu") @@ -655,28 +532,17 @@ def _move_tensor(tensor, new_device, name, config): class ExLlama: -# class ExLlama(nn.Module): def __init__(self, config): - # super().__init__() - # self.eval() self.config = config - if self.config.debug: - device_count = torch.cuda.device_count() - print(f" !! Available CUDA devices:") - for i in range(device_count): - print(f'" !! - cuda:{i}: {torch.cuda.get_device_name(i)}') - # Copy tuning parameters to C++ extension self.config.set_tuning_params() # Load model weights - if self.config.debug: print(f" !! Loading safetensors file: {self.config.model_path}") - tensors = {} with safe_open(self.config.model_path, framework="pt", device="cpu") as f: @@ -689,8 +555,6 @@ def __init__(self, config): if self.config.auto_map is not None: - if self.config.debug: print(f" !! Begin auto device map") - self.config.device_map.embed_tokens = "cpu" self.config.device_map.layers = ["cuda:0"] + ["?"] * (self.config.num_hidden_layers - 1) @@ -710,11 +574,6 @@ def __init__(self, config): tensor = f.get_tensor(key) head_size += tensor.numel() * tensor.element_size() - if self.config.debug: - print(f" !! Decoder size: {decoder_size:,} bytes") - print(f" !! Norm size: {norm_size:,} bytes") - print(f" !! Head size: {head_size:,} bytes") - # Assign layers automatically device_usage = 0 @@ -743,13 +602,10 @@ def __init__(self, config): device_usage += this_layer_size layer_index_device += 1 - if self.config.debug: self.config.device_map.debug() - # Load tensors, move to device(s) max_dq_buffer_size = 0 - if self.config.debug: print(f" !! Begin load tensors") for key in f.keys(): if _skip_key(key): continue @@ -757,10 +613,6 @@ def __init__(self, config): device = self.config.device_map.map(key, loading = True) tensor = f.get_tensor(key) - if self.config.debug: - if key.startswith("model.layers.0.") or not key.startswith("model.layers."): - print(f" !! - {key} read: {_describe_tensor(tensor)}") - if key.endswith(".scales"): tensor = tensor.half() if key == "lm_head.weight": tensor = tensor.float() if device == "cpu" else tensor.half() if key == "model.norm.weight": tensor = tensor.half() @@ -772,10 +624,6 @@ def __init__(self, config): if key.endswith(".qweight"): max_dq_buffer_size = max(max_dq_buffer_size, tensor.numel() * 8) - if self.config.debug: - if key.startswith("model.layers.0.") or not key.startswith("model.layers."): - print(f" !! - {key} map: {_describe_tensor(tensor, device.startswith('cuda'))}") - tensors[key] = tensor # Head @@ -795,8 +643,6 @@ def __init__(self, config): # Prepare position embeddings for max seq length - if self.config.debug: print(f" !! Computing RoPE table for seq length: {self.config.max_seq_len}") - devs = self.config.device_map.get_layers_devs() self.sincos = {} @@ -811,7 +657,6 @@ def __init__(self, config): cos = emb.cos()[None, None, :, :].half() self.sincos[device] = (sin, cos) - if self.config.debug: print(f" !! - stored for device: {device}") # Layers @@ -828,7 +673,6 @@ def __init__(self, config): modules.append(layer) self.layers = modules - # self.layers = nn.ModuleList(modules) # Prepare CUDA buffers @@ -857,89 +701,73 @@ def __init__(self, config): def forward(self, input_ids, cache, last_id_only = True, preprocess_only = False): - if torch.is_grad_enabled(): - raise ValueError("Forward pass called with gradients enabled. Back propagation is not supported yet.") - - if self.config.debug: print(f" !! Begin forward pass") - - batch_size, seq_len = input_ids.shape - past_len = cache.current_seq_len - - buffer = ExLlamaBuffer(self.config) + # if torch.is_grad_enabled(): + # raise ValueError("Forward pass called with gradients enabled. Back propagation is not supported yet.") + with torch.no_grad(): - # Build attention mask on first device, copy to others if necessary - - devs = self.config.device_map.get_layers_devs() + batch_size, seq_len = input_ids.shape + past_len = cache.current_seq_len - if seq_len > 1: - - attn_mask = torch.zeros(batch_size, 1, seq_len, past_len + seq_len, dtype = torch.float16, device = devs[0]) - attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), torch.finfo(torch.float16).min)) - attn_mask[:, :, : seq_len - 1, past_len + 1: past_len + seq_len] = attn_mask_triu - - else: + buffer = ExLlamaBuffer(self.config) - attn_mask = None - # attn_mask = torch.zeros(batch_size, 1, seq_len, seq_len + past_len, dtype = torch.float16, device = devs[0]) + # Build attention mask on first device, copy to others if necessary - buffer.attn_mask = attn_mask + devs = self.config.device_map.get_layers_devs() - # Embeddings - # TODO: Allow passing input embeddings instead of IDs + if seq_len > 1: - input_ids = _move_tensor(input_ids, "cpu", "input_ids", self.config) - hidden_states = self.embed_tokens(input_ids) + attn_mask = torch.zeros(batch_size, 1, seq_len, past_len + seq_len, dtype = torch.float16, device = devs[0]) + attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), torch.finfo(torch.float16).min)) + attn_mask[:, :, : seq_len - 1, past_len + 1: past_len + seq_len] = attn_mask_triu - if self.config.debug: print(f" !! Built initial hidden state: {_describe_tensor(hidden_states, True)}") - - # Split buffers to devices - - buffers = {devs[0]: buffer} - for device in devs[1:]: - buffers[device] = buffer.to(device) + else: - if self.config.debug: - for device in devs: - print(f" !! Prepared buffer for device: {device}") - buffers[device].debug() + attn_mask = None + # attn_mask = torch.zeros(batch_size, 1, seq_len, seq_len + past_len, dtype = torch.float16, device = devs[0]) - # Decoder layers + buffer.attn_mask = attn_mask - for i, decoder_layer in enumerate(self.layers): + # Embeddings + # TODO: Allow passing input embeddings instead of IDs - device = self.config.device_map.layers[i] - hidden_states = _move_tensor(hidden_states, device, "hidden_states", self.config) + input_ids = _move_tensor(input_ids, "cpu", "input_ids", self.config) + hidden_states = self.embed_tokens(input_ids) - hidden_states = decoder_layer.forward(hidden_states, cache, buffers[device]) + # Split buffers to devices - cache.current_seq_len += seq_len + buffers = {devs[0]: buffer} + for device in devs[1:]: + buffers[device] = buffer.to(device) - # Early exit when we don't need logits + # Decoder layers - if preprocess_only: return None + for i, decoder_layer in enumerate(self.layers): - # Norm + device = self.config.device_map.layers[i] + hidden_states = _move_tensor(hidden_states, device, "hidden_states", self.config) - hidden_states = _move_tensor(hidden_states, self.config.device_map.norm, "hidden_states", self.config) + hidden_states = decoder_layer.forward(hidden_states, cache, buffers[device]) - if self.config.debug: print(f" !! pre norm, hidden_states: {_describe_tensor(hidden_states, True)}") + cache.current_seq_len += seq_len - hidden_states = self.norm.forward(hidden_states, buffer) + # Early exit when we don't need logits - # Head + if preprocess_only: return None - if last_id_only: hidden_states = hidden_states[:, -1:, :].contiguous() - if self.config.device_map.lm_head == "cpu": hidden_states = hidden_states.float() + # Norm - hidden_states = _move_tensor(hidden_states, self.config.device_map.lm_head, "hidden_states", self.config) + hidden_states = _move_tensor(hidden_states, self.config.device_map.norm, "hidden_states", self.config) + hidden_states = self.norm.forward(hidden_states, buffer) - if self.config.debug: print(f" !! pre lm_head, hidden_states: {_describe_tensor(hidden_states, True)}") + # Head - logits = self.lm_head(hidden_states) - # logits = cuda_ext.matmul_half(hidden_states, self.lm_head_data, cublas = False) + if last_id_only: hidden_states = hidden_states[:, -1:, :].contiguous() + if self.config.device_map.lm_head == "cpu": hidden_states = hidden_states.float() - if self.config.debug: print(f" !! logits: {_describe_tensor(logits, True)}") + hidden_states = _move_tensor(hidden_states, self.config.device_map.lm_head, "hidden_states", self.config) + logits = self.lm_head(hidden_states) + # logits = cuda_ext.matmul_half(hidden_states, self.lm_head_data, cublas = False) - logits = logits.float() - logits = _move_tensor(logits, self.config.device_map.embed_tokens, "logits", self.config) - return logits + logits = logits.float() + logits = _move_tensor(logits, self.config.device_map.embed_tokens, "logits", self.config) + return logits diff --git a/test_benchmark_inference.py b/test_benchmark_inference.py index 4e97493b..177ed21c 100644 --- a/test_benchmark_inference.py +++ b/test_benchmark_inference.py @@ -14,9 +14,8 @@ testdata_path = "testdata.jsonl" -torch.set_grad_enabled(False) torch.cuda._lazy_init() -torch.backends.cuda.matmul.allow_tf32 = True +# torch.backends.cuda.matmul.allow_tf32 = True # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True torch.set_printoptions(precision = 10) torch_devices = [f"cuda:{i}" for i in range(torch.cuda.device_count())] @@ -84,7 +83,6 @@ def mem(name, total = False): parser.add_argument("-p", "--perf", action = "store_true", help = "Benchmark speed and VRAM usage") parser.add_argument("-ppl", "--perplexity", action = "store_true", help = "Perplexity benchmark (slow)") parser.add_argument("-v", "--validate", action = "store_true", help = "Quick perplexity benchmark just to test if model is working at all, and short text completion") -parser.add_argument("-dbg", "--debug", action = "store_true", help = "Run debug pass") args = parser.parse_args() model_init.post_parse(args) @@ -96,14 +94,12 @@ def mem(name, total = False): if args.perf: print_opts.append("perf") if args.perplexity: print_opts.append("perplexity") if args.validate: print_opts.append("validate") -if args.debug: print_opts.append("debug") model_init.print_options(args, print_opts) # Instantiate model config = model_init.make_config(args) -config.debug = args.debug model = timer("Load model", lambda: ExLlama(config)) tokenizer = timer("Load tokenizer", lambda: ExLlamaTokenizer(args.tokenizer)) @@ -119,133 +115,120 @@ def mem(name, total = False): max_seq_len = args.length ids = torch.randint(0, 31999, (1, max_seq_len - gen_tokens)).cuda() -with torch.no_grad(): +# Benchmark memory and performance - if args.debug: +if args.perf: - print(" !! Inference, debug pass") + # Warming up apparently makes a huge difference + for i in range(1, 4): + print(f" -- Warmup pass {i}...") begin() - logits = timer("Inference", lambda: next_logits(ids)) + logits = timer("Warmup", lambda: next_logits(ids)) - model.config.debug = False + # Do the actual benchmark - # Benchmark memory and performance + begin() - if args.perf: + t = time.time() - # Warming up apparently makes a huge difference + print(" -- Inference, first pass.") + logits = timer("Inference", lambda: next_logits(ids)) - for i in range(1, 4): - print(f" -- Warmup pass {i}...") - begin() - logits = timer("Warmup", lambda: next_logits(ids)) + t = time.time() - t + print(f" ** Speed: {ids.shape[-1] / t:.2f} tokens/second") - # Do the actual benchmark - - begin() + for j in range(2): t = time.time() + print(f" -- Generating {gen_tokens} tokens, {ids.shape[-1]} token prompt...") + for i in range(gen_tokens): - print(" -- Inference, first pass.") - logits = timer("Inference", lambda: next_logits(ids)) + logits = logits[0, -1, :] + token = torch.argmax(logits) + next_id = token.unsqueeze(0).unsqueeze(0) + logits = next_logits(next_id) t = time.time() - t - print(f" ** Speed: {ids.shape[-1] / t:.2f} tokens/second") - - for j in range(2): - - t = time.time() - print(f" -- Generating {gen_tokens} tokens, {ids.shape[-1]} token prompt...") - for i in range(gen_tokens): + print(f" ** Speed: {gen_tokens / t:.2f} tokens/second") - logits = logits[0, -1, :] - token = torch.argmax(logits) - next_id = token.unsqueeze(0).unsqueeze(0) - logits = next_logits(next_id) + ids = ids[:, :4] + cache.current_seq_len = 4 - t = time.time() - t - print(f" ** Speed: {gen_tokens / t:.2f} tokens/second") + mem("Inference") + mem("Total", total = True) - ids = ids[:, :4] - cache.current_seq_len = 4 +# Benchmark perplexity - mem("Inference") - mem("Total", total = True) +if args.perplexity or args.validate: - # Benchmark perplexity + print(" -- Loading dataset...") - if args.perplexity or args.validate: + ds = [] + with open(testdata_path) as f: + for line in f: + example = json.loads(line)["text"] + if len(example) > 50: ds.append(example) - print(" -- Loading dataset...") + def _ppl_test(text, ex_count): - ds = [] - with open(testdata_path) as f: - for line in f: - example = json.loads(line)["text"] - if len(example) > 50: ds.append(example) + print(" -- Testing", end="") + sys.stdout.flush() - def _ppl_test(text, ex_count): + logprob_sum = 0.0 + logprob_count = 0 - print(" -- Testing", end="") - sys.stdout.flush() - - logprob_sum = 0.0 - logprob_count = 0 - - for ex in ds: - - begin() + for ex in ds: - ids = tokenize(ex) - ids = ids[:, :max_seq_len + 1] - input_ids = ids[:, :-1] - target_ids = ids[:, 1:] - - logits = next_logits(input_ids, last_id_only=False) + begin() - log_probs = F.log_softmax(logits, dim=-1) - token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) + ids = tokenize(ex) + ids = ids[:, :max_seq_len + 1] + input_ids = ids[:, :-1] + target_ids = ids[:, 1:] - logprob_sum += token_log_probs.sum().item() - logprob_count += target_ids.numel() + logits = next_logits(input_ids, last_id_only=False) - ex_count -= 1 - if ex_count % 10 == 0: - print(".", end = "") - sys.stdout.flush() - if ex_count == 0: break + log_probs = F.log_softmax(logits, dim=-1) + token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) - mean_log_prob = logprob_sum / logprob_count - perplexity = math.exp(-mean_log_prob) + logprob_sum += token_log_probs.sum().item() + logprob_count += target_ids.numel() - print("") - print(f" ** Perplexity{text}: {perplexity:.4f}") + ex_count -= 1 + if ex_count % 10 == 0: + print(".", end = "") + sys.stdout.flush() + if ex_count == 0: break - if args.perplexity: + mean_log_prob = logprob_sum / logprob_count + perplexity = math.exp(-mean_log_prob) - _ppl_test("", 100) + print("") + print(f" ** Perplexity{text}: {perplexity:.4f}") - if args.validate: + if args.perplexity: - # Short perplexity tests in switched and quant mode, should produce roughly equal results + _ppl_test("", 100) - model.config.matmul_recons_thd = 1 - _ppl_test(" (reconstruct)", 8) - model.config.matmul_recons_thd = 0 - _ppl_test(" (quant)", 8) - # model.config.fused_attn_thd = 1 - # _ppl_test(" (fused_attn)", 8) + if args.validate: - # Do a short, easy topk=1 completion to see if we're generating garbage. Should run in switched mode - # for the prompt and quant for individual tokens + # Short perplexity tests in switched and quant mode, should produce roughly equal results - model.config.matmul_recons_thd = 4 - generator = ExLlamaGenerator(model, tokenizer, cache) - generator.settings.top_k = 1 - text = generator.generate_simple("To be or not to be, that is the", max_new_tokens = 20) - # text = generator.generate_simple("To be or", max_new_tokens = 20) - text = text.replace("\n", "\\n") - print(f" ** Generation: {text}") + model.config.matmul_recons_thd = 1 + _ppl_test(" (reconstruct)", 8) + model.config.matmul_recons_thd = 0 + _ppl_test(" (quant)", 8) + # model.config.fused_attn_thd = 1 + # _ppl_test(" (fused_attn)", 8) + # Do a short, easy topk=1 completion to see if we're generating garbage. Should run in switched mode + # for the prompt and quant for individual tokens + model.config.matmul_recons_thd = 4 + generator = ExLlamaGenerator(model, tokenizer, cache) + generator.settings.top_k = 1 + text = generator.generate_simple("To be or not to be, that is the", max_new_tokens = 20) + # text = generator.generate_simple("To be or", max_new_tokens = 20) + text = text.replace("\n", "\\n") + print(f" ** Generation: {text}") diff --git a/webui/session.py b/webui/session.py index e744ad2b..42b45490 100644 --- a/webui/session.py +++ b/webui/session.py @@ -534,146 +534,144 @@ def respond(self, author, stop_conditions, total_tokens, res_line = "", num_res_ def respond_multi(self, user_input): global model, tokenizer, cache, generator - with torch.no_grad(): - - packet = {"cmd": "begin_stream"} - yield json.dumps(packet) + "\n" + packet = {"cmd": "begin_stream"} + yield json.dumps(packet) + "\n" - # Prepare stop conditions + # Prepare stop conditions - # stop_conditions = [ (torch.Tensor([[tokenizer.eos_token_id]]).long(), None) ] - stop_conditions = [] - newline_token = torch.Tensor([[tokenizer.newline_token_id]]).long() + # stop_conditions = [ (torch.Tensor([[tokenizer.eos_token_id]]).long(), None) ] + stop_conditions = [] + newline_token = torch.Tensor([[tokenizer.newline_token_id]]).long() - if self.break_on_newline: - stop_conditions.append((newline_token, "\n")) - else: - for part in self.participants: - txt = part + ":" - sc = tokenizer.encode(txt) - sc = torch.cat((newline_token, sc), dim=1) - stop_conditions.append((sc, "\n" + txt)) - stop_conditions.append((sc, "\n " + txt)) - - # Clean up the input a bit + if self.break_on_newline: + stop_conditions.append((newline_token, "\n")) + else: + for part in self.participants: + txt = part + ":" + sc = tokenizer.encode(txt) + sc = torch.cat((newline_token, sc), dim=1) + stop_conditions.append((sc, "\n" + txt)) + stop_conditions.append((sc, "\n " + txt)) - user_input = user_input.strip() + # Clean up the input a bit - if len(user_input) > 0: + user_input = user_input.strip() - # Append input to context + if len(user_input) > 0: - author = None - if len(self.participants) > 0: author = self.participants[0] - newNode = Node(user_input, author) - self.history.append(newNode) + # Append input to context - self.save() + author = None + if len(self.participants) > 0: author = self.participants[0] + newNode = Node(user_input, author) + self.history.append(newNode) - # Echo input back to client + self.save() - packet = {"cmd": "begin_block", - "init_text": user_input, - "uuid": newNode.uuid} - if author is not None: packet["author"] = author - yield json.dumps(packet) + "\n" + # Echo input back to client - # Prepare context for generator + packet = {"cmd": "begin_block", + "init_text": user_input, + "uuid": newNode.uuid} + if author is not None: packet["author"] = author + yield json.dumps(packet) + "\n" - self.set_context_window() - context, text_context = self.get_tokenized_context() + # Prepare context for generator - # Start generating, reusing cache for any part of the context that hasn't changed + self.set_context_window() + context, text_context = self.get_tokenized_context() - if context is None: - print("No initial context") - reused = generator.gen_begin_empty() - else: - begin_time = time.time() - reused = generator.gen_begin_reuse(context) - end_time = time.time() - elapsed = end_time - begin_time - new_tokens = context.shape[-1] - reused - print(f"Prompt processed in {elapsed:.2f} seconds, {new_tokens} new tokens, {(new_tokens / elapsed):.2f} tokens/second:") + # Start generating, reusing cache for any part of the context that hasn't changed + if context is None: + print("No initial context") + reused = generator.gen_begin_empty() + else: begin_time = time.time() - total_tokens = [0] + reused = generator.gen_begin_reuse(context) + end_time = time.time() + elapsed = end_time - begin_time + new_tokens = context.shape[-1] - reused + print(f"Prompt processed in {elapsed:.2f} seconds, {new_tokens} new tokens, {(new_tokens / elapsed):.2f} tokens/second:") - # No participants + begin_time = time.time() + total_tokens = [0] - if len(self.participants) == 0: + # No participants - yield from self.respond(None, stop_conditions, total_tokens) + if len(self.participants) == 0: - # Two participants + yield from self.respond(None, stop_conditions, total_tokens) - elif len(self.participants) == 2: + # Two participants - author = self.participants[1] - res_line = author + ":" - res_tokens = tokenizer.encode(res_line) - num_res_tokens = res_tokens.shape[-1] + elif len(self.participants) == 2: - generator.gen_feed_tokens(res_tokens) - yield from self.respond(self.participants[1], stop_conditions, total_tokens, res_line, num_res_tokens) + author = self.participants[1] + res_line = author + ":" + res_tokens = tokenizer.encode(res_line) + num_res_tokens = res_tokens.shape[-1] - # Multiple bots might answer + generator.gen_feed_tokens(res_tokens) + yield from self.respond(self.participants[1], stop_conditions, total_tokens, res_line, num_res_tokens) - elif len(self.participants) > 2: + # Multiple bots might answer - cpart = [p + ":" for p in self.participants] - upart = cpart.pop(0) - first_round = True + elif len(self.participants) > 2: - while True: + cpart = [p + ":" for p in self.participants] + upart = cpart.pop(0) + first_round = True - res_tokens = [] - npart = [p for p in cpart] - ncrange = [i for i in range(len(cpart))] - ntoken = [tokenizer.encode(np).squeeze(0).tolist() for np in npart] - winner = -1 + while True: - while True: + res_tokens = [] + npart = [p for p in cpart] + ncrange = [i for i in range(len(cpart))] + ntoken = [tokenizer.encode(np).squeeze(0).tolist() for np in npart] + winner = -1 - constraints = [t[len(res_tokens)] for t in ntoken] - next_t = generator.gen_single_token(constraints) + while True: - remove = [] - for i in range(len(ntoken)): - if ntoken[i][len(res_tokens)] != next_t: remove.append(i) + constraints = [t[len(res_tokens)] for t in ntoken] + next_t = generator.gen_single_token(constraints) - for i in reversed(remove): - npart.pop(i) - ntoken.pop(i) - ncrange.pop(i) + remove = [] + for i in range(len(ntoken)): + if ntoken[i][len(res_tokens)] != next_t: remove.append(i) - res_tokens.append(next_t) + for i in reversed(remove): + npart.pop(i) + ntoken.pop(i) + ncrange.pop(i) - for i in range(len(ntoken)): - if len(ntoken[i]) == len(res_tokens): winner = ncrange[i] + res_tokens.append(next_t) - if winner != -1: break + for i in range(len(ntoken)): + if len(ntoken[i]) == len(res_tokens): winner = ncrange[i] - author = cpart.pop(winner)[:-1] - res_line = author + ":" - num_res_tokens = len(res_tokens) + if winner != -1: break - if author == self.participants[0]: - generator.gen_rewind(num_res_tokens) - break + author = cpart.pop(winner)[:-1] + res_line = author + ":" + num_res_tokens = len(res_tokens) - # generator.gen_feed_tokens(res_tokens) - yield from self.respond(self.participants[1], stop_conditions, total_tokens, res_line, num_res_tokens) + if author == self.participants[0]: + generator.gen_rewind(num_res_tokens) + break - if first_round: - first_round = False - cpart.append(upart) + # generator.gen_feed_tokens(res_tokens) + yield from self.respond(self.participants[1], stop_conditions, total_tokens, res_line, num_res_tokens) - end_time = time.time() - elapsed = end_time - begin_time + if first_round: + first_round = False + cpart.append(upart) - print(f"Response generated in {elapsed:.2} seconds, {total_tokens[0]} tokens, {(total_tokens[0] / elapsed):.2f} tokens/second:") + end_time = time.time() + elapsed = end_time - begin_time - self.save() + print(f"Response generated in {elapsed:.2} seconds, {total_tokens[0]} tokens, {(total_tokens[0] / elapsed):.2f} tokens/second:") + + self.save() From ebdaae03efa70c9d8f8e461cc1a6845642f7014d Mon Sep 17 00:00:00 2001 From: turboderp Date: Sun, 11 Jun 2023 00:22:20 +0200 Subject: [PATCH 09/15] Cleanup, updates --- README.md | 26 ++++++------------ TODO.md => doc/TODO.md | 0 _screenshot.jpg => doc/_screenshot.jpg | Bin .../model_compatibility.md | 0 prompt_assistant.txt | 2 -- prompt_bluemoon.txt | 1 - 6 files changed, 9 insertions(+), 20 deletions(-) rename TODO.md => doc/TODO.md (100%) rename _screenshot.jpg => doc/_screenshot.jpg (100%) rename model_compatibility.md => doc/model_compatibility.md (100%) delete mode 100644 prompt_assistant.txt delete mode 100644 prompt_bluemoon.txt diff --git a/README.md b/README.md index 6cafb4fd..17481bcb 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,15 @@ # ExLlama -A rewrite of the HF transformers implementation of Llama with the following goals, among others: +A standalone Python/C++/CUDA implementation of Llama for use with 4-bit GPTQ weights, designed to be fast and +memory-efficient on modern GPUs. -* Designed for use with quantized weights -* Fast and memory-efficient inference (not just attention) -* Mapping across multiple devices -* Built-in (multi) LoRA support -* Companion library of funky sampling functions - -Disclaimer: This is currently a preview of a work in progress. Or maybe a proof of concept. Either way any part of it -is subject to change. - -## Hardware/software requirements +## Hardware requirements I am developing on an RTX 4090 and an RTX 3090-Ti. Both cards support the CUDA kernel, but there might be -incompatibilities with older cards. I have no way of testing that right now. +incompatibilities with older cards. ## Dependencies -This list might be incomplete: - * `torch` tested on 2.0.1 and 2.1.0 (nightly) with cu118 * `safetensors` 0.3.1 * `sentencepiece` @@ -69,7 +59,7 @@ I made a simple web UI for it. Like the rest of the project, it's a work in prog it was mostly written by ChatGPT and it will haunt your dreams. But it sort of works, and it's kinda fun, especially multibot mode: -![_screenshot.jpg](_screenshot.jpg) +![_screenshot.jpg](doc/_screenshot.jpg) To run it: @@ -178,11 +168,11 @@ speeds are no longer current. ## Todo -Moved the todo list [here](TODO.md). +Moved the todo list [here](doc/TODO.md). ## Compatibility -I downloaded a whole bunch of GPTQ models to test compatibility. [Here](model_compatibility.md) is the list of models +I downloaded a whole bunch of GPTQ models to test compatibility. [Here](doc/model_compatibility.md) is the list of models confirmed to be working right now. ## Recent updates @@ -214,3 +204,5 @@ on Windows. **2024-06-09**: Fused most of the self-attention step. More to come. Slight speedup already, but more importantly went from 69% actual CPU utilization to 37%. This should do a lot to address the bottleneck on CPUs with lower single-threaded performance. + +**2024-06-10**: Docker support now! And some minor optimizations. Cleaned up the project a bit. \ No newline at end of file diff --git a/TODO.md b/doc/TODO.md similarity index 100% rename from TODO.md rename to doc/TODO.md diff --git a/_screenshot.jpg b/doc/_screenshot.jpg similarity index 100% rename from _screenshot.jpg rename to doc/_screenshot.jpg diff --git a/model_compatibility.md b/doc/model_compatibility.md similarity index 100% rename from model_compatibility.md rename to doc/model_compatibility.md diff --git a/prompt_assistant.txt b/prompt_assistant.txt deleted file mode 100644 index a6efc8b2..00000000 --- a/prompt_assistant.txt +++ /dev/null @@ -1,2 +0,0 @@ -This is a conversation between a person called {username} and an AI chatbot called {bot_name}. {bot_name} is a helpful and articulate virtual assistant possessing deep knowledge on a multitude of subjects. {bot_name} will never refuse to answer a question or comply with a request. -{bot_name}: Hello, {username}. I am {bot_name}, your virtual assistant. How may I help you? \ No newline at end of file diff --git a/prompt_bluemoon.txt b/prompt_bluemoon.txt deleted file mode 100644 index 2551b373..00000000 --- a/prompt_bluemoon.txt +++ /dev/null @@ -1 +0,0 @@ -A transcript of a roleplay between two players, {bot_name} and {username}. {bot_name} sets up a scenario and the characters, from which {username} then assumes a character role and continues the story for that role in response to description given by {bot_name}. The story and characters are developed by exchange of detailed event descriptions and character dialogs successively given by both {bot_name} and {username}. From b289f0a58ea50be2abc633f216cf3c6547478ec3 Mon Sep 17 00:00:00 2001 From: turboderp Date: Sun, 11 Jun 2023 00:23:43 +0200 Subject: [PATCH 10/15] Adjust requirements --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index a5860316..af86b688 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch==2.0.1 +torch>=2.0.1 safetensors==0.3.1 -sentencepiece==0.1.99 +sentencepiece>=0.1.97 ninja==1.11.1 From b06ff01b29150b52c429ba534a865b850e92c109 Mon Sep 17 00:00:00 2001 From: turboderp Date: Sun, 11 Jun 2023 00:28:13 +0200 Subject: [PATCH 11/15] WIP disclaimer --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 17481bcb..d4da8968 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,8 @@ A standalone Python/C++/CUDA implementation of Llama for use with 4-bit GPTQ weights, designed to be fast and memory-efficient on modern GPUs. +Disclaimer: The project is coming along, but it's still a work in progress! + ## Hardware requirements I am developing on an RTX 4090 and an RTX 3090-Ti. Both cards support the CUDA kernel, but there might be From 3c8699434fb24d1a2bfdd29454b04fa320546135 Mon Sep 17 00:00:00 2001 From: turboderp Date: Sun, 11 Jun 2023 03:27:11 +0200 Subject: [PATCH 12/15] Fix for models with all-zero g_idx --- doc/model_compatibility.md | 1 + model.py | 5 +++++ model_init.py | 2 ++ sh/test_compat.sh | 1 + 4 files changed, 9 insertions(+) diff --git a/doc/model_compatibility.md b/doc/model_compatibility.md index 10508a7b..c01cc818 100644 --- a/doc/model_compatibility.md +++ b/doc/model_compatibility.md @@ -19,6 +19,7 @@ As of **2023-05-24**, the following GPTQ models on HuggingFace all appear to be - TheBloke/Manticore-13B-GPTQ - TheBloke/medalpaca-13B-GPTQ-4bit - TheBloke/medalpaca-13B-GPTQ-4bit (compat version) +- TheBloke/tulu-30B-GPTQ - TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g - TheBloke/VicUnlocked-30B-LoRA-GPTQ - TheBloke/wizard-mega-13B-GPTQ diff --git a/model.py b/model.py index 468f380c..0f31ce77 100644 --- a/model.py +++ b/model.py @@ -53,6 +53,7 @@ def __init__(self, model_config_path): self.groupsize = None # Autodetected self.act_order = False # Autodetected + self.empty_g_idx = False # Autodetected # Required settings @@ -115,6 +116,10 @@ def __init__(self, config, in_features, out_features, has_bias, tensors, key): self.g_idx = tensors[key + ".g_idx"].cpu() if key + ".g_idx" in tensors else None self.bias = tensors[key + ".bias"] if has_bias else None + if (self.g_idx == 0).all(): + self.config.empty_g_idx = True + self.g_idx = None + self.device = self.qweight.device self.device_index = self.device.index diff --git a/model_init.py b/model_init.py index ebc97f73..21e6e4b5 100644 --- a/model_init.py +++ b/model_init.py @@ -119,3 +119,5 @@ def print_stats(model): print(f" -- Groupsize (inferred): {model.config.groupsize if model.config.groupsize is not None else 'None'}") print(f" -- Act-order (inferred): {'yes' if model.config.act_order else 'no'}") + if model.config.empty_g_idx: + print(f" !! Model has empty group index (discarded)") diff --git a/sh/test_compat.sh b/sh/test_compat.sh index 5618e7bd..c3a893ed 100755 --- a/sh/test_compat.sh +++ b/sh/test_compat.sh @@ -13,6 +13,7 @@ echo "---------" && python test_benchmark_inference.py -v -l 1024 -d /mnt/str/mo echo "---------" && python test_benchmark_inference.py -v -l 1024 -d /mnt/str/models/_test_models/TheBloke_Manticore-13B-GPTQ echo "---------" && python test_benchmark_inference.py -v -l 1024 -d /mnt/str/models/_test_models/TheBloke_medalpaca-13B-GPTQ-4bit echo "---------" && python test_benchmark_inference.py -v -l 1024 -d /mnt/str/models/_test_models/TheBloke_medalpaca-13B-GPTQ-4bit_compat +echo "---------" && python test_benchmark_inference.py -v -l 1024 -d /mnt/str/models/_test_models/TheBloke_tulu-30B-GPTQ echo "---------" && python test_benchmark_inference.py -v -l 1024 -d /mnt/str/models/_test_models/TheBloke_vicuna-13B-1.1-GPTQ-4bit-128g echo "---------" && python test_benchmark_inference.py -v -l 1024 -d /mnt/str/models/_test_models/TheBloke_VicUnlocked-30B-LoRA-GPTQ echo "---------" && python test_benchmark_inference.py -v -l 1024 -d /mnt/str/models/_test_models/TheBloke_wizard-mega-13B-GPTQ From bb855c9d2366eb48bd3cae3437897f5a9c61b709 Mon Sep 17 00:00:00 2001 From: turboderp Date: Sun, 11 Jun 2023 09:49:52 +0200 Subject: [PATCH 13/15] Fix for previous fix --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 0f31ce77..55f1cb45 100644 --- a/model.py +++ b/model.py @@ -116,7 +116,7 @@ def __init__(self, config, in_features, out_features, has_bias, tensors, key): self.g_idx = tensors[key + ".g_idx"].cpu() if key + ".g_idx" in tensors else None self.bias = tensors[key + ".bias"] if has_bias else None - if (self.g_idx == 0).all(): + if self.g_idx is not None and (self.g_idx == 0).all(): self.config.empty_g_idx = True self.g_idx = None From b65d774c1bd4fbf23405e9f97e2e58da8109543b Mon Sep 17 00:00:00 2001 From: turboderp Date: Sun, 11 Jun 2023 18:15:54 +0200 Subject: [PATCH 14/15] Concurrency (experimental) --- README.md | 12 ++-- doc/TODO.md | 3 +- exllama_ext/cuda_buffers.cu | 8 +++ exllama_ext/cuda_buffers.cuh | 7 +++ exllama_ext/cuda_func/q4_attn.cu | 89 +++++++++++++++++++---------- exllama_ext/cuda_func/q4_matmul.cu | 7 ++- exllama_ext/cuda_func/q4_matmul.cuh | 3 +- exllama_ext/cuda_func/q4_mlp.cu | 27 +++++++-- exllama_ext/cuda_func/rope.cu | 5 +- exllama_ext/cuda_func/rope.cuh | 3 +- exllama_ext/exllama_ext.cpp | 4 +- exllama_ext/tuning.h | 1 + model.py | 4 +- model_init.py | 3 + sh/test_benchmark_perf.sh | 4 +- 15 files changed, 131 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index d4da8968..0bea2dad 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Disclaimer: The project is coming along, but it's still a work in progress! ## Hardware requirements -I am developing on an RTX 4090 and an RTX 3090-Ti. Both cards support the CUDA kernel, but there might be +I am developing on an RTX 4090 and an RTX 3090-Ti. Both cards support the CUDA kernels, but there might be incompatibilities with older cards. ## Dependencies @@ -118,8 +118,8 @@ docker run --gpus all -p 5000:5000 -v :/app/model/ --rm -it ### New implementation | Model | Size | grpsz | act | Seq. len. | VRAM | Prompt | Best | Worst | Ppl | |----------|------|-------|-----------------|----------------------|-----------|------------|---------|---------|------| -| Llama | 7B | 128 | no | 2,048 t | 5,194 MB | 13,918 t/s | 168 t/s | 139 t/s | 6.45 | -| Llama | 13B | 128 | no | 2,048 t | 9,127 MB | 7,507 t/s | 99 t/s | 84 t/s | 5.60 | +| Llama | 7B | 128 | no | 2,048 t | 5,194 MB | 13,918 t/s | 173 t/s | 140 t/s | 6.45 | +| Llama | 13B | 128 | no | 2,048 t | 9,127 MB | 7,507 t/s | 102 t/s | 86 t/s | 5.60 | | Llama | 30B | 128 | no | 2,048 t | 20,795 MB | 2,959 t/s | 47 t/s | 40 t/s | 4.60 | | Llama | 30B | 128 | yes | 2,048 t | 20,795 MB | 2,784 t/s | 45 t/s | 37 t/s | 4.55 | | Llama | 30B | 32 | yes | 1,550 t 1 | 21,486 MB | 2,636 t/s | 41 t/s | 37 t/s | 4.52 | @@ -207,4 +207,8 @@ on Windows. from 69% actual CPU utilization to 37%. This should do a lot to address the bottleneck on CPUs with lower single-threaded performance. -**2024-06-10**: Docker support now! And some minor optimizations. Cleaned up the project a bit. \ No newline at end of file +**2024-06-10**: Docker support now! And some minor optimizations. Cleaned up the project a bit. + +**2024-06-11**: Added some concurrency a couple of places. It's only beneficial on the 4090, on small models where the +cores are somewhat underutilized and the L2 cache can keep up. For the 3090 it's detrimental to performance, so it's +disabled by default. YMMV. Use `-cs` to try it out. \ No newline at end of file diff --git a/doc/TODO.md b/doc/TODO.md index 23459a74..89a339ce 100644 --- a/doc/TODO.md +++ b/doc/TODO.md @@ -46,6 +46,7 @@ - [x] Examine if scaled_dot_product_attention is actually the best attention method for single tokens (it's not) - [ ] Implement attention in CUDA - [x] Rewrite at least the quantized matmul kernel. Should be a bunch of special cases to consider +- [x] Experiment with concurrent streams where possible (fused MLP and QKV proj.) ## Generation @@ -54,7 +55,7 @@ - [ ] Multi-token censoring/de-censoring - [ ] Multi-token repetition penalties - [ ] (Multi) LoRA support -- [ ] Guided generation (chat with multiple bots at once, etc.) +- [x] Guided generation (chat with multiple bots at once, etc.) - [ ] Multiple chat modes with prompt templates (instruct, etc.) ## Interface diff --git a/exllama_ext/cuda_buffers.cu b/exllama_ext/cuda_buffers.cu index a6e4ea9d..7e7bd75d 100644 --- a/exllama_ext/cuda_buffers.cu +++ b/exllama_ext/cuda_buffers.cu @@ -23,6 +23,14 @@ CudaBuffers::CudaBuffers max_zeros_float(_max_zeros_float), current_zeros_float(0) { + cudaSetDevice(_device); + + cudaStreamCreate(&alt_stream_1); + cudaStreamCreate(&alt_stream_2); + cudaStreamCreate(&alt_stream_3); + cudaEventCreate(&alt_stream_1_done); + cudaEventCreate(&alt_stream_2_done); + cudaEventCreate(&alt_stream_3_done); } CudaBuffers::~CudaBuffers() diff --git a/exllama_ext/cuda_buffers.cuh b/exllama_ext/cuda_buffers.cuh index ce948a12..86e8af82 100644 --- a/exllama_ext/cuda_buffers.cuh +++ b/exllama_ext/cuda_buffers.cuh @@ -25,6 +25,13 @@ public: int current_zeros_float; int max_zeros_float; + cudaStream_t alt_stream_1; + cudaStream_t alt_stream_2; + cudaStream_t alt_stream_3; + cudaEvent_t alt_stream_1_done; + cudaEvent_t alt_stream_2_done; + cudaEvent_t alt_stream_3_done; + CudaBuffers ( int _device, diff --git a/exllama_ext/cuda_func/q4_attn.cu b/exllama_ext/cuda_func/q4_attn.cu index 5c17d37e..2f459904 100644 --- a/exllama_ext/cuda_func/q4_attn.cu +++ b/exllama_ext/cuda_func/q4_attn.cu @@ -97,47 +97,78 @@ void q4_attn_cuda const int device_index ) { + // Cache update grid + + dim3 threads(THREADS_X, THREADS_Y, THREADS_Z); + + dim3 blocks + ( + head_dim / THREADS_X / BLOCKSIZE_X, + q_len, + num_heads / THREADS_Z / BLOCKSIZE_Z + ); + + int _rows = q_len * num_heads; + CudaBuffers* buffers = get_buffers(device_index); + // Layernorm + half* temp_x = buffers->temp_state + q_len * dim; // TODO: .. rms_norm_cuda(tuningParams, x, rms_norm_weight, temp_x, epsilon, q_len, dim, device_index); - // Project q, k, v + if (!tuningParams->concurrent_streams) + { + // Project q, k, v - q4_matmul_cuda(tuningParams, temp_x, q_len, q_proj, query_states); - q4_matmul_cuda(tuningParams, temp_x, q_len, k_proj, key_states); - q4_matmul_cuda(tuningParams, temp_x, q_len, v_proj, value_states); + q4_matmul_cuda(tuningParams, temp_x, q_len, q_proj, query_states); + q4_matmul_cuda(tuningParams, temp_x, q_len, k_proj, key_states); + q4_matmul_cuda(tuningParams, temp_x, q_len, v_proj, value_states); - // Positional embeddings - // TODO: these can be fused to reduce launch overhead by about 1500 ns and kernel time by a little, too + // Positional embeddings q, k - int _rows = q_len * num_heads; - rope_cuda(tuningParams, query_states, sin, cos, _rows, head_dim, num_heads, past_len); - rope_cuda(tuningParams, key_states, sin, cos, _rows, head_dim, num_heads, past_len); + rope_cuda(tuningParams, query_states, sin, cos, _rows, head_dim, num_heads, past_len); + rope_cuda(tuningParams, key_states, sin, cos, _rows, head_dim, num_heads, past_len); - // Update cache tensors with projected k, v + // Update cache tensors with projected k, v - dim3 threads(THREADS_X, THREADS_Y, THREADS_Z); + update_cache_kernel<<>>(key_states, value_states, key_cache, value_cache, head_dim, num_heads, q_len, max_seq_len, past_len); + } + else + { + // Project q, k, v, add positional embeddings to q, k, update cache tensors with projected k, v - dim3 blocks - ( - head_dim / THREADS_X / BLOCKSIZE_X, - q_len, - num_heads / THREADS_Z / BLOCKSIZE_Z - ); + cudaStream_t str_1 = buffers->alt_stream_1; + cudaStream_t str_2 = buffers->alt_stream_2; + cudaStream_t str_3 = buffers->alt_stream_3; + cudaEvent_t sync_1 = buffers->alt_stream_1_done; + cudaEvent_t sync_2 = buffers->alt_stream_2_done; + cudaEvent_t sync_3 = buffers->alt_stream_3_done; - update_cache_kernel<<>> - ( - key_states, - value_states, - key_cache, - value_cache, - head_dim, - num_heads, - q_len, - max_seq_len, - past_len - ); + // str_1: project q, positions q, sync + + q4_matmul_cuda(tuningParams, temp_x, q_len, q_proj, query_states, false, str_1); + rope_cuda(tuningParams, query_states, sin, cos, _rows, head_dim, num_heads, past_len, str_1); + cudaEventRecord(sync_1, str_1); + + // str_2: project k, positions k, sync + + q4_matmul_cuda(tuningParams, temp_x, q_len, k_proj, key_states, false, str_2); + rope_cuda(tuningParams, key_states, sin, cos, _rows, head_dim, num_heads, past_len, str_2); + cudaEventRecord(sync_2, str_2); + + // str_3: project v, wait for str_2, copy (k,v) to cache, sync + + q4_matmul_cuda(tuningParams, temp_x, q_len, v_proj, value_states, false, buffers->alt_stream_3); + cudaStreamWaitEvent(str_3, sync_2, 0); + update_cache_kernel<<>>(key_states, value_states, key_cache, value_cache, head_dim, num_heads, q_len, max_seq_len, past_len); + cudaEventRecord(sync_3, str_3); + + // default: wait for str_1 and str_3 + + cudaStreamWaitEvent(NULL, sync_1, 0); + cudaStreamWaitEvent(NULL, sync_3, 0); + } } void q4_attn_2_cuda diff --git a/exllama_ext/cuda_func/q4_matmul.cu b/exllama_ext/cuda_func/q4_matmul.cu index 7d793928..875f9530 100644 --- a/exllama_ext/cuda_func/q4_matmul.cu +++ b/exllama_ext/cuda_func/q4_matmul.cu @@ -172,7 +172,8 @@ void q4_matmul_cuda const int x_height, const Q4Matrix* w, half* out, - bool no_zero + bool no_zero, + cudaStream_t alt_stream ) { int height = x_height; @@ -183,7 +184,7 @@ void q4_matmul_cuda uint32_t* x_map = w->cuda_x_map; const half* x_mapped = x; - if (x_map && !tuningParams->matmul_fused_remap) + if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) { CudaBuffers* buffers = get_buffers(w->device); column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); @@ -212,7 +213,7 @@ void q4_matmul_cuda ); fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); - kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); } void q4_matmul_recons_cuda diff --git a/exllama_ext/cuda_func/q4_matmul.cuh b/exllama_ext/cuda_func/q4_matmul.cuh index 348ee92a..7a494dd8 100644 --- a/exllama_ext/cuda_func/q4_matmul.cuh +++ b/exllama_ext/cuda_func/q4_matmul.cuh @@ -23,7 +23,8 @@ void q4_matmul_cuda const int x_height, const Q4Matrix* w, half* out, - bool no_zero = false + bool no_zero = false, + cudaStream_t alt_stream = NULL ); void q4_matmul_recons_cuda diff --git a/exllama_ext/cuda_func/q4_mlp.cu b/exllama_ext/cuda_func/q4_mlp.cu index 77fa09a4..8ac47b32 100644 --- a/exllama_ext/cuda_func/q4_mlp.cu +++ b/exllama_ext/cuda_func/q4_mlp.cu @@ -117,11 +117,30 @@ void q4_mlp_cuda half* temp_x = buffers->temp_state + height * dim; // TOOD: .. rms_norm_cuda(tuningParams, x, rms_norm_weight, temp_x, epsilon, height, dim, device_index); - // temp_mlp[0] = temp_x @ gate - // temp_mlp[1] = temp_x @ up + if (!tuningParams->concurrent_streams) + { + // temp_mlp[0] = temp_x @ gate + // temp_mlp[1] = temp_x @ up + + q4_matmul_cuda(tuningParams, temp_x, height, gate, buffers->temp_mlp); + q4_matmul_cuda(tuningParams, temp_x, height, up, buffers->temp_mlp + height * up->width); + } + else + { + cudaStream_t str_1 = buffers->alt_stream_1; + cudaStream_t str_2 = buffers->alt_stream_2; + cudaEvent_t sync_1 = buffers->alt_stream_1_done; + cudaEvent_t sync_2 = buffers->alt_stream_2_done; + + q4_matmul_cuda(tuningParams, temp_x, height, gate, buffers->temp_mlp, false, str_1); + cudaEventRecord(sync_1, str_1); - q4_matmul_cuda(tuningParams, temp_x, height, gate, buffers->temp_mlp); - q4_matmul_cuda(tuningParams, temp_x, height, up, buffers->temp_mlp + height * up->width); + q4_matmul_cuda(tuningParams, temp_x, height, up, buffers->temp_mlp + height * up->width, false, str_2); + cudaEventRecord(sync_2, str_2); + + cudaStreamWaitEvent(NULL, sync_1, 0); + cudaStreamWaitEvent(NULL, sync_2, 0); + } // temp_mlp[0] = silu(temp_mlp[0]) * temp_mlp[1] diff --git a/exllama_ext/cuda_func/rope.cu b/exllama_ext/cuda_func/rope.cu index d8fdeaec..5179553c 100644 --- a/exllama_ext/cuda_func/rope.cu +++ b/exllama_ext/cuda_func/rope.cu @@ -103,7 +103,8 @@ void rope_cuda const int rows, const int head_dim, const int num_heads, - const int past_len + const int past_len, + cudaStream_t alt_stream ) { dim3 threads(THREADS_X, THREADS_Y, 1); @@ -116,5 +117,5 @@ void rope_cuda ); fp_rope_cuda_kernel kernel = rope_cuda_kernel_pick(tuningParams); - kernel<<>>(x, sin, cos, rows, head_dim, num_heads, past_len); + kernel<<>>(x, sin, cos, rows, head_dim, num_heads, past_len); } diff --git a/exllama_ext/cuda_func/rope.cuh b/exllama_ext/cuda_func/rope.cuh index 1c9fcae2..64d5b88b 100644 --- a/exllama_ext/cuda_func/rope.cuh +++ b/exllama_ext/cuda_func/rope.cuh @@ -16,7 +16,8 @@ void rope_cuda const int rows, const int head_dim, const int num_heads, - const int past_len + const int past_len, + cudaStream_t alt_stream = NULL ); #endif \ No newline at end of file diff --git a/exllama_ext/exllama_ext.cpp b/exllama_ext/exllama_ext.cpp index c03d4d92..394b4015 100644 --- a/exllama_ext/exllama_ext.cpp +++ b/exllama_ext/exllama_ext.cpp @@ -94,7 +94,8 @@ void set_tuning_params bool rmsnorm_no_half2, bool rope_no_half2, bool matmul_no_half2, - bool silu_no_half2 + bool silu_no_half2, + bool concurrent_streams ) { tuningParams.matmul_recons_thd = matmul_recons_thd; @@ -106,6 +107,7 @@ void set_tuning_params tuningParams.rope_no_half2 = rope_no_half2; tuningParams.matmul_no_half2 = matmul_no_half2; tuningParams.silu_no_half2 = silu_no_half2; + tuningParams.concurrent_streams = concurrent_streams; } // Prepare buffers for forward pass diff --git a/exllama_ext/tuning.h b/exllama_ext/tuning.h index 1f660b0e..50195e06 100644 --- a/exllama_ext/tuning.h +++ b/exllama_ext/tuning.h @@ -12,6 +12,7 @@ struct ExLlamaTuning bool rope_no_half2; bool matmul_no_half2; bool silu_no_half2; + bool concurrent_streams; }; #endif \ No newline at end of file diff --git a/model.py b/model.py index 55f1cb45..64e1c43a 100644 --- a/model.py +++ b/model.py @@ -77,6 +77,7 @@ def __init__(self, model_config_path): self.rope_no_half2 = False self.matmul_no_half2 = False self.silu_no_half2 = False + self.concurrent_streams = False # Copy tuning params to C++ extension @@ -89,7 +90,8 @@ def set_tuning_params(self): self.rmsnorm_no_half2, self.rope_no_half2, self.matmul_no_half2, - self.silu_no_half2) + self.silu_no_half2, + self.concurrent_streams) # Parse and set list of GPU VRAM allocations diff --git a/model_init.py b/model_init.py index 21e6e4b5..ce2048e2 100644 --- a/model_init.py +++ b/model_init.py @@ -26,6 +26,7 @@ def add_args(parser): parser.add_argument("-snh2", "--silu_no_half2", action = "store_true", help = "Don't use half2 in SiLU kernel") parser.add_argument("-nh2", "--no_half2", action = "store_true", help = "(All of the above) disable half2 in all kernela") parser.add_argument("-fh2", "--force_half2", action = "store_true", help = "Force enable half2 even if unsupported") + parser.add_argument("-cs", "--concurrent_streams", action = "store_true", help = "Use concurrent CUDA streams") def post_parse(args): @@ -84,6 +85,7 @@ def print_options(args, extra_options = None): if args.rope_no_half2: print(f" -- --rope_no_half2") if args.matmul_no_half2: print(f" -- --matmul_no_half2") if args.silu_no_half2: print(f" -- --silu_no_half2") + if args.concurrent_streams: print(f" -- ----concurrent_streams") print(f" -- Options: {print_opts}") @@ -109,6 +111,7 @@ def make_config(args): config.rope_no_half2 = args.rope_no_half2 config.matmul_no_half2 = args.matmul_no_half2 config.silu_no_half2 = args.silu_no_half2 + config.concurrent_streams = args.concurrent_streams return config diff --git a/sh/test_benchmark_perf.sh b/sh/test_benchmark_perf.sh index 39a484c3..5a6e8849 100755 --- a/sh/test_benchmark_perf.sh +++ b/sh/test_benchmark_perf.sh @@ -1,8 +1,8 @@ echo "-------------------------------------------------------------------------------------------------------------" -python test_benchmark_inference.py -p -d /mnt/str/models/llama-7b-4bit-128g +python test_benchmark_inference.py -p -d /mnt/str/models/llama-7b-4bit-128g -cs echo "-------------------------------------------------------------------------------------------------------------" -python test_benchmark_inference.py -p -d /mnt/str/models/llama-13b-4bit-128g +python test_benchmark_inference.py -p -d /mnt/str/models/llama-13b-4bit-128g -cs echo "-------------------------------------------------------------------------------------------------------------" python test_benchmark_inference.py -p -d /mnt/str/models/llama-30b-4bit-128g echo "-------------------------------------------------------------------------------------------------------------" From 896da5d3b59252cba40aea6818621b2fbc77fbf1 Mon Sep 17 00:00:00 2001 From: turboderp Date: Sun, 11 Jun 2023 19:44:41 +0200 Subject: [PATCH 15/15] Benchmark >2048 token sequence prompts in batches --- model.py | 4 ++++ test_benchmark_inference.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/model.py b/model.py index 64e1c43a..be15f5b4 100644 --- a/model.py +++ b/model.py @@ -705,6 +705,10 @@ def __init__(self, config): temp_zeros_float, temp_dq) + # Clear the cache + + torch.cuda.empty_cache() + def forward(self, input_ids, cache, last_id_only = True, preprocess_only = False): diff --git a/test_benchmark_inference.py b/test_benchmark_inference.py index 177ed21c..20b0ac66 100644 --- a/test_benchmark_inference.py +++ b/test_benchmark_inference.py @@ -33,7 +33,14 @@ def begin(): def next_logits(input_ids, last_id_only = True): global model, cache - return model.forward(input_ids, cache, last_id_only) + n_logits = None + a = 0 + while a < input_ids.shape[-1]: + b = min(input_ids.shape[-1], a + 2048) + n_logits = model.forward(input_ids[:, a:b], cache, last_id_only) + a = b + + return n_logits def tokenize(text): @@ -121,7 +128,7 @@ def mem(name, total = False): # Warming up apparently makes a huge difference - for i in range(1, 4): + for i in range(1, 3): print(f" -- Warmup pass {i}...") begin() logits = timer("Warmup", lambda: next_logits(ids))