Skip to content

Commit

Permalink
remove unused kernels; improved type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Nov 5, 2024
1 parent b1c4adc commit ed922b8
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 319 deletions.
266 changes: 2 additions & 264 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2233,160 +2233,6 @@ __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshol
}
}

template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols)
{
// 0. reset stats to -FLT_MAX
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
// 2. compute col max (per thread); store in smem due to register pressure
// 3. compute row max (per block); store in smem to accumulate full global mem transation
// 4. store data via atomicMax

// each block loads TILE_COLs columns and TILE_ROW rows
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
const int base_idx = (base_row*cols) + base_col;
const int items_per_load = ITEMS_PER_THREAD*THREADS;

typedef cub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadT;
typedef cub::BlockReduce<float, THREADS> BlockRowReduce;
typedef cub::BlockReduce<int, THREADS> BlockRowSum;
typedef cub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange;

__shared__ union {
typename BlockExchange::TempStorage exchange;
typename BlockRowReduce::TempStorage rowreduce;
typename BlockRowSum::TempStorage rowsum;
typename LoadT::TempStorage loadt;
} temp_storage;

__shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
__shared__ int smem_row_nnz_values[TILE_ROWS];

half local_data[ITEMS_PER_THREAD];
float local_data_fp32[ITEMS_PER_THREAD];
float local_col_absmax_values[ITEMS_PER_THREAD];
int local_row_nnz_count = 0;
float row_absmax = -FLT_MAX;

// 0. reset stats to -FLT_MAX
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
// smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0;
}

#pragma unroll TILE_ROWS
for (int j = 0; j < TILE_ROWS; j++) {
smem_row_nnz_values[j] = 0;
}

#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_col_absmax_values[j] = -FLT_MAX;

__syncthreads();

int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
int i = base_idx;
// we load row after row from the base_position
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
for(int row = 0; row < TILE_ROWS; row++)
{
if(base_row+row >= rows){ break; }
local_row_nnz_count = 0;
i = base_idx + ((row)*cols);
// each thread gets data from the same column
__syncthreads();
LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f));

#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data[j] = fabsf(local_data[j]);


if(SPARSE_DECOMP)
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
if((float)local_data[j] >= nnz_threshold)
{
local_row_nnz_count += 1;
local_data[j] = 0.0f;
}
}

// 2. compute col max (per thread); store in smem due to register pressure
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
// take the col max for this row
// we use shared memory because register pressure is too high if we do this locally
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j]));
local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));

// 3. compute row max (per block); store in smem to accumulate full global mem transation

// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data_fp32[j] = local_data[j];

__syncthreads();

row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
if(SPARSE_DECOMP)
{
__syncthreads();
local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count);
}
// we store the data temporarily in shared memory so we
// can execute a full atomic block transaction into global memory later
// we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores
if(threadIdx.x == 0)
{
smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax;
// each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block
smem_row_nnz_values[row] = local_row_nnz_count;
}

__syncthreads();

}

// 4. store data via atomicMax
// to store col data efficiently we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0
// into a striped arrangement: [0, 8, 16, 24, ..] for t0
__syncthreads();
BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values);

#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
if(base_col+threadIdx.x+(j*THREADS) < cols)
{
float val = colStats[base_col+(threadIdx.x+(j*THREADS))];
if(val < local_col_absmax_values[j])
atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]);
}

for(int j = 0; j < ITEMS_PER_THREAD; j++)
if(base_row+threadIdx.x+(j*THREADS) < rows)
{
float val = rowStats[base_row+(threadIdx.x+(j*THREADS))];
if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)])
atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]);
}

if(SPARSE_DECOMP)
if(threadIdx.x < TILE_ROWS)
nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x];

}

template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 0>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
template __global__ void kgetRowStats<half, 1024, 0>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
template __global__ void kgetRowStats<half, 1024, 1>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);

Expand Down Expand Up @@ -2430,16 +2276,15 @@ __global__ void kdequant_mm_int32_fp16(
row_idx = (block_offset + thread_offset + j) / numCols;
col_idx = (block_offset + thread_offset + j) % numCols;

local_colStats[j] = col_idx >= numCols ? 0.0f : colStats[col_idx];
local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx];
local_colStats[j] = col_idx >= numCols ? 0.0f : __ldg(&colStats[col_idx]);
local_rowStats[j] = row_idx >= numRows ? 0.0f : __ldg(&rowStats[row_idx]);
local_biasValue[j] = ((bias == nullptr) || col_idx >= numCols) ? 0.0f : __half2float(bias[col_idx]);
}

// Each block loads THREADS * ITEMS_PER_THREAD values from A
int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out
? THREADS * ITEMS_PER_THREAD
: n_out - block_offset;
__syncthreads();
LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0);

#pragma unroll ITEMS_PER_THREAD
Expand All @@ -2458,110 +2303,6 @@ __global__ void kdequant_mm_int32_fp16(
}
}

template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols)
{
// assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD
// Each thread reads the same column but multiple rows
// Rows are loaded in shared memory and access is shared across the threadblock (broadcast)

// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
// 2. quantize data with row/col stats
// 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance)

// each block loads TILE_COLs columns and TILE_ROW rows
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
const int base_idx = (base_row*cols) + base_col;
const int items_per_load = ITEMS_PER_THREAD*THREADS;

typedef cub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf;
__shared__ typename LoadHalf::TempStorage loadhalf;
typedef cub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8;
__shared__ typename StoreInt8::TempStorage storeint8;

__shared__ float smem_row_stats[TILE_ROWS];
__shared__ unsigned int smem_nnz_row_idx[TILE_ROWS];

half local_data[ITEMS_PER_THREAD];
float local_col_stats[ITEMS_PER_THREAD];
char local_quantized_data[ITEMS_PER_THREAD];

// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols)
local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]);

for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x)
{
if(base_row + i < rows)
smem_row_stats[i] = rowStats[base_row+i];

if(SPARSE_DECOMP)
smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i];
}
__syncthreads();

// we load row after row from the base_position
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
for(int row = 0; row < TILE_ROWS; row++)
{
if(base_row + row >= rows){ break; }
int i = base_idx + (row*cols);
int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;


LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f);
float row_stat = __fdividef(127.0f, smem_row_stats[row]);

// 2. quantize data with row/col stats
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
// we already pre-normalized the col/row stat:
// what this does is float/absmax*127 = int8
if(SPARSE_DECOMP)
{
if(fabsf((float)local_data[j]) >= threshold)
{
local_quantized_data[j] = 0;

int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX);

rowidx[old_idx] = base_row+row;
colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j;
val[old_idx] = local_data[j];
}
else
{
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
}
}
else
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
}

StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items);

// 2. quantize data with row/col stats
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
// we already pre-normalized the col/row stat:
// what this does is float/absmax*127 = int8
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j]));
}

__syncthreads();
StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items);

}
}

template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols)
{

Expand Down Expand Up @@ -3864,9 +3605,6 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(

template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);

template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);

template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);

Expand Down
3 changes: 0 additions & 3 deletions csrc/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,9 @@ template <int ITEMS_PER_THREAD, int THREADS>__global__ void kdequant_mm_int32_fp
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);

template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);

template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);

template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);

template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
Expand Down
47 changes: 2 additions & 45 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,8 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
NULL, NULL, 0, stream
));
} else {
// This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows.

if (!SCALE_ROWS) {
float alpha = 1.0f, beta = 0.0f;
has_error |= checkCublasStatus(cublasLtMatmul(
Expand Down Expand Up @@ -532,28 +534,6 @@ void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

#define STATS_THREADS 64
#define STATS_ITEMS 4
#define STATS_ROWS 16
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
{
int tile_cols = STATS_THREADS*STATS_ITEMS;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
int row_tiles = (tiledRows/STATS_ROWS);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;

if(nnz_threshold == 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
else if(nnz_threshold != 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 1><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());

}

void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
if (threshold == 0.0)
kgetRowStats<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
Expand All @@ -562,29 +542,6 @@ void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols,
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols)
{
int threads = 64;
int items_per_thread = 4;
int tile_cols = threads*items_per_thread;
int tile_rows = 16;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int row_tiles = (tiledRows/tile_rows);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;


if(threshold > 0.0f)
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
else
kDoubleRowColQuant<64, 4, 16, 64*4, 0><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);

CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols)
{
int threads = 256;
Expand Down
3 changes: 0 additions & 3 deletions csrc/ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,7 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle,
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream);
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed,
int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols);
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);

template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
Expand Down
Loading

0 comments on commit ed922b8

Please sign in to comment.