diff --git a/.gitignore b/.gitignore index 22f5a6cd6..aca1983d3 100644 --- a/.gitignore +++ b/.gitignore @@ -22,9 +22,11 @@ CMakeFiles/ bitsandbytes.dir/ Debug/ Release/ +cmake-build-*/ # IDE local files .vs/ +.idea/ # Distribution / packaging .Python diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 03e3add4a..133f9e066 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -335,17 +335,17 @@ def forward( # Zero out the outliers in the int8 inputs CA[:, state.idx] = 0 - CAt[:, state.idx] = 0 + # CAt[:, state.idx] = 0 # Extract the input outliers in original precision subA = A[:, state.idx] # Extract the corresponding weights if state.has_fp16_weights: - state.subB = B[:, state.idx].t().contiguous() + state.subB = B[:, state.idx].t() # .contiguous() else: - outliers = state.CB[:, state.idx].clone() - state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) + outliers = state.CB[:, state.idx] # .clone() + state.subB = (7.874016e-3 * outliers * state.SCB.view(-1, 1)).t().to(A.dtype) else: subA = None @@ -372,14 +372,14 @@ def forward( ctx.tensors = (CAt, subA, A) ctx.tensor_states = (SCAt, state.idx) else: - ctx.tensors = [None, None, A] + ctx.tensors = [None, None, None] # A] ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) output_shape = (*input_shape[:-1], state.CB.shape[0]) if len(input_shape) == 3: - return output.view(output_shape).clone() + return output.reshape(output_shape) # .clone() else: return output @@ -417,10 +417,10 @@ def backward(ctx, grad_output): if req_gradA: # grad_output @ B.T - if state.CBt is not None: - gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t()) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) - elif state.CB is not None: + # if state.CBt is not None: + # gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t()) + # grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) + if state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 8d7226b2c..6c8ffe3d1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2400,7 +2400,6 @@ def get_colrow_absmax( outlier_mask = None if row_stats is None or col_stats is None: - # view as 2D absA = A.abs().view(-1, A.shape[-1]) if threshold > 0.0: @@ -2408,13 +2407,10 @@ def get_colrow_absmax( outlier_mask = absA >= threshold absA.masked_fill_(outlier_mask, 0.0) - # For parity with tests build nnz_block_ptr. - nnz_block_ptr = torch.zeros(absA.shape[0] + 1, dtype=torch.int64, device=A.device) - nnz_block_ptr[1:] = outlier_mask.sum(1).cumsum(0) - if row_stats is None: # shape [rows]; unsqueeze(-1) gives [rows,1] - row_stats = absA.amax(dim=1, keepdim=False).float() + row_stats = get_row_absmax(A, threshold) + if col_stats is None: # shape [cols]; unsqueeze(0) gives [1,cols] col_stats = absA.amax(dim=0, keepdim=False).float() @@ -2422,42 +2418,20 @@ def get_colrow_absmax( return row_stats, col_stats, outlier_mask -def get_colrow_absmax_old(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): +def get_row_absmax(A, threshold=0.0): assert A.dtype == torch.float16 - device = A.device + rows = prod(A.shape[:-1]) cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - col_tiles = (cols + 255) // 256 - tiled_rows = ((rows + 15) // 16) * 16 - if row_stats is None: - row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0) - if col_stats is None: - col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0) - - if nnz_block_ptr is None and threshold > 0.0: - nnz_block_ptr = torch.zeros(((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device) - - ptrA = get_ptr(A) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - ptrNnzrows = get_ptr(nnz_block_ptr) - rows = ct.c_int32(rows) - cols = ct.c_int32(cols) + row_stats = torch.empty((rows,), dtype=torch.float32, device=A.device) - prev_device = pre_call(A.device) - is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) - lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) - post_call(prev_device) + is_on_gpu([A, row_stats]) - if threshold > 0.0: - nnz_block_ptr.cumsum_(0) + with torch.cuda.device_of(A): + lib.cget_row_stats(get_ptr(A), get_ptr(row_stats), ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols)) - return row_stats, col_stats, nnz_block_ptr + return row_stats class COOSparseTensor: @@ -2541,127 +2515,48 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -# @torch.compile +def extract_outliers_new(A: torch.Tensor, threshold: float): + outlier_mask = A.abs() >= threshold + return A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo() + + def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): - # TODO: Optimize/write CUDA kernel for this + assert A.dtype == torch.half - if row_stats is None or col_stats is None: - row_stats, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold) + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty((rows,), device=A.device, dtype=torch.float32) + + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) if threshold > 0.0: # Extract outliers to COO tensor: # 1. Zero out all of the non-outliers, convert to COO. # 2. Zero out the outliers in the dense tensor. # TODO we could improve perf of this - # is_outlier = A.abs() >= threshold - coo_tensor = A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo() - A = A.masked_fill(outlier_mask, 0.0) + # outlier_mask = A.abs() >= threshold + # coo_tensor = A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo() + # A = A.masked_fill(outlier_mask, 0.0) + coo_tensor = extract_outliers_new(A, threshold) else: coo_tensor = None - # Quantize - scaled_A = A.mul(C) - # quant_row = torch.round(A * (C / row_stats.unsqueeze(-1))).to(torch.int8) - # quant_col = torch.round(A * (C / col_stats.unsqueeze(0))).to(torch.int8) - quant_row = torch.round(scaled_A / row_stats.unsqueeze(-1)).to(torch.int8) - quant_col = torch.round(scaled_A / col_stats.unsqueeze(0)).to(torch.int8) - - if out_row is not None: - quant_row = out_row.copy_(quant_row) - if out_col is not None: - quant_col = out_col.copy_(quant_col) - - return quant_row, quant_col, row_stats.flatten().float(), col_stats.flatten().float(), coo_tensor + is_on_gpu([A, row_stats]) - -def double_quant_old(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): - device = A.device - assert A.dtype == torch.half - assert device.type == "cuda" - prev_device = pre_call(A.device) - - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - - if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) - - if out_col is None: - out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) - if out_row is None: - out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) - - coo_tensor = None - ptrA = get_ptr(A) - ptrColStats = get_ptr(col_stats) - ptrRowStats = get_ptr(row_stats) - ptrOutCol = get_ptr(out_col) - ptrOutRow = get_ptr(out_row) - - is_on_gpu([A, col_stats, row_stats, out_col, out_row]) - if threshold > 0.0: - nnz = nnz_row_ptr[-1].item() - if nnz > 0: - coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) - ptrRowIdx = get_ptr(coo_tensor.rowidx) - ptrColIdx = get_ptr(coo_tensor.colidx) - ptrVal = get_ptr(coo_tensor.values) - ptrRowPtr = get_ptr(nnz_row_ptr) - - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - ptrRowIdx, - ptrColIdx, - ptrVal, - ptrRowPtr, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - val, idx = torch.sort(coo_tensor.rowidx) - coo_tensor.rowidx = val - coo_tensor.colidx = coo_tensor.colidx[idx] - coo_tensor.values = coo_tensor.values[idx] - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(0.0), - ct.c_int32(rows), - ct.c_int32(cols), - ) - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, + with torch.cuda.device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols), ) - post_call(prev_device) - return out_row, out_col, row_stats, col_stats, coo_tensor + # TODO: col_stats + + return out_row, None, row_stats, None, coo_tensor # coo_tensor #col_stats.flatten().float(), coo_tensor def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 1e5a334ee..fee15b000 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1008,9 +1008,9 @@ def forward(self, x: torch.Tensor): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights: - if self.state.CB is not None: - self.weight.data = self.state.CB + if not self.state.has_fp16_weights and self.state.CB is not None: + self.weight.data = self.state.CB + return out diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 3e807d6e1..dd4de5df6 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -186,7 +186,9 @@ class SwitchBackBnb(torch.autograd.Function): @staticmethod # TODO: the B008 on the line below is a likely bug; the current implementation will # have each SwitchBackBnb instance share a single MatmulLtState instance!!! - def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B008 + def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): + state = state or MatmulLtState() + # default to pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: @@ -222,7 +224,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # idx = torch.unique(coo_tensorA.colidx).long() idx = torch.unique(coo_tensorA._indices()[1]).long() CA[:, idx] = 0 - CAt[:, idx] = 0 + # CAt[:, idx] = 0 subA = A[:, idx] state.subB = B[:, idx].t().contiguous() state.idx = idx @@ -249,7 +251,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 state.CBt, state.SCB, state.SCBt, - coo_tensorB, + _, ) = F.double_quant(B.to(torch.float16)) state.SB = (state.CB.shape, "row") else: @@ -257,21 +259,13 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 if coo_tensorA is not None and not state.has_fp16_weights: # extract outliers + state.idx = torch.unique(coo_tensorA._indices()[1]).long() - # outlier_idx = torch.unique(coo_tensorA.colidx) - outlier_idx = torch.unique(coo_tensorA._indices()[1]).long() - state.idx = outlier_idx - # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # # do not use pool for 2nd FFN layer - # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - # else: - # state.idx = outlier_idx # outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) outliers = state.CB[:, state.idx.long()].clone() state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) CA[:, state.idx.long()] = 0 - CAt[:, state.idx.long()] = 0 + # CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] shapeB = state.SB[0] @@ -318,6 +312,7 @@ def backward(ctx, grad_output): if ctx.is_empty: bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA, A = ctx.tensors SCAt, idx = ctx.tensor_states @@ -340,11 +335,10 @@ def backward(ctx, grad_output): grad_B = torch.matmul(grad_output.t(), A) if req_gradA: - if state.CBt is not None: - gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t()) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) - - elif state.CB is not None: + # if state.CBt is not None: + # gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t()) + # grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) + if state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: diff --git a/csrc/common.cuh b/csrc/common.cuh new file mode 100644 index 000000000..8c85accfd --- /dev/null +++ b/csrc/common.cuh @@ -0,0 +1,48 @@ +#pragma once + +// TODO: Let's make some of these constexpr and put in a namespace. + +#define BNB_CC_MAXWELL 500 +#define BNB_CC_MAXWELL2 520 +#define BNB_CC_MAXWELL2_X1 530 +#define BNB_CC_PASCAL 600 +#define BNB_CC_PASCAL_X2 620 +#define BNB_CC_VOLTA 700 +#define BNB_CC_VOLTA_XAVIER 720 +#define BNB_CC_TURING 750 +#define BNB_CC_AMPERE 800 +#define BNB_CC_AMPERE2 860 +#define BNB_CC_AMPERE2_ORIN 870 +#define BNB_CC_ADA 890 +#define BNB_CC_HOPPER 900 +#define BNB_CC_BLACKWELL 1000 + +#define BNB_FP16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_MAXWELL2_X1) +#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA) +#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER) +#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE) +#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA) + +#define BNB_WARP_SIZE 32 + +// The maximum number of resident threads per SM varies by arch. +// For A100/H100 and all prior to Turing, it is 2048, which allows +// for 2 full blocks of 1024 threads per SM. +// Reference: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability +#if __CUDA_ARCH__ == 750 +#define BNB_MAX_THREADS_PER_SM 1024 +#elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890 +#define BNB_MAX_THREADS_PER_SM 1536 +#else +#define BNB_MAX_THREADS_PER_SM 2048 +#endif + +// Maximum resident warps per SM is always directly related to the number of threads. +#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE)) + +// Maximum resident blocks per SM may vary. +#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 +#define BNB_MAX_BLOCKS_PER_SM 16 +#else +#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2) +#endif diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 34de9d5ca..d0ea0d270 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3,7 +3,8 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -#include +#include "kernels.cuh" +#include "common.cuh" #include #include #include @@ -2129,6 +2130,106 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char } } +// Inputs: +// A [rows, cols] +// Outputs: +// rowStats [rows] +// out [rows, cols] +template +__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { + using BlockReduceT = cub::BlockReduce; + + // One block per row. + // Threads load column values in a striped arrangement. + // e.g. t0 reads row[0], row[0+nthreads], .. + // and t1 reads row[1], row[1+nthreads], .. + // Each thread will determine its local absmax. + // We then do a blockwise reduction to determine the row's absmax. + + __shared__ typename BlockReduceT::TempStorage temp_storage; + __shared__ float smem_row_absmax; + + const int row_id = blockIdx.x; + const T* __restrict__ row_data = A + (row_id * cols); + + // Threads will read the row values in a striped access pattern and find a local absmax. + float row_local_absmax = -FLT_MIN; + for (int i = threadIdx.x; i < cols; i += THREADS) { + const float absval = fabsf(__ldg(&(row_data[i]))); + + // For sparse decomposition, values outside of the threshold are not to be + // included when calculating the row's absmax. + if constexpr (SPARSE_DECOMP) { + row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); + } else { + row_local_absmax = fmaxf(row_local_absmax, absval); + } + } + + // Reduce thread-local absmax across the block. + // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY + const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); + if (threadIdx.x == 0) { + // Save our block's absmax to shared memory for the quantization step. + rowStats[row_id] = smem_row_absmax = row_absmax; + } + __syncthreads(); + + // Quantize row-wise. + const float scale = __fdividef(127.0f, smem_row_absmax); + for (int i = threadIdx.x; i < cols; i += THREADS) { + if constexpr (SPARSE_DECOMP) { + // For sparse decomposition, we do not want to quantize the outliers. + // Instead they're zeroed out. + float val = row_data[i]; + out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0; + } else { + out[row_id * cols + i] = __float2int_rn(float(row_data[i]) * scale); + } + } +} + +template +__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) { + using BlockReduceT = cub::BlockReduce; + + // One block per row. + // Threads load column values in a striped arrangement. + // e.g. t0 reads row[0], row[0+nthreads], .. + // and t1 reads row[1], row[1+nthreads], .. + // Each thread will determine its local absmax. + // We then do a blockwise reduction to determine the row's absmax. + + __shared__ typename BlockReduceT::TempStorage temp_storage; + + const int row_id = blockIdx.x; + const T* __restrict__ row_data = A + (row_id * cols); + + // Threads will read the row values in a striped access pattern and find a local absmax. + float row_local_absmax = -FLT_MIN; + for (int i = threadIdx.x; i < cols; i += THREADS) { + const float absval = fabsf(row_data[i]); + + // For sparse decomposition, values outside of the threshold are not to be + // included when calculating the row's absmax. + if constexpr (SPARSE_DECOMP) { + row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); + } else { + row_local_absmax = fmaxf(row_local_absmax, absval); + } + } + + // Reduce thread-local absmax across the block. + // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY + const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); + if (threadIdx.x == 0) { + // Save our block's absmax to shared memory for the quantization step. + rowStats[row_id] = row_absmax; + } +} + template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols) { // 0. reset stats to -FLT_MAX @@ -2283,6 +2384,12 @@ template(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); +template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); + +template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); +template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); + #define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 1e094dbd2..f17bfe4de 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -117,6 +117,9 @@ template __global__ void kdequant_mm_int32_fp half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols); +template __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); + template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); diff --git a/csrc/ops.cu b/csrc/ops.cu index e2eddc7ab..df5ec01da 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -505,64 +505,6 @@ template int igemmlt( return has_error; } -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) -{ - int has_error = 0; - cublasLtMatmulDesc_t matmulDesc = NULL; - cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; - cublasOperation_t opT = CUBLAS_OP_T; - cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; - cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; - cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C; - cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4; - - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb)); - - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - if(FORMATB == COL_TURING) - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); - else - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); - - if(DTYPE_OUT == 32) - { - has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I)); - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - int alpha = 1, beta = 0; - has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0)); - } - else - { - has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F)); - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - if(!SCALE_ROWS) - { - float alpha = 1.0f, beta = 0.0f; - has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); - } - else - { - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); - has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); - } - } - - - if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc)); - if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc)); - if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc)); - if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); - if(has_error == 1) - printf("error detected"); - - return has_error; -} - int fill_up_to_nearest_multiple(int value, int multiple) { return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); @@ -580,6 +522,15 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, CUDA_CHECK_RETURN(cudaPeekAtLastError()); } +void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols) { + if (threshold == 0.0) { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } else { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + #define STATS_THREADS 64 #define STATS_ITEMS 4 #define STATS_ROWS 16 @@ -602,6 +553,14 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r } +void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols) { + if (threshold == 0.0) + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + else + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols) { int threads = 64; diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 9ecb93bf2..558d93008 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -29,7 +29,6 @@ exit(1); \ } } -#define THREADS_PER_BLOCKS (512) #define CHECK_CUSPARSE(value) { \ cusparseStatus_t _m_cudaStat = value; \ @@ -40,9 +39,6 @@ } } -#define THREADS_PER_BLOCKS (512) - - inline void checkCudaStatus(cudaError_t status) { if (status != cudaSuccess) { printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status)); @@ -181,8 +177,10 @@ template void trans void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols); void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); +void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols); void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); +void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); template void transformRowToFormat(char * A, char *out, int rows, int cols); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index b03b0650c..0400d9b48 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -337,7 +337,12 @@ extern "C" { dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols); } void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) { getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); } - + void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols) { + getRowStats(A, rowStats, threshold, rows, cols); + } + void cint8_vector_quant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols) { + int8VectorQuant(A, out, rowStats, threshold, rows, cols); + } void cdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols) { doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); } diff --git a/tests/test_modules.py b/tests/test_modules.py index c84ffa42a..51fb21178 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -530,7 +530,7 @@ def test_linear_kbit_fp32_bias(module): def test_kbit_backprop(module): b = 16 dim1 = 36 - dim2 = 56 + dim2 = 84 # dim1 = 37 # dim2 = 83