From 75d9277ffc37b900ab8f8ba4b9b7de5b8de81c90 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Tue, 5 Sep 2023 11:02:56 -0700 Subject: [PATCH 01/66] init kernel headers --- csrc/kernels.dp.cpp | 0 csrc/kernels.dp.hpp | 212 +++++++++++++++++++++++++++++++++++++++++++ csrc/ops.dp.cpp | 0 csrc/ops.dp.hpp | 215 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 427 insertions(+) create mode 100644 csrc/kernels.dp.cpp create mode 100644 csrc/kernels.dp.hpp create mode 100644 csrc/ops.dp.cpp create mode 100644 csrc/ops.dp.hpp diff --git a/csrc/kernels.dp.cpp b/csrc/kernels.dp.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/csrc/kernels.dp.hpp b/csrc/kernels.dp.hpp new file mode 100644 index 000000000..4b646b2de --- /dev/null +++ b/csrc/kernels.dp.hpp @@ -0,0 +1,212 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include "ops.dp.hpp" + +#ifndef kernels +#define kernels + +#pragma once + +//template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template extern SYCL_EXTERNAL void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n, + const sycl::nd_item<3> &item_ct1, + uint8_t *temp_storage_ct1); + +extern SYCL_EXTERNAL void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_code); +extern SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_code); + +template extern SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, + const sycl::nd_item<3> &item_ct1, + typename LoadT::TempStorage &loadt, + typename LoadFloat::TempStorage &loadf, + typename StoreChar::TempStorage &storec, + typename BlockReduce::TempStorage &reduce, + float *smem_code, + float *smem_absmax_value); +template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1, + typename LoadChar::TempStorage &loadchar, + typename StoreT::TempStorage &storet); + +template +extern SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + +template +extern SYCL_EXTERNAL void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + +template +extern SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + +template +extern SYCL_EXTERNAL void kOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + +template +extern SYCL_EXTERNAL void +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, + float *smem_quantiles1); + + +template +extern SYCL_EXTERNAL void +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, + uint8_t *temp_storage_ct1); + + + +template +extern SYCL_EXTERNAL void +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, + float *smem_quantiles1, float *smem_quantiles2); + + +template +extern SYCL_EXTERNAL void +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, + float *smem_quantiles2, uint8_t *temp_storage_ct1); + +template extern SYCL_EXTERNAL void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + sycl::local_accessor smem_quantiles2, + typename BlockReduce1::TempStorage &reduce1, + typename BlockReduce2::TempStorage &reduce2, + float *smem_exchange1, float *smem_exchange2, + uint8_t *temp_storage_ct1); + +template extern SYCL_EXTERNAL void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + typename BlockReduce1::TempStorage &reduce1, + float *smem_exchange1, uint8_t *temp_storage_ct1); + + +template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n, + const sycl::nd_item<3> &item_ct1, + typename BlockReduce::TempStorage &reduce, + typename LoadT::TempStorage &loadT); + + +void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, + const sycl::nd_item<3> &item_ct1); + +template +extern SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, + int *offset_rowidx, int *rowidx, int *colidx, + sycl::half *values, T *B, sycl::half *out, + float *__restrict__ const dequant_stats, + int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); + +template +extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( + int *__restrict__ const A, float *__restrict__ const rowStats, + float *__restrict__ const colStats, sycl::half *out, float *newRowStats, + float *newcolStats, sycl::half *__restrict__ const bias, const int numRows, + const int numCols, const int tileCols, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_rowStats, + typename LoadInt32::TempStorage &loadint32, + typename ExchangeInt32::TempStorage &exchangeint32); + +template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, + const sycl::nd_item<3> &item_ct1, + uint8_t *temp_storage_ct1, + float *smem_row_absmax_values, + int *smem_row_nnz_values); +template +extern SYCL_EXTERNAL void kDoubleRowColQuant(sycl::half *__restrict__ const A, + float *__restrict__ const rowStats, + float *__restrict__ const colStats, + char *out_col_normed, char *out_row_normed, int *rowidx, + int *colidx, sycl::half *val, + int *__restrict__ nnz_block_ptr, float threshold, + int rows, int cols, int tiledCols, + const sycl::nd_item<3> &item_ct1, + typename LoadHalf::TempStorage &loadhalf, + typename StoreInt8::TempStorage &storeint8, + float *smem_row_stats, unsigned int *smem_nnz_row_idx); + + +template extern SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, + const sycl::nd_item<3> &item_ct1, + char *smem_data); + +template extern SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, + const sycl::nd_item<3> &item_ct1); + +template extern SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template extern SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template extern SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + T *quant_map); + +template extern SYCL_EXTERNAL void kfunc(T *A, T *B, T value, long n, + const sycl::nd_item<3> &item_ct1); + +#endif diff --git a/csrc/ops.dp.cpp b/csrc/ops.dp.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/csrc/ops.dp.hpp b/csrc/ops.dp.hpp new file mode 100644 index 000000000..84850263b --- /dev/null +++ b/csrc/ops.dp.hpp @@ -0,0 +1,215 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + + +#ifndef ops_H +#define ops_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + + +#define THREADS_PER_BLOCKS (512) + +inline void checkCudaStatus(int status) { + /* + DPCT1000:93: Error handling if-stmt was detected but could not be rewritten. + */ + if (status != 0) { + /* + DPCT1009:94: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + printf( + "cuda API failed with status %d: %s\n", status, + "cudaGetErrorString is not supported" /*cudaGetErrorString(status)*/); + /* + DPCT1001:92: The statement could not be removed. + */ + throw std::logic_error("cuda API failed"); + } +} + +inline int checkCublasStatus(int status) { + if (status != 0) { + printf("cuBLAS API failed with status %d\n", status); + //throw std::logic_error("cuBLAS API failed"); + return 1; + } + return 0; +} + +typedef enum Operations_t +{ + ksmul = 0, +} Operations_t; + +typedef enum Optimizer_t +{ + ADAM = 0, + MOMENTUM = 1, + RMSPROP = 2, + LARS = 3, + ADAGRAD = 4, + LION = 5, +} Optimizer_t; + +typedef enum Transform_t +{ + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, +} Transform_t; + +typedef enum DataType_t +{ + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +typedef enum Funcs_t +{ + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; + +class Context +{ + public: + dpct::queue_ptr m_handle; + + Context() + { + dpct::queue_ptr handle; + handle = &dpct::get_default_queue(); + m_handle = handle; + } + +}; + +class ContextLt +{ + public: + dpct::queue_ptr m_handle; + + ContextLt() + { + dpct::queue_ptr handle; + handle = &dpct::get_default_queue(); + m_handle = handle; + } + +}; + +class ContextCusparse +{ + public: + sycl::queue *m_handle; + + ContextCusparse() + { + sycl::queue *handle; + handle = &dpct::get_default_queue(); + m_handle = handle; + } + +}; + + +template void estimateQuantiles(T *A, float *code, float offset, int n); + +void quantize(float *code, float *A, unsigned char *out, int n); +void dequantize(float *code, unsigned char *A, float *out, int n); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, float eps, float weight_decay, + int step, float lr, const float gnorm_scale, bool skip_zeros, int n); + +template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n); + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, + bool skip_zeros, int n); + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); + +extern SYCL_EXTERNAL void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); + +void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount); + + +template int igemmlt(dpct::queue_ptr ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + +template void transform(dpct::queue_ptr ltHandle, T *A, T *out, int dim1, int dim2); +void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, + sycl::half *out, float *newRowStats, + float *newcolStats, sycl::half *bias, int numRows, + int numCols); +void getColRowStats(sycl::half *A, float *rowStats, float *colStats, + int *nnz_count_row, float nnz_threshold, int rows, + int cols); +void doubleRowColQuant(sycl::half *A, float *rowStats, float *colStats, + char *out_col_normed, char *out_row_normed, int *rowidx, + int *colidx, sycl::half *val, int *nnz_block_ptr, + float threshold, int rows, int cols); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +void spmm_coo(sycl::queue *handle, int *A_rowidx, int *A_colidx, + sycl::half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, + int ldb, sycl::half *B, int ldc, sycl::half *C, + bool transposed_B); + +template +void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, + int *offset_rowidx, int *rowidx, int *colidx, + sycl::half *values, T *B, sycl::half *out, + float *dequant_stats, int nnz_rows, int nnz, + int rowsA, int rowsB, int colsB); + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + +void matmul4bite(sycl::half *A, unsigned char *B, sycl::half *out, int lda, + int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); + +template void func(T *A, T *B, T value, long n); + +#endif From 279fce44dc013c5e132257aff53c2b47d68eb198 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 20 Sep 2023 01:43:33 -0700 Subject: [PATCH 02/66] modify kernel header --- csrc/kernels.dp.hpp | 102 ++++++++++---------------------------------- 1 file changed, 23 insertions(+), 79 deletions(-) diff --git a/csrc/kernels.dp.hpp b/csrc/kernels.dp.hpp index 4b646b2de..f0111be37 100644 --- a/csrc/kernels.dp.hpp +++ b/csrc/kernels.dp.hpp @@ -24,46 +24,32 @@ extern SYCL_EXTERNAL void kQuantize(float * code, float * __restrict__ const A, extern SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code); -template extern SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, - const sycl::nd_item<3> &item_ct1, - typename LoadT::TempStorage &loadt, - typename LoadFloat::TempStorage &loadf, - typename StoreChar::TempStorage &storec, - typename BlockReduce::TempStorage &reduce, - float *smem_code, - float *smem_absmax_value); -template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1, - typename LoadChar::TempStorage &loadchar, - typename StoreT::TempStorage &storet); +template extern SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); template extern SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + const int step, const float lr, const float gnorm_scale, const int n); template extern SYCL_EXTERNAL void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template extern SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + const int step, const float lr, const float gnorm_scale, const int n); template extern SYCL_EXTERNAL void kOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template extern SYCL_EXTERNAL void @@ -74,9 +60,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c float* __restrict__ const quantiles1, float* max1, float* new_max1, const float weight_decay, - const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, - float *smem_quantiles1); + const float gnorm_scale, const int n); template @@ -87,9 +71,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* max1, float* new_max1, - float weight_decay, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, - uint8_t *temp_storage_ct1); + float weight_decay, const float gnorm_scale, const int n); @@ -101,9 +83,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c const float eps, const int step, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, - const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, - float *smem_quantiles1, float *smem_quantiles2); + const float gnorm_scale, const int n); template @@ -114,22 +94,13 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, - float *smem_quantiles2, uint8_t *temp_storage_ct1); + float weight_decay, const float gnorm_scale, const int n); template extern SYCL_EXTERNAL void kOptimizerStatic8bit2StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, - sycl::local_accessor smem_quantiles1, - sycl::local_accessor smem_quantiles2, - typename BlockReduce1::TempStorage &reduce1, - typename BlockReduce2::TempStorage &reduce2, - float *smem_exchange1, float *smem_exchange2, - uint8_t *temp_storage_ct1); + float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); template extern SYCL_EXTERNAL void kOptimizerStatic8bit1StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, @@ -138,46 +109,29 @@ template extern SYCL_EX float* __restrict__ const quantiles1, float* absmax1, float weight_decay, - const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, - sycl::local_accessor smem_quantiles1, - typename BlockReduce1::TempStorage &reduce1, - float *smem_exchange1, uint8_t *temp_storage_ct1); + const float gnorm_scale, const bool skip_zeros, const int n); -template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n, - const sycl::nd_item<3> &item_ct1, - typename BlockReduce::TempStorage &reduce, - typename LoadT::TempStorage &loadT); +template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); -void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, - const sycl::nd_item<3> &item_ct1); +void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); template extern SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float *__restrict__ const dequant_stats, - int nnz, int rowsA, int rowsB, int colsB, - const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); + int nnz, int rowsA, int rowsB, int colsB); template extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float *newRowStats, float *newcolStats, sycl::half *__restrict__ const bias, const int numRows, - const int numCols, const int tileCols, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_rowStats, - typename LoadInt32::TempStorage &loadint32, - typename ExchangeInt32::TempStorage &exchangeint32); - -template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, - const sycl::nd_item<3> &item_ct1, - uint8_t *temp_storage_ct1, - float *smem_row_absmax_values, - int *smem_row_nnz_values); + const int numCols, const int tileCols, const int n); + +template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); template extern SYCL_EXTERNAL void kDoubleRowColQuant(sycl::half *__restrict__ const A, @@ -186,27 +140,17 @@ extern SYCL_EXTERNAL void kDoubleRowColQuant(sycl::half *__restrict__ const A, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int *__restrict__ nnz_block_ptr, float threshold, - int rows, int cols, int tiledCols, - const sycl::nd_item<3> &item_ct1, - typename LoadHalf::TempStorage &loadhalf, - typename StoreInt8::TempStorage &storeint8, - float *smem_row_stats, unsigned int *smem_nnz_row_idx); + int rows, int cols, int tiledCols); -template extern SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, - const sycl::nd_item<3> &item_ct1, - char *smem_data); +template extern SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); -template extern SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, - const sycl::nd_item<3> &item_ct1); +template extern SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template extern SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); template extern SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template extern SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, - const sycl::nd_item<3> &item_ct1, - T *quant_map); +template extern SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); -template extern SYCL_EXTERNAL void kfunc(T *A, T *B, T value, long n, - const sycl::nd_item<3> &item_ct1); +template extern SYCL_EXTERNAL void kfunc(T *A, T *B, T value, long n); #endif From d9dcee915fbcad4f3747d886cf99fb22fc9f7f33 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Thu, 21 Sep 2023 07:28:58 -0700 Subject: [PATCH 03/66] modify headers --- csrc/kernels.dp.hpp | 349 ++++++++++++++++++++++++-------------------- 1 file changed, 193 insertions(+), 156 deletions(-) diff --git a/csrc/kernels.dp.hpp b/csrc/kernels.dp.hpp index f0111be37..499493f23 100644 --- a/csrc/kernels.dp.hpp +++ b/csrc/kernels.dp.hpp @@ -1,156 +1,193 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include "ops.dp.hpp" - -#ifndef kernels -#define kernels - -#pragma once - -//template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); - -template extern SYCL_EXTERNAL void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n, - const sycl::nd_item<3> &item_ct1, - uint8_t *temp_storage_ct1); - -extern SYCL_EXTERNAL void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_code); -extern SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_code); - -template extern SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); - -template -extern SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n); - -template -extern SYCL_EXTERNAL void kOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); - -template -extern SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(T* g, T* p, - float* state1, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n); - -template -extern SYCL_EXTERNAL void kOptimizer32bit1State(T* g, T* p, - float* state1, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); - -template -extern SYCL_EXTERNAL void -kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, - float *unorm, - const float beta1, const float beta2, - const float eps, const int step, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, - const float weight_decay, - const float gnorm_scale, const int n); - - -template -extern SYCL_EXTERNAL void -kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, - const float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, - float weight_decay, const float gnorm_scale, const int n); - - - -template -extern SYCL_EXTERNAL void -kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, - float *unorm, - const float beta1, const float beta2, - const float eps, const int step, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - const float gnorm_scale, const int n); - - -template -extern SYCL_EXTERNAL void -kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, - const float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, const float gnorm_scale, const int n); - -template extern SYCL_EXTERNAL void kOptimizerStatic8bit2StateBlockwise( - T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, - const float beta1, const float beta2, const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); - -template extern SYCL_EXTERNAL void kOptimizerStatic8bit1StateBlockwise( - T* p, T* __restrict__ const g, unsigned char* state1, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, - float* absmax1, - float weight_decay, - const float gnorm_scale, const bool skip_zeros, const int n); - - -template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); - - -void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); - -template -extern SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, - int *offset_rowidx, int *rowidx, int *colidx, - sycl::half *values, T *B, sycl::half *out, - float *__restrict__ const dequant_stats, - int nnz, int rowsA, int rowsB, int colsB); - -template -extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( - int *__restrict__ const A, float *__restrict__ const rowStats, - float *__restrict__ const colStats, sycl::half *out, float *newRowStats, - float *newcolStats, sycl::half *__restrict__ const bias, const int numRows, - const int numCols, const int tileCols, const int n); - -template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); -template -extern SYCL_EXTERNAL void kDoubleRowColQuant(sycl::half *__restrict__ const A, - float *__restrict__ const rowStats, - float *__restrict__ const colStats, - char *out_col_normed, char *out_row_normed, int *rowidx, - int *colidx, sycl::half *val, - int *__restrict__ nnz_block_ptr, float threshold, - int rows, int cols, int tiledCols); - - -template extern SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); - -template extern SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); - -template extern SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); -template extern SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template extern SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); - -template extern SYCL_EXTERNAL void kfunc(T *A, T *B, T value, long n); - -#endif +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include "ops.dp.hpp" + +#ifndef kernels +#define kernels + +#pragma once + +//template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template extern SYCL_EXTERNAL void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n, + const sycl::nd_item<3> &item_ct1, + uint8_t *temp_storage_ct1); + +extern SYCL_EXTERNAL void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, + const sycl::nd_item<3> &item_ct1); +extern SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const int n, + const sycl::nd_item<3> &item_ct1); + +template extern SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, + const sycl::nd_item<3> &item_ct1, + float *smem_code, + float *smem_absmax_value); +template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1); + +template +extern SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + +template +extern SYCL_EXTERNAL void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + +template +extern SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + +template +extern SYCL_EXTERNAL void kOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + +template +extern SYCL_EXTERNAL void +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, + float *smem_quantiles1); + + +template +extern SYCL_EXTERNAL void +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, + uint8_t *temp_storage_ct1); + + + +template +extern SYCL_EXTERNAL void +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, + float *smem_quantiles1, float *smem_quantiles2); + + +template +extern SYCL_EXTERNAL void +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, + float *smem_quantiles2, uint8_t *temp_storage_ct1); + +template extern SYCL_EXTERNAL void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + sycl::local_accessor smem_quantiles2, + float *smem_exchange1, float *smem_exchange2, + uint8_t *temp_storage_ct1); + +template extern SYCL_EXTERNAL void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + float *smem_exchange1, uint8_t *temp_storage_ct1); + + +template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n,const sycl::nd_item<3> &item_ct1); + + +void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, + const sycl::nd_item<3> &item_ct1); + +template +extern SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, + int *offset_rowidx, int *rowidx, int *colidx, + sycl::half *values, T *B, sycl::half *out, + float *__restrict__ const dequant_stats, + int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); + +template +extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( + int *__restrict__ const A, float *__restrict__ const rowStats, + float *__restrict__ const colStats, sycl::half *out, float *newRowStats, + float *newcolStats, sycl::half *__restrict__ const bias, const int numRows, + const int numCols, const int tileCols, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_rowStats); + +template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, + const sycl::nd_item<3> &item_ct1,uint8_t *temp_storage_ct1,float *smem_row_absmax_values,int *smem_row_nnz_values); +template +extern SYCL_EXTERNAL void kDoubleRowColQuant(sycl::half *__restrict__ const A, + float *__restrict__ const rowStats, + float *__restrict__ const colStats, + char *out_col_normed, char *out_row_normed, int *rowidx, + int *colidx, sycl::half *val, + int *__restrict__ nnz_block_ptr, float threshold, + int rows, int cols, int tiledCols, + const sycl::nd_item<3> &item_ct1, + float *smem_row_stats, unsigned int *smem_nnz_row_idx); + + +template extern SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, + const sycl::nd_item<3> &item_ct1, + char *smem_data); + +template extern SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, + const sycl::nd_item<3> &item_ct1); + +template extern SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template extern SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template extern SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + T *quant_map); + +template extern SYCL_EXTERNAL void kfunc(T *A, T *B, T value, long n, + const sycl::nd_item<3> &item_ct1); + +#endif From ddae8abcf169b429115cf2caad66d64b1729ff73 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Fri, 22 Sep 2023 02:59:52 -0700 Subject: [PATCH 04/66] update kernels --- csrc/kernels.dp.cpp | 5286 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 5286 insertions(+) diff --git a/csrc/kernels.dp.cpp b/csrc/kernels.dp.cpp index e69de29bb..51b1b7abf 100644 --- a/csrc/kernels.dp.cpp +++ b/csrc/kernels.dp.cpp @@ -0,0 +1,5286 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include +#include +#include "kernels.dp.hpp" +#include + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + + +// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +float atomicMax(float* address, float val) { + int* address_as_i = reinterpret_cast(address); + int old = *address_as_i, assumed; + do { + assumed = old; + old = dpct::atomic_compare_exchange_strong< + sycl::access::address_space::generic_space>( + reinterpret_cast(address), assumed, + sycl::bit_cast(sycl::fmax(val, sycl::bit_cast(assumed)))); + } while (assumed != old); + return sycl::bit_cast(old); +} + +float atomicMin(float* address, float val) { + int* address_as_i = reinterpret_cast(address); + int old = *address_as_i, assumed; + do { + assumed = old; + old = dpct::atomic_compare_exchange_strong< + sycl::access::address_space::generic_space>( + reinterpret_cast(address), assumed, + sycl::bit_cast(sycl::fmin(val, sycl::bit_cast(assumed)))); + } while (assumed != old); + return sycl::bit_cast(old); +} + +float dDequantizeFP4(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f*absmax; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction*absmax; + } +} + +float d2DequantizeFP4(unsigned char val) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction; + } +} + +float dDequantizeFP4Tree(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 111 + return 0.25000000f*absmax*sign; // 1111 + else + return 0.16666667f*absmax*sign; // 1110 + else + if((val & 0b0001) == 1) // 110 + return 0.50000000f*absmax*sign; // 1101 + else + return 0.33333333f*absmax*sign; // 1100 + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 1.00000000f*absmax*sign; // 1011 + else + return 0.66666667f*absmax*sign; // 1010 + else + if((val & 0b0001) == 1) // 100 + return 5.208333333e-03f*absmax*sign; // 1001 + else + return 0.00000000f*absmax*sign; // 1000 +} + +unsigned char dQuantizeFP4(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assum input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to noice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = sycl::fabs(x); + if(x > 0.29166667f) + if( x > 0.583333f) + if( x > 0.8333333f) + return 0b0011+sign; + else + return 0b0010+sign; + else + if(x > 0.4166667f) + return 0b101+sign; + else + return 0b100+sign; + else + if(x > 0.0859375f) + if(x > 0.20833333f) + return 0b0111+sign; + else + return 0b0110+sign; + else + if(x > 0.00260417f) + return 0b0001+sign; + else + return 0b0000+sign; +} + +sycl::half dhDequantizeNF4(unsigned char val) +{ + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +float dDequantizeNF4(unsigned char val) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +unsigned char dQuantizeNF4(float x) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if(x > 0.03979014977812767f) + if(x > 0.3893125355243683f) // 1 + if(x > 0.6427869200706482f) // 11 + if(x > 0.8614784181118011f) // 111 + return 0b1111; + else + return 0b1110; + else + if(x > 0.5016634166240692f) // 110 + return 0b1101; + else + return 0b1100; + else + if(x > 0.2035212516784668f) // 10 + if(x > 0.2920137718319893f) // 101 + return 0b1011; + else + return 0b1010; + else + if(x > 0.1202552504837513f) // 100 + return 0b1001; + else + return 0b1000; + else + if(x > -0.33967943489551544f) // 0 + if(x > -0.13791173323988914f) // 01 + if(x > -0.045525018125772476f) // 011 + return 0b0111; + else + return 0b0110; + else + if(x > -0.23460740596055984f) // 010 + return 0b0101; + else + return 0b0100; + else + if(x > -0.6106329262256622f) // 00 + if(x > -0.4599952697753906f) // 001 + return 0b0011; + else + return 0b0010; + else + if(x > -0.8480964004993439f) // 000 + return 0b0001; + else + return 0b0000; +} +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template int sgn(T val) +{ + return (T(0) < val) - (val < T(0)); +} + +template +unsigned char dQuantize(float* smem_code, const float rand, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = -1.0f; + float upper = 1.0f; + + float val = smem_code[pivot]; + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(upper_pivot == 255) + upper = smem_code[upper_pivot]; + if(lower_pivot == 0) + lower = smem_code[lower_pivot]; + + if(!STOCHASTIC) + { + if(x > val) + { + float midpoint = (upper+val)*0.5f; + if(x > midpoint) + { + return upper_pivot; + } + else + return pivot; + } + else + { + float midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } + } + else + { + if(x > val) + { + float dist_to_upper = sycl::fabs(upper - x); + float dist_full = upper-val; + if(rand >= dist_to_upper/dist_full) return upper_pivot; + else return pivot; + } + else + { + float dist_to_lower = sycl::fabs(lower - x); + float dist_full = val-lower; + if(rand >= dist_to_lower/dist_full) return lower_pivot; + else return pivot; + } + } +} + +template +__dpct_inline__ unsigned char quantize_2D(float *__restrict__ quadrants, + float *__restrict__ const smem_code, + float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = SIGNED ? -1.0f : 0.0f; + float upper = 1.0f; + float midpoint; + float val = quadrants[1]; + int local_pivot = 1; + int offset = 1; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + //val = i == 64 ? quadrants[2] : smem_code[pivot]; + local_pivot += offset; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + //val = i == 64 ? quadrants[0] : smem_code[pivot]; + local_pivot -= offset; + } + val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; + offset -= 1; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +template +__dpct_inline__ unsigned char +quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, + float lower, float midpoint, float upper) +{ + int lower_pivot = QUADRANT*16-1 - 0; + int pivot = QUADRANT*16-1 + 16; + int upper_pivot = QUADRANT*16-1 + 31; + + float val = midpoint; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 16; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, + const sycl::nd_item<3> &item_ct1) +{ + const int tid = item_ct1.get_local_id(2) + + (item_ct1.get_local_range(2) * item_ct1.get_group(2)); + const int numThreads = + item_ct1.get_local_range(2) * item_ct1.get_group_range(2); + + for(int i = tid; i < n; i+=numThreads) + { + int idx = (index1[i]*maxidx1) + index2[i]; + dpct::atomic_fetch_add( + &histogram[idx], src[i]); + } +} + + +/* +template +class BlockLoadKernel { +public: + BlockLoadKernel(sycl::queue& queue, const float* input, float* output, size_t numElements) + : queue_(queue), input_(input), output_(output), numElements_(numElements) { + } + + void operator()() { + queue_.submit([&](sycl::handler& cgh) { + auto input = input_; + auto output = output_; + + cgh.parallel_for( + sycl::nd_range<1>(sycl::range<1>(numElements_), sycl::range<1>(BlockSize)), + [=](sycl::nd_item<1> item) { + // Create a local memory buffer to hold a block of data + dpct::local localBuffer[BlockSize / NumPerThread]; + + // Compute global index + size_t globalIndex = item.get_global_linear_id(); + + // Compute local index within the block + size_t localIndex = item.get_local_id(0) / NumPerThread; + + // Compute index within the block for each thread + size_t blockIndex = item.get_local_id(0) % (BlockSize / NumPerThread); + + // Load data from global memory into local memory + localBuffer[blockIndex] = input[globalIndex]; + + // Synchronize to ensure all work-items have loaded data + item.barrier(); + + // Use the loaded data from local memory (e.g., perform some computation) + float result = localBuffer[localIndex]; + + // Store the result back to global memory + output[globalIndex] = result; + }); + }); + } + +private: + sycl::queue& queue_; + const float* input_; + float* output_; + size_t numElements_; +}; +*/ + + +template +void kCompressMax(T *__restrict__ const A, T *out, unsigned char *out_idx, + const int n, const sycl::nd_item<3> &item_ct1) +{ + //typedef cub::WarpReduce WarpReduce; + //__shared__ typename WarpReduce::TempStorage temp_storage; + typedef cub::BlockLoad LoadT; + __shared__ typename LoadT::TempStorage loadt; + + const int warp_idx = item_ct1.get_local_id(2) / 32; + const int valid_items = n - (item_ct1.get_group(2) * BLOCK_SIZE) > BLOCK_SIZE + ? BLOCK_SIZE + : n - (item_ct1.get_group(2) * BLOCK_SIZE); + + // BLOCK_SIZE/32 == number of warps + //sycl::local_accessor smem_max_indices + //__shared__ int smem_max_indices[8*BLOCK_SIZE/32]; + //__shared__ float smem_max_values[8*BLOCK_SIZE/32]; + int smem_max_indices[8*BLOCK_SIZE/32]; + float smem_max_values[8*BLOCK_SIZE/32]; + + T values[8]; + T max1 = -64000.0f; + T max2 = -64000.0f; + int max_idx1 = -1; + int max_idx2 = -1; + int sign1 = -1; + int sign2 = -1; + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + + LoadT(loadt).Load(&(A[(item_ct1.get_group(2) * BLOCK_SIZE)]), values, + valid_items, (T)0.0f); +#pragma unroll 8 + for(int i = 0; i < 8; i++) + { + T absval = fabsf(values[i]); + if(absval > max1) + { + max1 = values[i]; + sign1 = signbit(values[i]); + max_idx1 = 8 * item_ct1.get_local_id(2) + i; + } + else if(absval > max2) + { + max2 = values[i]; + sign2 = signbit(values[i]); + max_idx2 = 8 * item_ct1.get_local_id(2) + i; + } + } + + float warp_max; + for(int i = 0; i < 8; i++) + { + // 3. do warp reduction + broadcast back + //warp_max = WarpReduce(temp_storage).Reduce(max1, sycl::maximum<>()); + warp_max = dpct::group::reduce(item_ct1, max1, sycl::maximum<>()); + warp_max = item_ct1.get_sub_group().shuffle(warp_max, 0); + + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + if(warp_max == max1) + { + smem_max_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1; + smem_max_indices[warp_idx*8 + i] = max_idx1; + + sign1 = sign2; + max1 = max2; + max_idx1 = max_idx2; + + max2 = -64000.0f; + } + sycl::group_barrier(item_ct1.get_sub_group()); + } + + if (item_ct1.get_local_id(2) % 32 < 8) + { + // offset: 8 values per 256 input values + // + int offset = BLOCK_SIZE * item_ct1.get_group(2) * BLOCK_SIZE / 32 * 8; + } + +} + +#define THREADS_ESTIMATE 512 +#define NUM_ESTIMATE 8 +#define BLOCK_ESTIMATE 4096 + +template + +void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n, + const sycl::nd_item<3> &item_ct1, + uint8_t *temp_storage_ct1) +{ + const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); + int valid_items = (item_ct1.get_group(2) + 1 == item_ct1.get_group_range(2)) + ? n - (item_ct1.get_group(2) * BLOCK_ESTIMATE) + : BLOCK_ESTIMATE; + const int base_idx = (item_ct1.get_group(2) * BLOCK_ESTIMATE); + const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); + + T vals[NUM_ESTIMATE]; + + typedef cub::BlockRadixSort BlockRadixSort; + typedef cub::BlockLoad LoadFloat; + + union type_ct1 { + typename LoadFloat::TempStorage loadf; + typename BlockRadixSort::TempStorage sort; + int smem_qidx[BLOCK_ESTIMATE]; + }; + type_ct1 &temp_storage = *(type_ct1 *)temp_storage_ct1; + + for (unsigned int i = base_idx; i < n_full; + i += item_ct1.get_group_range(2) * BLOCK_ESTIMATE) + { + valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; + + // do not process half-blocks + if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; } + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = max_val; + + /* + DPCT1065:0: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:96: Migration of cub::BlockLoad.Load is not supported. + */ + LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); + +#pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = ((float)vals[j]) * reciprocal_num_blocks; + + /* + DPCT1065:1: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + // sort into striped pattern to mitigate bank conflicts + // striped pattern index for thread 0 [0, 1024, 2048, 3096] + // striped pattern index for thread 1 [1, 1025, 2049, 3097] + /* + DPCT1007:97: Migration of cub::BlockRadixSort.SortBlockedToStriped is not + supported. + */ + BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); + + /* + DPCT1065:2: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + for (int j = item_ct1.get_local_id(2); j < BLOCK_ESTIMATE; + j += item_ct1.get_local_range(2)) + temp_storage.smem_qidx[j] = -1; + + if (item_ct1.get_local_id(2) < 256) + { + float q_interval = (1.0f-(2.0f*offset))/255.0f; + int local_idx = + sycl::round(((offset + (item_ct1.get_local_id(2) * q_interval)) * + (valid_items - 1))); + temp_storage.smem_qidx[local_idx] = item_ct1.get_local_id(2); + } + + /* + DPCT1065:3: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + for (int i = item_ct1.get_local_id(2); i < BLOCK_ESTIMATE; + i += item_ct1.get_local_range(2)) + { + if(temp_storage.smem_qidx[i] != -1) + dpct::atomic_fetch_add< + sycl::access::address_space::generic_space>( + &code[temp_storage.smem_qidx[i]], vals[i / THREADS_ESTIMATE]); + } + } +} + + +void kQuantize(float *code, float *__restrict__ const A, unsigned char *out, + const int n, const sycl::nd_item<3> &item_ct1, float *smem_code) +{ + const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + int valid_items = (item_ct1.get_group(2) + 1 == item_ct1.get_group_range(2)) + ? n - (item_ct1.get_group(2) * NUM_BLOCK) + : NUM_BLOCK; + const int base_idx = (item_ct1.get_group(2) * NUM_BLOCK); + + float vals[NUM]; + unsigned char qvals[NUM]; + //const int lane_id = threadIdx.x % 2; + + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockStore StoreChar; + + //__shared__ float smem_code[2][257]; + + if (item_ct1.get_local_id(2) < 256) + { + smem_code[item_ct1.get_local_id(2)] = code[item_ct1.get_local_id(2)]; + //smem_code[0][threadIdx.x] = code[threadIdx.x]; + //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; + } + + for (unsigned int i = base_idx; i < n_full; + i += item_ct1.get_group_range(2) * NUM_BLOCK) + { + // number of values already processed in blocks + + // number of values already processed in this block + + // rand_offset % mod value + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + + /* + DPCT1065:89: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:160: Migration of cub::BlockLoad.Load is not supported. + */ + LoadFloat(loadf).Load(&(A[i]), vals, valid_items); + +#pragma unroll 4 + for(int j = 0; j < NUM; j++) + qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); + + /* + DPCT1065:90: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:161: Migration of cub::BlockStore.Store is not supported. + */ + StoreChar(storec).Store(&(out[i]), qvals, valid_items); + } +} + +template +//__launch_bounds__(TH, 4) +void kQuantizeBlockwise(float *code, T *__restrict__ const A, float *absmax, + unsigned char *out, float *__restrict__ const rand, + const int rand_offset, const int n, + const sycl::nd_item<3> &item_ct1 + float *smem_code, float *smem_absmax_value) +{ + const int n_full = item_ct1.get_group_range(2) * BLOCK_SIZE; + int valid_items = 0; + const int base_idx = (item_ct1.get_group(2) * BLOCK_SIZE); + + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; + //float local_abs_max = -FLT_MAX; + float local_abs_max = 0.0f; + int local_rand_idx = 0; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; + + typedef cub::BlockLoad LoadFloat; + + if(DATA_TYPE == General8bit) + for (int i = item_ct1.get_local_id(2); i < 256; + i += item_ct1.get_local_range(2)) + smem_code[i] = code[i]; + + for (unsigned int i = base_idx; i < n_full; + i += item_ct1.get_group_range(2) * BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_abs_max = -FLT_MAX; + + /* + DPCT1065:4: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:98: Migration of cub::BlockLoad.Load is not supported. + */ + LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); + + // 1. compute local max + // 2. broadcast local max + // 3. normalize inputs and quantize + + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + local_abs_max = sycl::fmax(local_abs_max, sycl::fabs((float)vals[j])); + + /* + DPCT1007:7: Migration of cub::Reduce is not supported. + */ + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, sycl::maximum<>(), + valid_items); + //local_abs_max = dpct::group::reduce(item_ct1, local_abs_max, sycl::maximum<>() + + if (item_ct1.get_local_id(2) == 0) + smem_absmax_value[0] = local_abs_max; + + /* + DPCT1065:5: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + if (item_ct1.get_local_id(2) == 0) + absmax[i/BLOCK_SIZE] = local_abs_max; + else + local_abs_max = smem_absmax_value[0]; + + sycl::group_barrier(item_ct1.get_sub_group()); + + local_abs_max = 1.0f/local_abs_max; + + if(STOCHASTIC) + { + local_rand_idx = ((item_ct1.get_group(2) * NUM_BLOCK) + + (item_ct1.get_local_id(2) * NUM) + rand_offset) % + (1024 - 4); + /* + DPCT1007:99: Migration of cub::BlockLoad.Load is not supported. + */ + LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); + } + + unsigned char packed_4bit = 0; + switch(DATA_TYPE) + { + case General8bit: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + } + + /* + DPCT1065:6: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:100: Migration of cub::BlockStore.Store is not supported. + */ + StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i / 2 : i]), qvals, + (DATA_TYPE > 0) ? (valid_items + 1) / 2 + : valid_items); + } +} + +template +void kDequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, + const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1) +{ + + const int n_load = (item_ct1.get_group_range(2) * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (item_ct1.get_group(2) * TILE_SIZE); + + T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; + unsigned char qvals[NUM_PER_TH]; + float local_abs_max = -FLT_MAX; + + typedef cub::BlockLoad LoadChar; + typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + + for (unsigned int i = base_idx; i < n_load; + i += item_ct1.get_group_range(2) * TILE_SIZE) + { + if(DATA_TYPE > 0) + { + valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; + valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; + } + else + { + valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; + valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; + } + /* + DPCT1026:101: The call to __ldg was removed because there is no + corresponding API in SYCL. + */ + local_abs_max = + absmax[(i + item_ct1.get_local_id(2) * NUM_PER_TH) / (blocksize)]; + + /* + DPCT1065:8: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:102: Migration of cub::BlockLoad.Load is not supported. + */ + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); + + switch(DATA_TYPE) + { + case General8bit: + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + /* + DPCT1026:103: The call to __ldg was removed because there is no + corresponding API in SYCL. + */ + vals[j] = code[qvals[j]] * local_abs_max; + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; + } + break; + } + + /* + DPCT1065:9: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:104: Migration of cub::BlockStore.Store is not supported. + */ + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i * 2 : i]), vals, + valid_items_store); + } +} + +void kDequantize(float *code, unsigned char *A, float *out, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_code) +{ + const unsigned int numThreads = + item_ct1.get_local_range(2) * item_ct1.get_group_range(2); + const int idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + + item_ct1.get_local_id(2); + + if (item_ct1.get_local_id(2) < 256) + { + smem_code[item_ct1.get_local_id(2)] = code[item_ct1.get_local_id(2)]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + for (int i = idx;i < n; i += numThreads) + { + out[i] = smem_code[A[i]]; + } +} + + + +template + +void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = + (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + float s2_vals[NUM_VALS]; + + const float correction1 = 1.0f / (1.0f - sycl::pow(beta1, step)); + const float correction2 = 1.0f / (1.0f - sycl::pow(beta2, step)); + + typedef cub::BlockLoad Load; + typedef cub::BlockLoad LoadFloat; + typedef sycl::group<3> BlockReduce; + + union type_ct2 { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + }; + type_ct2 &temp_storage = *(type_ct2 *)temp_storage_ct1; + + for (unsigned int i = base_idx; i < n_full; + i += item_ct1.get_group_range(2) * BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + /* + DPCT1065:10: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:105: Migration of cub::BlockLoad.Load is not supported. + */ + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + /* + DPCT1065:11: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:106: Migration of cub::BlockLoad.Load is not supported. + */ + LoadFloat(temp_storage.loadf) + .Load(&(state1[i]), s1_vals, valid_items, 0.0f); + /* + DPCT1065:12: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:107: Migration of cub::BlockLoad.Load is not supported. + */ + LoadFloat(temp_storage.loadf) + .Load(&(state2[i]), s2_vals, valid_items, 0.0f); + +# pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case ADAM: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + s1_vals[j] = + s1_vals[j] / (sycl::sqrt(s2_vals[j]) + eps); // update + s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) + break; + } + } + + # pragma unroll NUM_VALS-1 + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + /* + DPCT1065:13: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + s1_vals[0] = sycl::reduce_over_group(item_ct1.get_group(), s1_vals[0], + sycl::plus<>()); + + if (item_ct1.get_local_id(2) == 0) + dpct::atomic_fetch_add( + &unorm[0], s1_vals[0]); + + sycl::group_barrier(item_ct1.get_sub_group()); + } +} + + + +#define NUM_PER_THREAD 4 + +template + +void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = + (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + float s2_vals[NUM_PER_THREAD]; + + const float correction1 = 1.0f - sycl::pow(beta1, step); + const float correction2 = sycl::sqrt(1.0f - sycl::pow(beta2, step)); + const float step_size = -lr*correction2/correction1; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sycl::sqrt(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + typedef cub::BlockLoad Load; + typedef cub::BlockStore Store; + + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockStore StoreFloat; + + union type_ct3 { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + }; + type_ct3 &temp_storage = *(type_ct3 *)temp_storage_ct1; + + for (unsigned int i = base_idx; i < n_full; + i += item_ct1.get_group_range(2) * TH * NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + /* + DPCT1065:14: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:108: Migration of cub::BlockLoad.Load is not supported. + */ + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + /* + DPCT1065:15: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:109: Migration of cub::BlockLoad.Load is not supported. + */ + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + /* + DPCT1065:16: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:110: Migration of cub::BlockLoad.Load is not supported. + */ + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); + /* + DPCT1065:17: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:111: Migration of cub::BlockLoad.Load is not supported. + */ + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + +# pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + switch(OPTIMIZER) + { + case ADAM: + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + p_vals[j] = + ((float)p_vals + [j]) + + (update_scale * + step_size * + (s1_vals + [j] / + (sycl::sqrt( + s2_vals + [j]) + + (eps * + correction2)))); + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + break; + } + } + + /* + DPCT1065:18: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:112: Migration of cub::BlockStore.Store is not supported. + */ + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + /* + DPCT1065:19: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:113: Migration of cub::BlockStore.Store is not supported. + */ + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + /* + DPCT1065:20: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:114: Migration of cub::BlockStore.Store is not supported. + */ + StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); + } +} + +template + +void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = + (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + + typedef cub::BlockLoad Load; + typedef cub::BlockLoad LoadFloat; + typedef sycl::group<3> BlockReduce; + + union type_ct4 { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + }; + type_ct4 &temp_storage = *(type_ct4 *)temp_storage_ct1; + + for (unsigned int i = base_idx; i < n_full; + i += item_ct1.get_group_range(2) * BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + /* + DPCT1065:21: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:115: Migration of cub::BlockLoad.Load is not supported. + */ + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + /* + DPCT1065:22: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:116: Migration of cub::BlockLoad.Load is not supported. + */ + LoadFloat(temp_storage.loadf) + .Load(&(state1[i]), s1_vals, valid_items, 0.0f); + +# pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; // state update + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update + s1_vals[j] = (float)g_vals[j] / + (sycl::sqrt(s1_vals[j]) + eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update + s1_vals[j] = (float)g_vals[j] / + (sycl::sqrt(s1_vals[j]) + eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + } + } + + # pragma unroll + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + /* + DPCT1065:23: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:24: Migration of cub::Sum is not supported. + */ + s1_vals[0] = + BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); + + if (item_ct1.get_local_id(2) == 0) + dpct::atomic_fetch_add( + &unorm[0], s1_vals[0]); + + sycl::group_barrier(item_ct1.get_sub_group()); + } +} + +template + +void kOptimizer32bit1State(T *g, T *p, + float *state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = + (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sycl::sqrt(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + + typedef cub::BlockLoad Load; + typedef cub::BlockStore Store; + + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockStore StoreFloat; + + union type_ct5 { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + }; + type_ct5 &temp_storage = *(type_ct5 *)temp_storage_ct1; + + for (unsigned int i = base_idx; i < n_full; + i += item_ct1.get_group_range(2) * TH * NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + /* + DPCT1065:25: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:117: Migration of cub::BlockLoad.Load is not supported. + */ + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + /* + DPCT1065:26: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:118: Migration of cub::BlockLoad.Load is not supported. + */ + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + /* + DPCT1065:27: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:119: Migration of cub::BlockLoad.Load is not supported. + */ + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + +# pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + g_vals[j] = gnorm_scale*((float)g_vals[j]); + if(weight_decay > 0.0f) + g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); + p_vals[j] = + ((float)p_vals + [j]) - + update_scale * + (lr * + (float)g_vals + [j] / + (sycl::sqrt( + (float)s1_vals + [j]) + + eps)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); + p_vals[j] = + ((float)p_vals + [j]) - + lr * + (float)g_vals + [j] / + (sycl::sqrt( + (float)s1_vals + [j]) + + eps); + break; + } + } + } + + /* + DPCT1065:28: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:120: Migration of cub::BlockStore.Store is not supported. + */ + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + /* + DPCT1065:29: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:121: Migration of cub::BlockStore.Store is not supported. + */ + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + } +} + + +#define NUM8BIT 16 +#define NUM_THREADS 256 +#define NUM_PER_BLOCK 4096 + +template +void + +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, + float *smem_quantiles1, float *smem_quantiles2) +{ + const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; + const int base_idx = + (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); + int valid_items = + n - (item_ct1.get_group(2) * NUM_PER_BLOCK) > NUM_PER_BLOCK + ? NUM_PER_BLOCK + : n - (item_ct1.get_group(2) * NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_max_s2 = -FLT_MAX; + float local_unorm = 0.0f; + + float s2_vals[NUM8BIT]; + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + unsigned char r_c2[NUM8BIT]; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadUInt8; + typedef sycl::group<3> BlockReduce; + + union type_ct6 { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + }; + type_ct6 &temp_storage = *(type_ct6 *)temp_storage_ct1; + + if (item_ct1.get_local_id(2) < 256) + { + smem_quantiles1[item_ct1.get_local_id(2)] = + quantiles1[item_ct1.get_local_id(2)]; + smem_quantiles2[item_ct1.get_local_id(2)] = + quantiles2[item_ct1.get_local_id(2)]; + } + + /* + DPCT1065:42: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + for (unsigned int i = base_idx; i < n_full; + i += NUM_THREADS * item_ct1.get_group_range(2) * NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + /* + DPCT1007:129: Migration of cub::BlockLoad.Load is not supported. + */ + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + /* + DPCT1065:45: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:130: Migration of cub::BlockLoad.Load is not supported. + */ + LoadUInt8(temp_storage.loadc) + .Load(&(state1[i]), m_c1, valid_items, 128); + /* + DPCT1065:46: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:131: Migration of cub::BlockLoad.Load is not supported. + */ + LoadUInt8(temp_storage.loadc) + .Load(&(state2[i]), r_c2, valid_items, 128); + /* + DPCT1065:47: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + +#pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1; + s1_vals[j] += (1.0f-beta1)*g_val; + local_max_s1 = sycl::fmax(local_max_s1, sycl::fabs(s1_vals[j])); + } + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2; + s2_vals[j] += (1.0f-beta2)*g_val*g_val; + local_max_s2 = sycl::fmax(local_max_s2, sycl::fabs(s2_vals[j])); + } + + if(unorm != NULL) + { + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + float correction1 = 1.0f / (1.0f - sycl::pow(beta1, step)); + float correction2 = 1.0f / (1.0f - sycl::pow(beta2, step)); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + float update_val = + s1_vals[j] / (sycl::sqrt(s2_vals[j]) + eps); // update + local_unorm += update_val*update_val; + } + } + } + + /* + DPCT1065:43: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:48: Migration of cub::Reduce is not supported. + */ + local_max_s1 = BlockReduce(temp_storage.reduce) + .Reduce(local_max_s1, sycl::maximum<>(), valid_items); + /* + DPCT1065:44: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:49: Migration of cub::Reduce is not supported. + */ + local_max_s2 = BlockReduce(temp_storage.reduce) + .Reduce(local_max_s2, sycl::maximum<>(), valid_items); + if(unorm != NULL) + { + /* + DPCT1065:50: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:51: Migration of cub::Reduce is not supported. + */ + local_unorm = BlockReduce(temp_storage.reduce) + .Reduce(local_unorm, sycl::plus<>(), valid_items); + } + + if (item_ct1.get_local_id(2) == 0) + { + atomicMax(&new_max1[0], local_max_s1); + atomicMax(&new_max2[0], local_max_s2); + if (unorm != NULL) { + dpct::atomic_fetch_add( + &unorm[0], local_unorm); + } + } +} + +#define NUM_PER_THREAD2 4 +#define NUM_THREADS2 1024 +#define NUM_PER_BLOCK2 4096 + +template +void + +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, + float *smem_quantiles2, uint8_t *temp_storage_ct1) +{ + + const int n_full = + (item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) * + NUM_PER_THREAD2; + const int base_idx = + (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float s2_vals[NUM_PER_THREAD2]; + const float correction1 = 1.0f - sycl::pow(beta1, step); + const float correction2 = sycl::sqrt(1.0f - sycl::pow(beta2, step)); + const float step_size = -lr*correction2/correction1; + //const float step_size = -lr*correction2/correction1; + float new_max_val1 = 1.0f/new_max1[0]; + float new_max_val2 = 1.0f/new_max2[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sycl::sqrt((float)(unorm[0])) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + unsigned char c2s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; + + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; + + union type_ct7 { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + }; + type_ct7 &temp_storage = *(type_ct7 *)temp_storage_ct1; + + if (item_ct1.get_local_id(2) < 512) + { + if (item_ct1.get_local_id(2) < 256) + smem_quantiles1[item_ct1.get_local_id(2)] = + quantiles1[item_ct1.get_local_id(2)]; + else + smem_quantiles2[item_ct1.get_local_id(2) - 256] = + quantiles2[item_ct1.get_local_id(2) - 256]; + } + + /* + DPCT1065:52: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + for (unsigned int i = base_idx; i < n_full; + i += item_ct1.get_group_range(2) * NUM_THREADS2 * NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + /* + DPCT1007:132: Migration of cub::BlockLoad.Load is not supported. + */ + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + /* + DPCT1065:53: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:133: Migration of cub::BlockLoad.Load is not supported. + */ + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + /* + DPCT1065:54: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:134: Migration of cub::BlockLoad.Load is not supported. + */ + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + /* + DPCT1065:55: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:135: Migration of cub::BlockLoad.Load is not supported. + */ + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if ((i + (item_ct1.get_local_id(2) * NUM_PER_THREAD2) + + NUM_PER_THREAD2) > n) { + continue; + } + +# pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[c1s[j]]; + s1_vals[j] = s1_vals[j]*max1[0]; + + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if (sycl::signbit(smem_quantiles1[c1s[j]]) != + sycl::signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + s2_vals[j] = smem_quantiles2[c2s[j]]; + s2_vals[j] = s2_vals[j]*max2[0]; + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + p_vals[j] = (T)(((float)p_vals[j]) + + ((update_scale * step_size * + (s1_vals[j] / (sycl::sqrt(s2_vals[j]) + + (correction2 * eps)))))); + if(weight_decay > 0.0f) + p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + + /* + DPCT1007:136: Migration of cub::BlockStore.Store is not supported. + */ + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + /* + DPCT1065:56: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:137: Migration of cub::BlockStore.Store is not supported. + */ + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + /* + DPCT1065:57: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:138: Migration of cub::BlockStore.Store is not supported. + */ + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + /* + DPCT1065:58: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + } +} + + +template +void + +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, + float *smem_quantiles1) +{ + const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; + const int base_idx = + (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); + int valid_items = + n - (item_ct1.get_group(2) * NUM_PER_BLOCK) > NUM_PER_BLOCK + ? NUM_PER_BLOCK + : n - (item_ct1.get_group(2) * NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_unorm = 0.0f; + + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadUInt8; + typedef sycl::group<3> BlockReduce; + + union type_ct8 { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + }; + type_ct8 &temp_storage = *(type_ct8 *)temp_storage_ct1; + + if (item_ct1.get_local_id(2) < 256) + smem_quantiles1[item_ct1.get_local_id(2)] = + quantiles1[item_ct1.get_local_id(2)]; + + /* + DPCT1065:30: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + for (unsigned int i = base_idx; i < n_full; + i += item_ct1.get_group_range(2) * NUM_THREADS * NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + /* + DPCT1065:32: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:122: Migration of cub::BlockLoad.Load is not supported. + */ + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + /* + DPCT1065:33: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:123: Migration of cub::BlockLoad.Load is not supported. + */ + LoadUInt8(temp_storage.loadc) + .Load(&(state1[i]), m_c1, valid_items, 128); + +#pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + if(unorm != NULL) + local_unorm += s1_vals[j]*s1_vals[j]; + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + } + + local_max_s1 = sycl::fmax(local_max_s1, sycl::fabs(s1_vals[j])); + } + } + + /* + DPCT1065:31: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:34: Migration of cub::Reduce is not supported. + */ + local_max_s1 = BlockReduce(temp_storage.reduce) + .Reduce(local_max_s1, sycl::maximum<>(), valid_items); + if (item_ct1.get_local_id(2) == 0) { + atomicMax(&new_max1[0], local_max_s1); + } + if(unorm != NULL) + { + /* + DPCT1065:35: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:36: Migration of cub::Reduce is not supported. + */ + local_unorm = BlockReduce(temp_storage.reduce) + .Reduce(local_unorm, sycl::plus<>(), valid_items); + if (item_ct1.get_local_id(2) == 0) { + dpct::atomic_fetch_add( + &unorm[0], local_unorm); + } + } + +} + +template +void + +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, + const float gnorm_scale, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, + uint8_t *temp_storage_ct1) +{ + + const int n_full = + (item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) * + NUM_PER_THREAD2; + const int base_idx = + (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float new_max_val1 = 1.0f/new_max1[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sycl::sqrt((float)(unorm[0])) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; + + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; + + union type_ct9 { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + }; + type_ct9 &temp_storage = *(type_ct9 *)temp_storage_ct1; + + if (item_ct1.get_local_id(2) < 256) + smem_quantiles1[item_ct1.get_local_id(2)] = + quantiles1[item_ct1.get_local_id(2)]; + + /* + DPCT1065:37: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + for (unsigned int i = base_idx; i < n_full; + i += item_ct1.get_group_range(2) * NUM_THREADS2 * NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + /* + DPCT1007:124: Migration of cub::BlockLoad.Load is not supported. + */ + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + /* + DPCT1065:38: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:125: Migration of cub::BlockLoad.Load is not supported. + */ + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + /* + DPCT1065:39: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:126: Migration of cub::BlockLoad.Load is not supported. + */ + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if ((i + (item_ct1.get_local_id(2) * NUM_PER_THREAD2) + + NUM_PER_THREAD2) > n) { + continue; + } + +# pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + p_vals[j] = ((float)p_vals[j]) - + (lr * g_val / (sycl::sqrt(s1_vals[j]) + eps)); + break; + } + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + if (sycl::signbit(smem_quantiles1[c1s[j]]) != + sycl::signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + /* + DPCT1007:127: Migration of cub::BlockStore.Store is not supported. + */ + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + /* + DPCT1065:40: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:128: Migration of cub::BlockStore.Store is not supported. + */ + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + /* + DPCT1065:41: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + } +} + +template +void kPercentileClipping(T *__restrict__ g, float *gnorm_vec, int step, + const int n, const sycl::nd_item<3> &item_ct1) +{ + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + int valid_items = 0; + + typedef cub::BlockLoad LoadT; + typename BlockReduce::TempStorage &reduce; + + T vals[NUM_VALS]; + float local_sum = 0.0f; + + for (unsigned int i = (item_ct1.get_group(2) * BLOCK_SIZE); i < n_full; + i += item_ct1.get_group_range(2) * BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_sum = 0.0f; + + /* + DPCT1065:75: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:151: Migration of cub::BlockLoad.Load is not supported. + */ + LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); + +#pragma unroll NUM_VALS + for(int j = 0; j < NUM_VALS; j++) + local_sum += ((float)vals[j])*((float)vals[j]); + + /* + DPCT1007:76: Migration of cub::Sum is not supported. + */ + local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); + if (item_ct1.get_local_id(2) == 0) + { + if(step == 1) + { + // initialize with the same norm for all positions + //#pragma unroll 10 + for(int j = 0; j < 100; j++) + dpct::atomic_fetch_add( + &gnorm_vec[j], local_sum); + } + else + dpct::atomic_fetch_add( + &gnorm_vec[step % 100], local_sum); + } + + } +} + + +#define LANES 2 +#define QUAD 3 +template + +void kOptimizerStatic8bit2StateBlockwise( + T *p, T *__restrict__ const g, unsigned char *state1, unsigned char *state2, + const float beta1, const float beta2, const float eps, const int step, + const float lr, float *__restrict__ const quantiles1, + float *__restrict__ const quantiles2, float *absmax1, float *absmax2, + float weight_decay, const float gnorm_scale, const bool skip_zeros, + const int n, const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + sycl::local_accessor smem_quantiles2, + float *smem_exchange1, + float *smem_exchange2, uint8_t *temp_storage_ct1) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = item_ct1.get_group_range(2) * BLOCK_SIZE; + const int base_idx = (item_ct1.get_group(2) * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + float s2_vals[N_PER_TH]; + // 2-5% + const float correction1 = 1.0f - sycl::pow(beta1, step); + const float correction2 = sycl::sqrt(1.0f - sycl::pow(beta2, step)); + const float step_size = (-lr * correction2) / correction1; + const int lane_id = item_ct1.get_local_id(2) % LANES; + float new_local_abs_max1 = -FLT_MAX; + float new_local_abs_max2 = -FLT_MAX; + float quadrants1[QUAD]; + float quadrants2[QUAD]; + + unsigned char c1s[N_PER_TH]; + unsigned char c2s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; + + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; + + + //__shared__ float smem_quantiles1[LANES][257]; + //__shared__ float smem_quantiles2[LANES][257]; + typedef cub::BlockReduce BlockReduce1; + typedef cub::BlockReduce BlockReduce2; + typename BlockReduce1::TempStorage reduce1; + typename BlockReduce2::TempStorage reduce2; + + + union type_ct10 { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + }; + type_ct10 &temp_storage = *(type_ct10 *)temp_storage_ct1; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][item_ct1.get_local_id(2)] = + quantiles1[item_ct1.get_local_id(2)]; + smem_quantiles2[0][item_ct1.get_local_id(2)] = + quantiles2[item_ct1.get_local_id(2)]; +# pragma unroll + for(unsigned int j = 1; j < LANES; j++) + { + smem_quantiles1[j][item_ct1.get_local_id(2)] = + smem_quantiles1[0][item_ct1.get_local_id(2)]; + smem_quantiles2[j][item_ct1.get_local_id(2)] = + smem_quantiles2[0][item_ct1.get_local_id(2)]; + } + + /* + DPCT1065:59: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + +#pragma unroll + for(int k = 0; k < QUAD; k++) + { + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + } + + for (unsigned int i = base_idx; i < n_full; + i += item_ct1.get_group_range(2) * BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + /* + DPCT1065:60: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:139: Migration of cub::BlockLoad.Load is not supported. + */ + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + /* + DPCT1065:61: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:140: Migration of cub::BlockLoad.Load is not supported. + */ + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + /* + DPCT1065:62: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:141: Migration of cub::BlockLoad.Load is not supported. + */ + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + + new_local_abs_max1 = -FLT_MAX; + new_local_abs_max2 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if (!sycl::isnan((float)g_vals[j]) && + !sycl::isinf((float)g_vals[j])) + { + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; + g_val = g_vals[j]; + //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + } + else + { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; + } + + new_local_abs_max1 = + sycl::fmax(new_local_abs_max1, sycl::fabs(s1_vals[j])); + new_local_abs_max2 = + sycl::fmax(new_local_abs_max2, sycl::fabs(s2_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = sycl::reduce_over_group( + item_ct1.get_group(), new_local_abs_max1, sycl::maximum<>()); + new_local_abs_max2 = sycl::reduce_over_group( + item_ct1.get_group(), new_local_abs_max2, sycl::maximum<>()); + + if (item_ct1.get_local_id(2) == 0) + { + smem_exchange1[0] = new_local_abs_max1; + smem_exchange2[0] = new_local_abs_max2; + } + + /* + DPCT1065:63: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + + if (item_ct1.get_local_id(2) == 0) + { + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + absmax2[i/BLOCK_SIZE] = new_local_abs_max2; + } + else + { + new_local_abs_max1 = smem_exchange1[0]; + new_local_abs_max2 = smem_exchange2[0]; + } + + /* + DPCT1065:64: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:142: Migration of cub::BlockLoad.Load is not supported. + */ + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if (!sycl::isnan((float)g_vals[j]) && + !sycl::isinf((float)g_vals[j])) + { + p_vals[j] = + (T)(((float)p_vals + [j]) + + ((step_size * + (s1_vals[j] / + (sycl::sqrt( + s2_vals + [j]) + + (correction2 * + eps)))))); + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + /* + DPCT1065:65: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:143: Migration of cub::BlockStore.Store is not supported. + */ + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], + s1_vals[j] / new_local_abs_max1); + c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], + s2_vals[j] / new_local_abs_max2); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if (sycl::signbit(smem_quantiles1[lane_id][c1s[j]]) != + sycl::signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + /* + DPCT1065:66: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:144: Migration of cub::BlockStore.Store is not supported. + */ + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + /* + DPCT1065:67: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:145: Migration of cub::BlockStore.Store is not supported. + */ + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + } +} + + +#define LANES 2 +#define QUAD 3 +template + +void kOptimizerStatic8bit1StateBlockwise( + T *p, T *__restrict__ const g, unsigned char *state1, const float beta1, + const float beta2, const float eps, const int step, const float lr, + float *__restrict__ const quantiles1, float *absmax1, float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + float *smem_exchange1, + uint8_t *temp_storage_ct1) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = item_ct1.get_group_range(2) * BLOCK_SIZE; + const int base_idx = (item_ct1.get_group(2) * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + // 2-5% + const int lane_id = item_ct1.get_local_id(2) % LANES; + float new_local_abs_max1 = -FLT_MAX; + float quadrants1[QUAD]; + + unsigned char c1s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; + + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; + + //__shared__ float smem_quantiles1[LANES][257]; + typedef cub::BlockReduce BlockReduce1; + __shared__ typename BlockReduce1::TempStorage reduce1; + //__shared__ float smem_exchange1[1]; + + union type_ct11 { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + }; + type_ct11 &temp_storage = *(type_ct11 *)temp_storage_ct1; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][item_ct1.get_local_id(2)] = + quantiles1[item_ct1.get_local_id(2)]; +# pragma unroll + for(unsigned int j = 1; j < LANES; j++) + smem_quantiles1[j][item_ct1.get_local_id(2)] = + smem_quantiles1[0][item_ct1.get_local_id(2)]; + + /* + DPCT1065:68: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + +#pragma unroll + for(int k = 0; k < QUAD; k++) + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + + for (unsigned int i = base_idx; i < n_full; + i += item_ct1.get_group_range(2) * BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + /* + DPCT1065:69: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:146: Migration of cub::BlockLoad.Load is not supported. + */ + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + /* + DPCT1065:70: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:147: Migration of cub::BlockLoad.Load is not supported. + */ + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + /* + DPCT1065:71: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:148: Migration of cub::BlockLoad.Load is not supported. + */ + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + + new_local_abs_max1 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_val; + else + s1_vals[j] = (s1_vals[j]*beta1) + g_val; + break; + case LION: + // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + (g_val*g_val); + break; + } + } + + new_local_abs_max1 = + sycl::fmax(new_local_abs_max1, sycl::fabs(s1_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = sycl::reduce_over_group( + item_ct1.get_group(), new_local_abs_max1, sycl::maximum<>()); + + if (item_ct1.get_local_id(2) == 0) + smem_exchange1[0] = new_local_abs_max1; + + /* + DPCT1065:72: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + + if (item_ct1.get_local_id(2) == 0) + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + else + new_local_abs_max1 = smem_exchange1[0]; + + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + break; + case RMSPROP: + g_val = g_vals[j]; + p_vals[j] = + ((float)p_vals + [j]) - + lr * + (g_val / + (sycl::sqrt( + s1_vals + [j]) + + eps)); + break; + case ADAGRAD: + g_val = g_vals[j]; + p_vals[j] = + ((float)p_vals + [j]) - + lr * + (g_val / + (sycl::sqrt( + s1_vals + [j]) + + eps)); + break; + } + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + /* + DPCT1065:73: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:149: Migration of cub::BlockStore.Store is not supported. + */ + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], + s1_vals[j] / new_local_abs_max1); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if (sycl::signbit(smem_quantiles1[lane_id][c1s[j]]) != + sycl::signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + /* + DPCT1065:74: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:150: Migration of cub::BlockStore.Store is not supported. + */ + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + } +} + +template void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, + const sycl::nd_item<3> &item_ct1, + uint8_t *temp_storage_ct1, + float *smem_row_absmax_values, + int *smem_row_nnz_values) +{ + // 0. reset stats to -FLT_MAX + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + // 2. compute col max (per thread); store in smem due to register pressure + // 3. compute row max (per block); store in smem to accumulate full global mem transation + // 4. store data via atomicMax + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = + ((item_ct1.get_group(2) * TILE_COLS) / tiledCols) * TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (item_ct1.get_group(2) * TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef cub::BlockLoad LoadT; + typedef sycl::group<3> BlockRowReduce; + typedef sycl::group<3> BlockRowSum; + typedef cub::BlockExchange BlockExchange; + + union type_ct12 { + typename BlockExchange::TempStorage exchange; + typename BlockRowReduce::TempStorage rowreduce; + typename BlockRowSum::TempStorage rowsum; + typename LoadT::TempStorage loadt; + }; + type_ct12 &temp_storage = *(type_ct12 *)temp_storage_ct1; + + sycl::half local_data[ITEMS_PER_THREAD]; + float local_data_fp32[ITEMS_PER_THREAD]; + float local_col_absmax_values[ITEMS_PER_THREAD]; + int local_row_nnz_count = 0; + float row_absmax = -FLT_MAX; + + // 0. reset stats to -FLT_MAX + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + smem_row_absmax_values[item_ct1.get_local_id(2) + (j * THREADS)] = -FLT_MAX; + smem_row_nnz_values[item_ct1.get_local_id(2) + (j * THREADS)] = 0; + } + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_col_absmax_values[j] = -FLT_MAX; + + /* + DPCT1065:79: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + int i = base_idx; + // we load row after row from the base_position + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row+row >= rows){ break; } + local_row_nnz_count = 0; + i = base_idx + ((row)*cols); + // each thread gets data from the same column + /* + DPCT1065:81: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:154: Migration of cub::BlockLoad.Load is not supported. + */ + LoadT(temp_storage.loadt) + .Load(&(A[i]), local_data, valid_items, + sycl::vec{0.0f} + .convert()[0]); + +#pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = sycl::fabs(local_data[j]); + + if(SPARSE_DECOMP) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + if((float)local_data[j] >= nnz_threshold) + { + local_row_nnz_count += 1; + local_data[j] = 0.0f; + } + } + + // 2. compute col max (per thread); store in smem due to register pressure + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + // take the col max for this row + // we use shared memory because register pressure is too high if we do this locally + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j])); + local_col_absmax_values[j] = + sycl::fmax(local_col_absmax_values[j], + sycl::vec{local_data[j]} + .convert()[0]); + + // 3. compute row max (per block); store in smem to accumulate full global mem transation + + // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data_fp32[j] = local_data[j]; + + /* + DPCT1065:82: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + row_absmax = (float)dpct::group::reduce(item_ct1, local_data_fp32, + sycl::maximum<>()); + if(SPARSE_DECOMP) + { + /* + DPCT1065:84: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + local_row_nnz_count = sycl::reduce_over_group( + item_ct1.get_group(), local_row_nnz_count, sycl::plus<>()); + } + // we store the data temporarily in shared memory so we + // can execute a full atomic block transaction into global memory later + // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores + if (item_ct1.get_local_id(2) == 0) + { + smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax; + // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block + smem_row_nnz_values[row] = local_row_nnz_count; + } + + /* + DPCT1065:83: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + } + + // 4. store data via atomicMax + // to store col data efficienctly we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0 + // into a striped arangement: [0, 8, 16, 24, ..] for t0 + /* + DPCT1065:80: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:155: Migration of cub::BlockExchange.BlockedToStriped is not + supported. + */ + BlockExchange(temp_storage.exchange) + .BlockedToStriped(local_col_absmax_values); + +#pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if (base_col + item_ct1.get_local_id(2) + (j * THREADS) < cols) + { + float val = + colStats[base_col + (item_ct1.get_local_id(2) + (j * THREADS))]; + if(val < local_col_absmax_values[j]) + atomicMax( + &colStats[base_col + (item_ct1.get_local_id(2) + (j * THREADS))], + local_col_absmax_values[j]); + } + + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if (base_row + item_ct1.get_local_id(2) + (j * THREADS) < rows) + { + float val = + rowStats[base_row + (item_ct1.get_local_id(2) + (j * THREADS))]; + if (val < + smem_row_absmax_values[item_ct1.get_local_id(2) + (j * THREADS)]) + atomicMax( + &rowStats[base_row + (item_ct1.get_local_id(2) + (j * THREADS))], + smem_row_absmax_values[item_ct1.get_local_id(2) + (j * THREADS)]); + } + + if(SPARSE_DECOMP) + if (item_ct1.get_local_id(2) < TILE_ROWS) + nnz_count_row[item_ct1.get_group(2) * TILE_ROWS + + item_ct1.get_local_id(2) + 1] = + smem_row_nnz_values[item_ct1.get_local_id(2)]; +} + +template void kgetColRowStats( + sycl::half *__restrict__ A, float *rowStats, float *colStats, + int *nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, + int tiledCols, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, + float *smem_row_absmax_values, int *smem_row_nnz_values); +template void kgetColRowStats( + sycl::half *__restrict__ A, float *rowStats, float *colStats, + int *nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, + int tiledCols, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, + float *smem_row_absmax_values, int *smem_row_nnz_values); + +#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) + +template +void kdequant_mm_int32_fp16( + int *__restrict__ const A, float *__restrict__ const rowStats, + float *__restrict__ const colStats, sycl::half *out, float *newRowStats, + float *newcolStats, sycl::half *__restrict__ const bias, const int numRows, + const int numCols, const int tileCols, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_rowStats) +{ + + // Strategy: To dequantize we need to load col/row statistics. This can be very expensive + // since different row/col stats need to be loaded with each thread. + // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure + // and would lead to low global load utilization. + // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads + // for each thread and this is duplicated by a factor of 32/num-cols-per-thread. + // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock. + // This allows for efficient row/col loading from shared memory within the tile. + // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has + // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts + // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the + // shared memory loads. + + // data is in 32 column-tile major with tile width 32 columns and numRows rows + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) + // C2. Compute normalization values and store col values in register + // S1. Store C1 into 16-bit output + // S2. Store col/row statistics of new buffer in shared memory + + // We allow for sub-tiles to span multiple col32 tiles. This is okay + // since the items per thread only rely on a single column statistic. + + + const int n_out = numRows*numCols; + + int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); + // we have tiles of size numRows*32, thus col only increases every numRows + // num_row_tiles is the tiles after which the column increases by 32 + // blockIdx.x is the index of the current tile + int col = ((item_ct1.get_local_id(2) % 32) + + ((item_ct1.get_group(2) / num_row_tiles) * 32)); + // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached + int base_row = + (item_ct1.get_group(2) * SUBTILE_ROWS) % (num_row_tiles * SUBTILE_ROWS); + + // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS + // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD + // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads. + // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have + // 1024*1024/(128*32) = 256 tiles + // 256 tiles are 256*128*32/4 = 256*1024 threads + + // 1. Figure out how index relates to the start of the sub-tile + // 2. Each thread < SUBTILE_ROWS calculates row index + // 3. Load striped and store in shared memory + + int local_values[ITEMS_PER_THREAD]; + sycl::half local_output[ITEMS_PER_THREAD]; + float local_rowStats[ITEMS_PER_THREAD]; + + typedef cub::BlockLoad LoadInt32; + typedef cub::BlockExchange ExchangeInt32; + + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + float colStat = col >= numCols ? 0.0f : colStats[col]; + float local_biasValue = + ((bias == NULL) || (col >= numCols)) + ? 0.0f + : sycl::vec{bias[col]} + .convert()[0]; + // no block loads for rows for now -- keep it simple + for (int j = item_ct1.get_local_id(2); j < SUBTILE_ROWS; + j += item_ct1.get_local_range(2)) + { + // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? + int row = (base_row+j) % numRows; // wrap around + // each warp accesses the same element, for four consequitive elements + // todo: update description about striped shared memory, it is not needed + // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements + smem_rowStats[j] = rowStats[row]; + } + /* + DPCT1065:78: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + // each block processes SUBTILE_ROWS*32 elements + const int items_per_load = THREADS*ITEMS_PER_THREAD; + const int rows_per_load = items_per_load/32; + + int subtile_base_row = + (item_ct1.get_local_id(2) / 32) * ITEMS_PER_THREAD; // row within the tile + int row_offset = 0; + // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed + int subtile_start = (item_ct1.get_group(2) / num_row_tiles) * (numRows * 32) + + (base_row * 32); + for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load) + { + int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset); + int valid_items = valid_rows*32; + if(valid_items <= 0) // the sub-tile might have more elements than the tile itself + break; + + // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + /* + DPCT1007:152: Migration of cub::BlockLoad.Load is not supported. + */ + LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); + /* + DPCT1007:153: Migration of cub::BlockExchange.BlockedToWarpStriped is not + supported. + */ + ExchangeInt32(exchangeint32) + .BlockedToWarpStriped(local_values, local_values); + +#pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_output[j] = + sycl::vec{((local_values[j] * MM_DEQUANT_CONST * + local_rowStats[j] * colStat) + + local_biasValue)} + .convert()[0]; + //absmax_col = fmax(fabsf(local_output[j]), absmax_col); + + // we store data in row major + // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3] + // so that each thread holds ITEMS_PER_THREAD consecutive items for each row + // this way throughput into storage is increased by a factor of ~2x + // for now we use a simple store + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols); + if(outIdx< n_out && col < numCols) + out[outIdx] = local_output[j]; + } + + row_offset += rows_per_load; + } +} + +template +void kDoubleRowColQuant(sycl::half *__restrict__ const A, + float *__restrict__ const rowStats, + float *__restrict__ const colStats, + char *out_col_normed, char *out_row_normed, int *rowidx, + int *colidx, sycl::half *val, + int *__restrict__ nnz_block_ptr, float threshold, + int rows, int cols, int tiledCols, + const sycl::nd_item<3> &item_ct1, + float *smem_row_stats, unsigned int *smem_nnz_row_idx) +{ + // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD + // Each thread reads the same column but multiple rows + // Rows are loaded in shared memory and access is shared across the threadblock (broadcast) + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + // 2. quantize data with row/col stats + // 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance) + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = + ((item_ct1.get_group(2) * TILE_COLS) / tiledCols) * TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (item_ct1.get_group(2) * TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef cub::BlockLoad + LoadHalf; + + typedef cub::BlockStore StoreInt8; + + sycl::half local_data[ITEMS_PER_THREAD]; + float local_col_stats[ITEMS_PER_THREAD]; + char local_quantized_data[ITEMS_PER_THREAD]; + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if (base_col + (item_ct1.get_local_id(2) * ITEMS_PER_THREAD) + j < cols) + local_col_stats[j] = + 127.0f / colStats[base_col + + (item_ct1.get_local_id(2) * ITEMS_PER_THREAD) + j]; + + for (int i = item_ct1.get_local_id(2); i < TILE_ROWS; + i += item_ct1.get_local_range(2)) + { + if(base_row + i < rows) + smem_row_stats[i] = rowStats[base_row+i]; + + if(SPARSE_DECOMP) + smem_nnz_row_idx[i] = + nnz_block_ptr[(TILE_ROWS * item_ct1.get_group(2)) + i]; + } + /* + DPCT1065:85: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + // we load row after row from the base_position + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row + row >= rows){ break; } + int i = base_idx + (row*cols); + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + + /* + DPCT1007:156: Migration of cub::BlockLoad.Load is not supported. + */ + LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f); + float row_stat = 127.0f / smem_row_stats[row]; + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + if(SPARSE_DECOMP) + { + if (sycl::fabs((float)local_data[j]) >= threshold) + { + local_quantized_data[j] = 0; + + int old_idx = + dpct::atomic_fetch_compare_inc< + sycl::access::address_space:: + generic_space>( + &smem_nnz_row_idx[row], + UINT_MAX); + + rowidx[old_idx] = base_row+row; + colidx[old_idx] = + base_col + (item_ct1.get_local_id(2) * ITEMS_PER_THREAD) + j; + val[old_idx] = local_data[j]; + } + else + { + local_quantized_data[j] = + (char)(sycl::rint( + sycl::vec{ + local_data[j]} + .convert< + float, + sycl::rounding_mode:: + automatic>()[0] * + row_stat)); + } + } + else + local_quantized_data[j] = (char)(sycl::rint( + sycl::vec{local_data[j]} + .convert()[0] * + row_stat)); + } + + /* + DPCT1007:157: Migration of cub::BlockStore.Store is not supported. + */ + StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, + valid_items); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + local_quantized_data[j] = (char)(sycl::rint( + sycl::vec{local_data[j]} + .convert()[0] * + local_col_stats[j])); + } + + /* + DPCT1065:86: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:158: Migration of cub::BlockStore.Store is not supported. + */ + StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, + valid_items); + } +} + +template void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1,char *smem_data) +{ + + // 0. Load data into 32*32 shared memory tiles + // 1. transpose / reorder in shared memory + // 2. store + + // COL32 FORMAT: + // rows*32 tiles + + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + + + // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values + // As such we need: + // at least 32*4 shared memory tiles for col32; preferably 32*32 + // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32 + // at least 32*8 shared memory tiles for col4_turing: preferably 32*32 + // for efficient loading of row major we need to load 128 elements and repeat this 32 items + // this would imply a 32x128 shared memory tile -> 4kb + // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb + // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy + // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough + // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM + // + // to make the shared memory work with that occupancy we might need to union the block loads/stores + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = + ((item_ct1.get_group(2) * TILE_COLS) / tiledCols) * TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (item_ct1.get_group(2) * TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + + // we load 128 bytes per warp with + // 32 rows for transposes that fill col32 types + // so that we can have contiguous stores + + char local_data[ITEMS_PER_THREAD]; + typedef cub::BlockExchange BlockExchange; + + // we load row after row from the base_position + // Load data row by row + int warps = item_ct1.get_local_range(2) / 32; + int warp_id = item_ct1.get_local_id(2) / 32; + int warp_lane = item_ct1.get_local_id(2) % 32; + int offset = 0; + + int smem_row = 0; + // each warp loads one row of 128 bytes + for(int row = warp_id; row < TILE_ROWS; row+=warps) + { + int i = base_idx + (row*cols); + // we load up to 128 bytes/items per load + int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col; + + // 0. Load data into 32*32 shared memory tiles + if(base_row + row < rows) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int col_idx = warp_lane+(j*32); + if(col_idx < valid_items) + local_data[j] = A[i+col_idx]; + else + local_data[j] = 0; + } + } + else + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = 0; + } + + if(TRANSPOSE) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int local_col = (32*j)+warp_lane; + //int local_row = row; + // store as 256x32 + smem_data[(local_col*33) + row] = local_data[j]; + } + } + else + { + // treat smem as 32x256, that is 32 rows and 256 columns + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j]; + } + + + + smem_row += warps; + + // 1. transpose / reorder in shared memory + if(smem_row % 32 == 0) + { + smem_row = 0; + /* + DPCT1065:87: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + for(int subrow = warp_id; subrow < 32; subrow+=warps) + { + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + + switch(FORMAT) + { + case COL32: + if(TRANSPOSE) + { + // data lies in shared memory in the following way: + // row0 [col0 col1 ... col31] + // row1 [col0 col1 ... col31] + // ... + // + // As such we read consequtive entries with 256 threads (8rows x 32 columns) + // as j increase, the row increase by a factor of 8 + // We load 8 rows per subrow loop, and subrow increase by 8 per loop + // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8 + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size outRows*32 and base_row is done in increments of 32 + offset = base_row*outRows; + out[offset + (base_col + jrow + subrow_loop_row) * 32 + + item_ct1.get_local_id(2)] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + offset = (base_col/32)*(32*rows); + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data; + } + } + break; + case COL_TURING: + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // + // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 8*32 = 256 elements offset + // for each row offset of 8 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 256*outRows/8*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + // since we process even number of rows with each j (8) and with each subrow (8j) we can determine + // odd or even rows with the warp_id (each warp processes one row) + // the col is warp_lane (max 32 columns per row) and the row warp_id + if(warp_id % 2 == 1) + // odd + offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2); + else + // even + offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2); + + out[offset] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + // set offset designates the tile offset among the 8*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 8*32=256 every 8 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile) + // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd + // each of these has 32 values in total for 32*4 = 128 as offset if odd + // every set of 4 columns increases the total offset by 16 + // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2 + // this happends every 8 rows anew (subrow % 8) + // one writes 4 columns at once that is (col % 4) for the particular index in the subtile + int subcol = warp_lane; + + // add local offset (4x4 sub-tile) + if(subrow % 2 == 1) + // odd + offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2); + else + // even + offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2); + + out[offset] = data; + } + } + break; + case COL_AMPERE: + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 32*32 = 1024 elements offset + // for each row offset of 32 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 1024*outRows/32*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + + // same as in the non-transpose case (see below) + // the difference is that now rows = cols + // in this case warp_id = subrow + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset + int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane + out[offset + (ampere_row*32) + warp_lane] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + + // set offset designates the tile offset among the 32*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 32*32=1024 every 32 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile) + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx + out[offset + (local_row*32) + warp_lane] = data; + } + } + break; + } + } + } + } + } +} + +#define DENORM 1.0f/127.0f +#define MAX_SPARSE_COUNT 32 +#define SMEM_SIZE 8*256 +template +void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, + int *offset_rowidx, int *rowidx, int *colidx, + sycl::half *values, T *B, sycl::half *out, + float *__restrict__ const dequant_stats, + int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats) +{ + + // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block + // If a block finishes, the next one is scheduled. Since the last blocks like have fewer + // elements they finish faster "fillin up" the gaps left by larger blocks + + // without tensor cores + // 1. use rowidx_length to find what to load (as many blocks as there are rows) + // 2. Load A into registers + // 3. each warp loads all required rows of B but each warp is offset by k + // 4. Do mma operations that accumulate into registers + // 5. Each warp stores its output row into matrix C + + const int count = max_count[item_ct1.get_group(2)]; + const int local_max_idx = max_idx[item_ct1.get_group(2)]; + const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; + const int local_row_idx = rowidx[offset]; + + const int warp_id = item_ct1.get_local_id(2) / 32; + const int warp_idx = item_ct1.get_local_id(2) % 32; + const int warp_offset = (warp_id*32)*SPMM_ITEMS; + const int num_items = BITS == 8 ? 8 : 8; + int idx_col_B = warp_offset; + int local_idx_col_B_offset = 0; + + sycl::half local_valA[MAX_SPARSE_COUNT]; + int local_colidxA[MAX_SPARSE_COUNT]; + sycl::half local_valC[SPMM_ITEMS]; + T local_valsB[num_items]; + sycl::half local_valOut[num_items]; + // 128 byte loads per warp == 4 bytes per thread + + // 2. Load A into registers + for(int j = 0; j < MAX_SPARSE_COUNT; j++) + { + local_valA[j] = + j < count + ? values[offset + j] + : sycl::vec{0.0f} + .convert()[0]; + local_colidxA[j] = j < count ? colidx[offset+j] : 0; + } + + // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 + // we expect each warp to be SPMM_ITEMS*32 apart + // we have a total of 128 bytes for the bank with a bank size of 4 bytes + // added 3 bytes = 6 values between warps should reduce bank conflicts + + while(idx_col_B < colsB) + { + + if(dequant_stats != NULL) + { + for (int i = item_ct1.get_local_id(2); i < SMEM_SIZE; + i += item_ct1.get_local_range(2)) + if((idx_col_B+i-local_idx_col_B_offset) < colsB) + smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; + + /* + DPCT1065:77: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + } + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j++) + local_valC[j] = 0.0f; + + #pragma unroll + for(int i = 0; i < count; i++) + { + // 3. each warp loads all required rows of B but each warp is offset by k + int row_offset = colsB*local_colidxA[i]; + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached + int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; + if(idx >= colsB){ break; } + if((idx+num_items < colsB)) + { + if(BITS == 8) + reinterpret_cast(local_valsB)[0] = + reinterpret_cast( + B)[(row_offset + idx) / num_items]; + else + reinterpret_cast(local_valsB)[0] = + reinterpret_cast( + B)[(row_offset + idx) / num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx+k < colsB) + local_valsB[k] = B[row_offset+idx+k]; + else + local_valsB[k] = 0.0f; + } + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + { + if(BITS == 8 && dequant_stats != NULL) + // we do texture cache reads (__ldg) on dequant_stats which should be super fast + { + float valB = local_valsB[k]; + float valA = local_valA[i]; + if(valB != 0.0 && valA != 0.0) + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; + } + else + local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; + } + } + } + + int idx_row_C = (colsB*local_row_idx); + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + //int idx_col_C = idx_col_B + (32*j) + warp_idx; + int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; + int idx_val = idx_col_C + idx_row_C; + + if(idx_col_C +num_items < colsB) + { + + // load outputs to do inplace addition + reinterpret_cast(local_valOut)[0] = + reinterpret_cast(out)[idx_val / num_items]; + +#pragma unroll num_items + for(int k = 0; k < num_items; k++) + local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; + + reinterpret_cast(out)[idx_val / num_items] = + reinterpret_cast( + local_valC)[j / num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx_col_C + k < colsB) + out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; + } + } + + idx_col_B += item_ct1.get_local_range(2) * SPMM_ITEMS; + local_idx_col_B_offset += item_ct1.get_local_range(2) * SPMM_ITEMS; + } +} + +template void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, + const sycl::nd_item<3> &item_ct1) +{ + int local_colidx = idx[item_ct1.get_group(2)]; + + if(FORMAT==COL_TURING) + { + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // columns are grouped in increments of 4, meaning that one has the following rows and columns + // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...] + + // each thread reads 1 element = 1 row + for (int row = item_ct1.get_local_id(2); row < rowsA; + row += item_ct1.get_local_range(2)) + { + int offset_per_col_tile = ((rowsA+7)/8)*32*8; + int tile_offset_rows = (row/8)*32*8; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int offset = 0; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 8; + if(row % 2 == 1) + offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2); + else + // even + offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2); + + offset += tile_offset_rows + tile_offset_cols; + + char val = A[offset]; + + int out_idx = (row * idx_size) + item_ct1.get_group(2); + out[out_idx] = val; + } + } + else if(FORMAT == COL_AMPERE) + { + + for (int row = item_ct1.get_local_id(2); row < rowsA; + row += item_ct1.get_local_range(2)) + { + // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element + // within each tile. + int offset_per_col_tile = ((rowsA+31)/32)*32*32; + int tile_offset_rows = (row/32)*32*32; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 32; + // this magic is taken from the cublasLt doc (search for COL32) + int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; + offset += tile_offset_cols + tile_offset_rows; + + char val = A[offset]; + int out_idx = (row * idx_size) + item_ct1.get_group(2); + out[out_idx] = val; + } + } +} + + +//template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +//// element-wise kernel +//// 1. Load batch x k into registers +//// 2. Load k x k into registers +//// 3. dequantize and store in second pair of k x k +//// 4. matmul +//// 5. sum with cub +//// 6. store outputs +//// TC kernel +//// use k warps per thread block +//// 1. threadblock use read-only cache to read in register tile for A into shared memory +//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments +//// 3. each warp reads a segment of values 16x32 from B +//// 4. do dequantization from register of B into second pair of registers +//// 5. store (4) into fragment +//// 6. matmul aggregate into fragment C +//// 7. aggreecate files of C into shared memroy block C +//// 8. sum (7) +//// 9. write outputs to matmul output matrix +//} + +template inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f) +{ + if(limit_base + ITEMS <= limit) + reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; + else + { + for(int k = 0; k < ITEMS; k++) + { + if(limit_base + k < limit) + local[k] = buffer[idx+k]; + else + local[k] = (T)zero_value; + } + } +} + +#define WARPS 3 +template void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) +{ + +#if DPCT_COMPATIBILITY_TEMP >= 750 + //using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + const int val_per_iter = blockDim.x-32; + + T local_A[4]; + T local_B[128]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + //local_A[0] = A[idx]; + + //#pragma unroll 32 + //for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif +} + + +template void printnonzero(T *A, int num_values, const char * strval, + const sycl::stream &stream_ct1) +{ + for(int i = 0; i < num_values; i++) + if((float)A[i] != 0.0) + + stream_ct1 <<"Strval "<< strval << "index "<(float *A, int num_values, const char*strval, + const sycl::stream &stream_ct1); +template void printnonzero(sycl::half *A, int num_values, + const char *strval, + const sycl::stream &stream_ct1); + +static dpct::global_memory nf4_data( + sycl::range<1>(16), + {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, + -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, + 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, + 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, + 0.7229568362236023, 1.0}); +template void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + +#if DPCT_COMPATIBILITY_TEMP >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + + T quant_map[16]; + + #pragma unroll 16 + for(int i = 0; i < 16; i++) + quant_map[i] = nf4_data[i]; + //__shared__ T quant_map[16*160]; + + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; + + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_C[8*32]; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) + smem_C[i] = 0.0f; + + __syncthreads(); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + //local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + //local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + + //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); + //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); + local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); + local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + //if(threadIdx.x == 0) + //printf("aa %i %i\n", idx, loaded_values); + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + //if(threadIdx.x == 0) + //printf("%i %i\n", idx, loaded_values); + + //__syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; + } + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + //local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); + + //local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); + //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); + local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); + local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); + } + //printnonzero(local_B, 128, ""); + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + //if(threadIdx.x == 0) + //{ + // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); + // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); + //} + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + //if(warp_lane == 0) + //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); + + //printnonzero(smem_C, 32, ""); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_C[warp_lane]; +#endif +} + +#define num_values_4bit 32 +template +void kgemm_4bit_inference_naive(int M, int N, int K, T *__restrict__ const A, + unsigned char *B, float *absmax, + const float *datatype, T *out, int lda, int ldb, + int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, T *quant_map) +{ + + // per threadblock: + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1x32 * 32x4 -> 1x4 outputs per thread block + + const int warp_idx = item_ct1.get_local_id(2) / 32; + const int warp_lane = item_ct1.get_local_id(2) % 32; + const int row_B = (THREADS / 32) * item_ct1.get_group(2) + warp_idx; + const int num_values_8bit = num_values_4bit/2; + float local_C = 0.0f; + + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit]; + T local_A[num_values_4bit]; + + T local_absmax = T(0.0f); + + for (int i = item_ct1.get_local_id(2); i < 16; i++) + quant_map[i] = T(datatype[i]); + /* + DPCT1065:88: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier(); + + // A: [1, K] + // B: [N, K] + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) + { + int inner_idx_halved = inner_idx/2; + int offset_B = ldb*row_B; + int absidx = ((2*offset_B)+inner_idx)/blocksize; + /* + DPCT1026:159: The call to __ldg was removed because there is no + corresponding API in SYCL. + */ + local_absmax = absmax[absidx]; + + if(row_B < M) + { + if((inner_idx_halved + num_values_8bit) < (K/2)) + { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = + reinterpret_cast( + B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + if((inner_idx_halved) + j < (K/2)) + local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; + } + + #pragma unroll + for(int k = 0; k < num_values_8bit; k++) + { +#if DPCT_COMPATIBILITY_TEMP >= 800 + local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; + local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax; + #else + // bf16 multipliation not supported + local_B[k*2] = T((float)quant_map[local_B_4bit[k] >> 4]*(float)local_absmax); + local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[k] & 0x0F]*(float)local_absmax); + #endif + } + + if(inner_idx+num_values_4bit < K) + { + // this is also relatively important for performance + if(BITS==16) + { + reinterpret_cast(local_A)[0] = + reinterpret_cast( + A)[inner_idx / (num_values_4bit / 4) + 0]; + reinterpret_cast(local_A)[1] = + reinterpret_cast( + A)[inner_idx / (num_values_4bit / 4) + 1]; + reinterpret_cast(local_A)[2] = + reinterpret_cast( + A)[inner_idx / (num_values_4bit / 4) + 2]; + reinterpret_cast(local_A)[3] = + reinterpret_cast( + A)[inner_idx / (num_values_4bit / 4) + 3]; + } + else + { + reinterpret_cast(local_A)[0] = + reinterpret_cast( + A)[inner_idx / (num_values_4bit / 8) + 0]; + reinterpret_cast(local_A)[1] = + reinterpret_cast( + A)[inner_idx / (num_values_4bit / 8) + 1]; + reinterpret_cast(local_A)[2] = + reinterpret_cast( + A)[inner_idx / (num_values_4bit / 8) + 2]; + reinterpret_cast(local_A)[3] = + reinterpret_cast( + A)[inner_idx / (num_values_4bit / 8) + 3]; + reinterpret_cast(local_A)[4] = + reinterpret_cast( + A)[inner_idx / (num_values_4bit / 8) + 4]; + reinterpret_cast(local_A)[5] = + reinterpret_cast( + A)[inner_idx / (num_values_4bit / 8) + 5]; + reinterpret_cast(local_A)[6] = + reinterpret_cast( + A)[inner_idx / (num_values_4bit / 8) + 6]; + reinterpret_cast(local_A)[7] = + reinterpret_cast( + A)[inner_idx / (num_values_4bit / 8) + 7]; + } + + } + else + #pragma unroll + for(int k = 0; k < num_values_4bit; k++) + if(inner_idx + k < K) + local_A[k] = A[inner_idx + k]; + else + local_A[k] = T(0.0f); + + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for(int k = 0; k < num_values_4bit; k++) + { +#if DPCT_COMPATIBILITY_TEMP >= 800 + local_C += (float)(local_A[k]*local_B[k]); + #else + // bf16 multipliation not supported + local_C += ((float)local_A[k]*(float)local_B[k]); + #endif + } + } + + local_C = sycl::reduce_over_group(item_ct1.get_sub_group(), local_C, + sycl::plus<>()); + + if(row_B < M && warp_lane == 0) + out[row_B] = T(local_C); + +} + + +//#define ROWS 2 +//template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) +//{ +//// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp +//// 1. Load dataB into register +//// 2. Dequantize B +//// 3. Fetch data from A and multiply +// +// typedef cub::BlockLoad LoadA; +// //__shared__ typename LoadA::TempStorage loada; +// typedef cub::BlockLoad LoadB; +// //__shared__ typename LoadB::TempStorage loadb; +// typedef cub::BlockReduce BlockReduce; +// // Allocate shared memory for BlockReduce +// //__shared__ typename BlockReduce::TempStorage reduce; +// +// __shared__ union { +// typename BlockReduce::TempStorage reduce; +// typename LoadB::TempStorage loadb; +// typename LoadA::TempStorage loada; +// } temp_storage; +// +// +// T dataA[ITEMS]; +// T local_B[ITEMS]; +// T local_accC[ROWS]; +// int valid_items = 0; +// const int col_offset = blockIdx.x * 8; +// +// __shared__ T tileA[ROWS*THREADS*ITEMS]; +// __shared__ T accumulatorC[ROWS*8]; +// +// //#pragma unroll 8 +// //for(int i = 0; i < 8; i++) +// // tileA[threadIdx.x + (i*256)] = 0.0f; +// //__syncthreads(); +// if(threadIdx.x < 64) +// accumulatorC[threadIdx.x] = 0.0f; +// __syncthreads(); +// +// +// for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS) +// { +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// int baserow = 0; +// for(int row = baserow; row < (baserow+ROWS) && row < N; row++) +// { +// LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); +// +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k]; +// +// __syncthreads(); +// } +// baserow += ROWS; +// +// // load 16 columns from B at a time. B is transposed, so its like loading rows +// // each warp loads one row +// // each thread loads 128 byte +// +// // col: inner_idx + warp_lane +// // row: ldb*(offset + warp_id) +// for(int col = 0; col < 8 && (col_offset + col) < M; col++) +// { +// int colB = col_offset + col; +// +// for(int k = 0; k < ROWS; k++) +// local_accC[k] = 0.0f; +// +// int base_idxB = ldb*colB; +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); +// __syncthreads(); +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// { +// int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k); +// local_accC[row] += tileA[idxA]*local_B[k]; +// } +// +// local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); +// if(threadIdx.x == 0) +// atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); +// } +// } +// } +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// int out_idx = ldc*row + col_offset; +// +// //if(threadIdx.x < 8) +// // if(accumulatorC[row*8 + threadIdx.x] != 0.0) +// // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); +// +// if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) +// { +// //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); +// out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; +// } +// } +// +// +// +//} + + +template void kfunc(T *A, T *B, T value, long n, + const sycl::nd_item<3> &item_ct1) +{ + for (long i = (item_ct1.get_local_range(2) * item_ct1.get_group(2)) + + item_ct1.get_local_id(2); + i < n; i += (item_ct1.get_local_range(2) * item_ct1.get_group_range(2))) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; + } + } +} + + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void kfunc(float *A, float *B, float value, long n, + const sycl::nd_item<3> &item_ct1); +template void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n, + const sycl::nd_item<3> &item_ct1); +template void kfunc(float *A, float *B, float value, long n, + const sycl::nd_item<3> &item_ct1); +template void kfunc(float *A, float *B, float value, long n, + const sycl::nd_item<3> &item_ct1); + +// these are not used and make no sense, but the compiler needs them +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); +// these are not used and make no sense, but the compiler needs them + +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); +template void gemm_device(int M, int N, int K, + sycl::half *__restrict__ const A, + sycl::half *B, sycl::half *out, + int lda, int ldb, int ldc); + +template void kgemm_4bit_inference( + int M, int N, int K, sycl::half *__restrict__ const A, unsigned char *B, + float *absmax, sycl::half *out, int lda, int ldb, int ldc, int blocksize); +template void kgemm_4bit_inference( + int M, int N, int K, sycl::half *__restrict__ const A, unsigned char *B, + float *absmax, sycl::half *out, int lda, int ldb, int ldc, int blocksize); +template void kgemm_4bit_inference( + int M, int N, int K, sycl::half *__restrict__ const A, unsigned char *B, + float *absmax, sycl::half *out, int lda, int ldb, int ldc, int blocksize); +template void kgemm_4bit_inference( + int M, int N, int K, sycl::half *__restrict__ const A, unsigned char *B, + float *absmax, sycl::half *out, int lda, int ldb, int ldc, int blocksize); + +template void kgemm_4bit_inference_naive( + int M, int N, int K, sycl::half *__restrict__ const A, unsigned char *B, + float *absmax, const float *datatype, sycl::half *out, int lda, int ldb, + int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, + sycl::half *quant_map); +template void kgemm_4bit_inference_naive( + int M, int N, int K, oneapi::mkl::bfloat16 *__restrict__ const A, + unsigned char *B, float *absmax, const float *datatype, + oneapi::mkl::bfloat16 *out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, oneapi::mkl::bfloat16 *quant_map); +template void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + float *quant_map); + +template void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, + const sycl::nd_item<3> &item_ct1); +template void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, + const sycl::nd_item<3> &item_ct1); + +template void kspmm_coo_very_sparse_naive( + int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, + sycl::half *values, sycl::half *B, sycl::half *out, float *dequant_stats, + int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); +template void kspmm_coo_very_sparse_naive( + int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, + sycl::half *values, sycl::half *B, sycl::half *out, float *dequant_stats, + int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); +template void kspmm_coo_very_sparse_naive( + int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, + sycl::half *values, sycl::half *B, sycl::half *out, float *dequant_stats, + int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); +template void kspmm_coo_very_sparse_naive( + int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, + sycl::half *values, signed char *B, sycl::half *out, float *dequant_stats, + int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); +template void kspmm_coo_very_sparse_naive( + int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, + sycl::half *values, signed char *B, sycl::half *out, float *dequant_stats, + int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); +template void kspmm_coo_very_sparse_naive( + int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, + sycl::half *values, signed char *B, sycl::half *out, float *dequant_stats, + int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); + +template void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, + const sycl::nd_item<3> &item_ct1, + char *smem_data); +template void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, + const sycl::nd_item<3> &item_ct1, + char *smem_data); +template void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, + const sycl::nd_item<3> &item_ct1, + char *smem_data); +template void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, + const sycl::nd_item<3> &item_ct1, + char *smem_data); +template void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, + const sycl::nd_item<3> &item_ct1, + char *smem_data); +template void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, + const sycl::nd_item<3> &item_ct1, + char *smem_data); + +template void kdequant_mm_int32_fp16<4, 128, 512>( + int *__restrict__ const A, float *__restrict__ const rowStats, + float *__restrict__ const colStats, sycl::half *out, float *newRowStats, + float *newcolStats, sycl::half *__restrict__ const bias, const int numRows, + const int numCols, const int tileCols, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_rowStats); + +template void kDoubleRowColQuant<64, 4, 16, 64 * 4, 0>( + sycl::half *__restrict__ const A, float *__restrict__ const rowStats, + float *__restrict__ const colStats, char *out_col_normed, + char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, + int *__restrict__ nnz_block_ptr, float threshold, int rows, int cols, + int tiledCols, const sycl::nd_item<3> &item_ct1, + float *smem_row_stats, + unsigned int *smem_nnz_row_idx); +template void kDoubleRowColQuant<64, 4, 16, 64 * 4, 1>( + sycl::half *__restrict__ const A, float *__restrict__ const rowStats, + float *__restrict__ const colStats, char *out_col_normed, + char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, + int *__restrict__ nnz_block_ptr, float threshold, int rows, int cols, + int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, + unsigned int *smem_nnz_row_idx); + +template unsigned char dQuantize<0>(float* smem_code, const float rand, float x); +template unsigned char dQuantize<1>(float* smem_code, const float rand, float x); + +template void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n, + const sycl::nd_item<3> &item_ct1, + uint8_t *temp_storage_ct1); +template void kEstimateQuantiles(sycl::half *__restrict__ const A, float *code, + const float offset, const sycl::half max_val, + const int n, const sycl::nd_item<3> &item_ct1, + uint8_t *temp_storage_ct1); + +#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ +template void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ + float* state1, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, sycl::half) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, sycl::half) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(LION, sycl::half) +MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(LION, oneapi::mkl::bfloat16) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, sycl::half) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) + +#define MAKE_Optimizer32bit1State(oname, gtype) \ +template void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + +MAKE_Optimizer32bit1State(MOMENTUM, sycl::half) +MAKE_Optimizer32bit1State(MOMENTUM, float) +MAKE_Optimizer32bit1State(RMSPROP, sycl::half) +MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(LION, sycl::half) +MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(LION, oneapi::mkl::bfloat16) +MAKE_Optimizer32bit1State(ADAGRAD, sycl::half) +MAKE_Optimizer32bit1State(ADAGRAD, float) + +#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ +template void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ + float* state1, float* state2, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + +MAKE_PreconditionOptimizer32bit2State(ADAM, float) +MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::half) +MAKE_PreconditionOptimizer32bit2State(ADAM, oneapi::mkl::bfloat16) + +template void kOptimizer32bit2State( + float *g, float *p, + float *state1, + float *state2, float *unorm, + const float max_unorm, + const float param_norm, + const float beta1, + const float beta2, + const float eps, + const float weight_decay, + const int step, + const float lr, + const float gnorm_scale, + const bool skip_zeros, + const int n, + const sycl::nd_item<3> + &item_ct1, + uint8_t *temp_storage_ct1); +template void kOptimizer32bit2State( + sycl::half *g, sycl::half *p, float *state1, float *state2, float *unorm, + const float max_unorm, const float param_norm, const float beta1, + const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, + const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, + uint8_t *temp_storage_ct1); +template void kOptimizer32bit2State( + oneapi::mkl::bfloat16 *g, oneapi::mkl::bfloat16 *p, float *state1, + float *state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, + const float weight_decay, const int step, const float lr, + const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + +#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ +template void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ + float *unorm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + const float weight_decay, \ + const float gnorm_scale, \ + const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, float *smem_quantiles1); + +MAKE_PreconditionStatic8bit1State(MOMENTUM, sycl::half) +MAKE_PreconditionStatic8bit1State(MOMENTUM, float) +MAKE_PreconditionStatic8bit1State(RMSPROP, sycl::half) +MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State( LION, sycl::half) +MAKE_PreconditionStatic8bit1State(LION, float) + +#define MAKE_optimizerStatic8bit1State(oname, gtype) \ +template void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, uint8_t *temp_storage_ct1); + +MAKE_optimizerStatic8bit1State(MOMENTUM, sycl::half) +MAKE_optimizerStatic8bit1State(MOMENTUM, float) +MAKE_optimizerStatic8bit1State(RMSPROP, sycl::half) +MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LION, sycl::half) +MAKE_optimizerStatic8bit1State(LION, float) + +#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ +template void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ + float *unorm, \ + const float beta1, const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + const float gnorm_scale, \ + const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, float *smem_quantiles1, float *smem_quantiles2); + +MAKE_PreconditionStatic8bit2State(ADAM, sycl::half) +MAKE_PreconditionStatic8bit2State(ADAM, float) + +#define MAKE_optimizerStatic8bit2State(oname, gtype) \ +template void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, uint8_t *temp_storage_ct1); + +MAKE_optimizerStatic8bit2State(ADAM, sycl::half) +MAKE_optimizerStatic8bit2State(ADAM, float) + +template void kPercentileClipping(float *__restrict__ g,float *gnorm_vec, int step,const int n,const sycl::nd_item<3> &item_ct1); +template void kPercentileClipping(sycl::half *__restrict__ g, float *gnorm_vec, int step, const int n, const sycl::nd_item<3> &item_ct1); + +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ +template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value); + +MAKE_kQuantizeBlockwise(sycl::half, 4096,4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half,4096, 4,1, General8bit) +MAKE_kQuantizeBlockwise(sycl::half,2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0. FP4) +MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0. FP4) +MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256,2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16,4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 128, 2, 0,FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 64, 2, 0, NF4) + +template void kDequantizeBlockwise(float *code,unsigned char *A,float *absmax,sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); +template void kDequantizeBlockwise( + float *code, unsigned char *A, float *absmax, sycl::half *out, + const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); +template void kDequantizeBlockwise( + float *code, unsigned char *A, float *absmax, sycl::half *out, + const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); +template void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1); +template void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1); +template void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1); +template void kDequantizeBlockwise( + float *code, unsigned char *A, float *absmax, oneapi::mkl::bfloat16 *out, + const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); +template void +kDequantizeBlockwise( + float *code, unsigned char *A, float *absmax, oneapi::mkl::bfloat16 *out, + const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); +template void kDequantizeBlockwise( + float *code, unsigned char *A, float *absmax, oneapi::mkl::bfloat16 *out, + const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); + +#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* absmax1, float* absmax2, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2, uint8_t *temp_storage_ct1); + +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, oneapi::mkl::bfloat16, 2048, 8) + + +#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template void kOptimizerStatic8bit1StateBlockwise( \ + gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* absmax1, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1, uint8_t *temp_storage_ct1); + +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, oneapi::mkl::bfloat16, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) + MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, sycl::half, 2048, 8) From 1f2c8bc6d3ece0aea6beb2b1f63d5d9470e1045c Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 20 Mar 2024 05:27:16 -0700 Subject: [PATCH 05/66] add partial kernels --- csrc/{kernels.dp.cpp => sycl/kernels.cpp} | 3005 ++++++++++----------- csrc/{kernels.dp.hpp => sycl/kernels.h} | 0 csrc/{ops.dp.cpp => sycl/ops.cpp} | 0 csrc/{ops.dp.hpp => sycl/ops.h} | 0 4 files changed, 1386 insertions(+), 1619 deletions(-) rename csrc/{kernels.dp.cpp => sycl/kernels.cpp} (55%) rename csrc/{kernels.dp.hpp => sycl/kernels.h} (100%) rename csrc/{ops.dp.cpp => sycl/ops.cpp} (100%) rename csrc/{ops.dp.hpp => sycl/ops.h} (100%) diff --git a/csrc/kernels.dp.cpp b/csrc/sycl/kernels.cpp similarity index 55% rename from csrc/kernels.dp.cpp rename to csrc/sycl/kernels.cpp index 51b1b7abf..6dcefca23 100644 --- a/csrc/kernels.dp.cpp +++ b/csrc/sycl/kernels.cpp @@ -7,13 +7,14 @@ #include #include #include -#include -#include -#include -#include +#include #include "kernels.dp.hpp" #include +#include + + + #define HLF_MAX 65504 #define TH 1024 #define NUM 4 @@ -26,10 +27,7 @@ float atomicMax(float* address, float val) { int old = *address_as_i, assumed; do { assumed = old; - old = dpct::atomic_compare_exchange_strong< - sycl::access::address_space::generic_space>( - reinterpret_cast(address), assumed, - sycl::bit_cast(sycl::fmax(val, sycl::bit_cast(assumed)))); + old = dpct::atomic_compare_exchange_strong(reinterpret_cast(address), assumed, sycl::bit_cast(sycl::fmax(val, sycl::bit_cast(assumed)))); } while (assumed != old); return sycl::bit_cast(old); } @@ -39,10 +37,7 @@ float atomicMin(float* address, float val) { int old = *address_as_i, assumed; do { assumed = old; - old = dpct::atomic_compare_exchange_strong< - sycl::access::address_space::generic_space>( - reinterpret_cast(address), assumed, - sycl::bit_cast(sycl::fmin(val, sycl::bit_cast(assumed)))); + old = dpct::atomic_compare_exchange_strong(reinterpret_cast(address), assumed, sycl::bit_cast(sycl::fmin(val, sycl::bit_cast(assumed)))); } while (assumed != old); return sycl::bit_cast(old); } @@ -109,7 +104,7 @@ float dDequantizeFP4Tree(unsigned char val, float absmax) return 1.00000000f*absmax*sign; // 1011 else return 0.66666667f*absmax*sign; // 1010 - else + else if((val & 0b0001) == 1) // 100 return 5.208333333e-03f*absmax*sign; // 1001 else @@ -133,10 +128,10 @@ unsigned char dQuantizeFP4(float x) // we do a binary search // the pivots are divided by 12 (the FP4 absmax) - // since we assum input data is in [-1.0, 1.0] + // since we assume input data is in [-1.0, 1.0] // !be careful here, its easy to make a mistake - // that is difficult to noice if you add an extra + // that is difficult to notice if you add an extra // zero somewhere! int sign = x < 0 ? 0b1000 : 0b0000; @@ -173,36 +168,36 @@ sycl::half dhDequantizeNF4(unsigned char val) if((val & 0b0100) == 4) // 1 if((val & 0b0010) == 2) // 11 if((val & 0b0001) == 1) // 111 - return 1.0f; + return 1.0f; else return 0.7229568362236023f; else if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; + return 0.5626170039176941f; else - return 0.44070982933044434f; + return 0.44070982933044434f; else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; + return 0.33791524171829224f; else - return 0.24611230194568634f; - else + return 0.24611230194568634f; + else if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; + return 0.16093020141124725f; else - return 0.07958029955625534f; + return 0.07958029955625534f; else if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 011 - return 0.0f; + return 0.0f; else - return -0.09105003625154495f; + return -0.09105003625154495f; else if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; + return -0.18477343022823334f; else return -0.28444138169288635f; else @@ -210,12 +205,12 @@ sycl::half dhDequantizeNF4(unsigned char val) if((val & 0b0001) == 1) // 001 return -0.39491748809814453f; else - return -0.5250730514526367f; - else + return -0.5250730514526367f; + else if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; + return -0.6961928009986877f; else - return -1.0f; + return -1.0f; } @@ -228,36 +223,36 @@ float dDequantizeNF4(unsigned char val) if((val & 0b0100) == 4) // 1 if((val & 0b0010) == 2) // 11 if((val & 0b0001) == 1) // 111 - return 1.0f; + return 1.0f; else return 0.7229568362236023f; else if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; + return 0.5626170039176941f; else - return 0.44070982933044434f; + return 0.44070982933044434f; else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; + return 0.33791524171829224f; else - return 0.24611230194568634f; - else + return 0.24611230194568634f; + else if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; + return 0.16093020141124725f; else - return 0.07958029955625534f; + return 0.07958029955625534f; else if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 011 - return 0.0f; + return 0.0f; else - return -0.09105003625154495f; + return -0.09105003625154495f; else if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; + return -0.18477343022823334f; else return -0.28444138169288635f; else @@ -265,12 +260,12 @@ float dDequantizeNF4(unsigned char val) if((val & 0b0001) == 1) // 001 return -0.39491748809814453f; else - return -0.5250730514526367f; - else + return -0.5250730514526367f; + else if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; + return -0.6961928009986877f; else - return -1.0f; + return -1.0f; } @@ -393,14 +388,14 @@ unsigned char dQuantize(float* smem_code, const float rand, float x) { if(x > val) { - float dist_to_upper = sycl::fabs(upper - x); + float dist_to_upper = sycl::fabs(upper-x); float dist_full = upper-val; if(rand >= dist_to_upper/dist_full) return upper_pivot; else return pivot; } else { - float dist_to_lower = sycl::fabs(lower - x); + float dist_to_lower = sycl::fabs(lower-x); float dist_full = val-lower; if(rand >= dist_to_lower/dist_full) return lower_pivot; else return pivot; @@ -409,9 +404,7 @@ unsigned char dQuantize(float* smem_code, const float rand, float x) } template -__dpct_inline__ unsigned char quantize_2D(float *__restrict__ quadrants, - float *__restrict__ const smem_code, - float x) +__dpct_inline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x) { int pivot = 127; int upper_pivot = 255; @@ -466,9 +459,7 @@ __dpct_inline__ unsigned char quantize_2D(float *__restrict__ quadrants, } template -__dpct_inline__ unsigned char -quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, - float lower, float midpoint, float upper) +__dpct_inline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, float lower, float midpoint, float upper) { int lower_pivot = QUADRANT*16-1 - 0; int pivot = QUADRANT*16-1 + 16; @@ -512,96 +503,42 @@ quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, } } -void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, +SYCL_EXTERNAL void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, const sycl::nd_item<3> &item_ct1) { - const int tid = item_ct1.get_local_id(2) + - (item_ct1.get_local_range(2) * item_ct1.get_group(2)); - const int numThreads = - item_ct1.get_local_range(2) * item_ct1.get_group_range(2); + const int tid = item_ct1.get_local_id(2) + (item_ct1.get_local_range(2)*item_ct1.get_group(2)); + const int numThreads = item_ct1.get_local_range(2)*item_ct1.get_group_range(2); for(int i = tid; i < n; i+=numThreads) { int idx = (index1[i]*maxidx1) + index2[i]; - dpct::atomic_fetch_add( - &histogram[idx], src[i]); + dpct::atomic_fetch_add(&histogram[idx], src[i]); } } +void warpreduceKernelMax(int* data, const sycl::nd_item<3> &item_ct1) { + int threadid = item_ct1.get_local_id(2); + int input = data[threadid]; + int output = 0; + output = sycl::reduce_over_group(item_ct1.get_sub_group(), input, sycl::maximum<>()); + data[threadid] = output; +} -/* -template -class BlockLoadKernel { -public: - BlockLoadKernel(sycl::queue& queue, const float* input, float* output, size_t numElements) - : queue_(queue), input_(input), output_(output), numElements_(numElements) { - } - - void operator()() { - queue_.submit([&](sycl::handler& cgh) { - auto input = input_; - auto output = output_; - - cgh.parallel_for( - sycl::nd_range<1>(sycl::range<1>(numElements_), sycl::range<1>(BlockSize)), - [=](sycl::nd_item<1> item) { - // Create a local memory buffer to hold a block of data - dpct::local localBuffer[BlockSize / NumPerThread]; - - // Compute global index - size_t globalIndex = item.get_global_linear_id(); - - // Compute local index within the block - size_t localIndex = item.get_local_id(0) / NumPerThread; - - // Compute index within the block for each thread - size_t blockIndex = item.get_local_id(0) % (BlockSize / NumPerThread); - - // Load data from global memory into local memory - localBuffer[blockIndex] = input[globalIndex]; - - // Synchronize to ensure all work-items have loaded data - item.barrier(); - - // Use the loaded data from local memory (e.g., perform some computation) - float result = localBuffer[localIndex]; - - // Store the result back to global memory - output[globalIndex] = result; - }); - }); - } - -private: - sycl::queue& queue_; - const float* input_; - float* output_; - size_t numElements_; -}; -*/ - - -template -void kCompressMax(T *__restrict__ const A, T *out, unsigned char *out_idx, - const int n, const sycl::nd_item<3> &item_ct1) +template +void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n, + const sycl::nd_item<3> &item_ct1, int *smem_max_indices, + float *smem_max_values) { - //typedef cub::WarpReduce WarpReduce; - //__shared__ typename WarpReduce::TempStorage temp_storage; - typedef cub::BlockLoad LoadT; - __shared__ typename LoadT::TempStorage loadt; + + + //typename WarpReduce::TempStorage temp_storage; + //typedef cub::BlockLoad LoadT; + //typename LoadT::TempStorage loadt; - const int warp_idx = item_ct1.get_local_id(2) / 32; - const int valid_items = n - (item_ct1.get_group(2) * BLOCK_SIZE) > BLOCK_SIZE - ? BLOCK_SIZE - : n - (item_ct1.get_group(2) * BLOCK_SIZE); + const int warp_idx = item_ct1.get_local_id(2)/32; + const int valid_items = n - (item_ct1.get_group(2)*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (item_ct1.get_group(2)*BLOCK_SIZE); // BLOCK_SIZE/32 == number of warps - //sycl::local_accessor smem_max_indices - //__shared__ int smem_max_indices[8*BLOCK_SIZE/32]; - //__shared__ float smem_max_values[8*BLOCK_SIZE/32]; - int smem_max_indices[8*BLOCK_SIZE/32]; - float smem_max_values[8*BLOCK_SIZE/32]; - T values[8]; T max1 = -64000.0f; T max2 = -64000.0f; @@ -610,16 +547,36 @@ void kCompressMax(T *__restrict__ const A, T *out, unsigned char *out_idx, int sign1 = -1; int sign2 = -1; - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - - LoadT(loadt).Load(&(A[(item_ct1.get_group(2) * BLOCK_SIZE)]), values, - valid_items, (T)0.0f); -#pragma unroll 8 + sycl::buffer buff_indices(smem_max_indices, sycl::range<1>(8*BLOCK_SIZE/32)); + sycl::buffer buff_values(smem_max_values, sycl::range<1>(8*BLOCK_SIZE/32)); + sycl::buffer buff_A(A,sycl::range<1>(8*BLOCK_SIZE/32)); + + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(8*BLOCK_SIZE/32); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_A[(item_ct1.get_local_id(2)*BLOCK_SIZE)], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, values); + + }); + + }); + + #pragma unroll 8 for(int i = 0; i < 8; i++) { T absval = fabsf(values[i]); @@ -627,29 +584,32 @@ void kCompressMax(T *__restrict__ const A, T *out, unsigned char *out_idx, { max1 = values[i]; sign1 = signbit(values[i]); - max_idx1 = 8 * item_ct1.get_local_id(2) + i; + max_idx1 = 8*item_ct1.get_local_id(2) + i; } else if(absval > max2) { max2 = values[i]; sign2 = signbit(values[i]); - max_idx2 = 8 * item_ct1.get_local_id(2) + i; + max_idx2 = 8*item_ct1.get_local_id(2) + i; } } - + float warp_max; + sycl::host_accessor hacc_values{buff_values}; + sycl::host_accessor hacc_indices{buff_indices}; for(int i = 0; i < 8; i++) { // 3. do warp reduction + broadcast back - //warp_max = WarpReduce(temp_storage).Reduce(max1, sycl::maximum<>()); - warp_max = dpct::group::reduce(item_ct1, max1, sycl::maximum<>()); + + output = sycl::reduce_over_group(item_ct1.get_sub_group(), max1, sycl::maximum<>()); warp_max = item_ct1.get_sub_group().shuffle(warp_max, 0); // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest if(warp_max == max1) { - smem_max_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1; - smem_max_indices[warp_idx*8 + i] = max_idx1; + + hacc_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1; + hacc_indices_indices[warp_idx*8 + i] = max_idx1; sign1 = sign2; max1 = max2; @@ -660,11 +620,11 @@ void kCompressMax(T *__restrict__ const A, T *out, unsigned char *out_idx, sycl::group_barrier(item_ct1.get_sub_group()); } - if (item_ct1.get_local_id(2) % 32 < 8) + if(item_ct1.get_local_id(2) % 32 < 8) { // offset: 8 values per 256 input values // - int offset = BLOCK_SIZE * item_ct1.get_group(2) * BLOCK_SIZE / 32 * 8; + int offset = BLOCK_SIZE*item_ct1.get_group(2)*BLOCK_SIZE/32*8; } } @@ -674,32 +634,28 @@ void kCompressMax(T *__restrict__ const A, T *out, unsigned char *out_idx, #define BLOCK_ESTIMATE 4096 template - +SYCL_EXTERNAL void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n, - const sycl::nd_item<3> &item_ct1, - uint8_t *temp_storage_ct1) + const sycl::nd_item<3> &item_ct1) { const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); - int valid_items = (item_ct1.get_group(2) + 1 == item_ct1.get_group_range(2)) - ? n - (item_ct1.get_group(2) * BLOCK_ESTIMATE) - : BLOCK_ESTIMATE; + int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; const int base_idx = (item_ct1.get_group(2) * BLOCK_ESTIMATE); const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); T vals[NUM_ESTIMATE]; - typedef cub::BlockRadixSort BlockRadixSort; + typedef cub::BlockRadixSort BlockRadixSort; typedef cub::BlockLoad LoadFloat; - union type_ct1 { + union type_ct1{ typename LoadFloat::TempStorage loadf; typename BlockRadixSort::TempStorage sort; int smem_qidx[BLOCK_ESTIMATE]; }; type_ct1 &temp_storage = *(type_ct1 *)temp_storage_ct1; - for (unsigned int i = base_idx; i < n_full; - i += item_ct1.get_group_range(2) * BLOCK_ESTIMATE) + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_ESTIMATE) { valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; @@ -711,100 +667,100 @@ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset vals[j] = max_val; /* - DPCT1065:0: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:76: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:96: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:81: Migration of cub::BlockLoad::Load is not supported. */ + + LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); -#pragma unroll 4 + #pragma unroll 4 for(int j = 0; j < NUM_ESTIMATE; j++) vals[j] = ((float)vals[j]) * reciprocal_num_blocks; + /* - DPCT1065:1: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:77: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); // sort into striped pattern to mitigate bank conflicts // striped pattern index for thread 0 [0, 1024, 2048, 3096] // striped pattern index for thread 1 [1, 1025, 2049, 3097] /* - DPCT1007:97: Migration of cub::BlockRadixSort.SortBlockedToStriped is not - supported. + DPCT1007:82: Migration of cub::BlockRadixSort::SortBlockedToStriped is not supported. */ BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); /* - DPCT1065:2: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:78: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); - for (int j = item_ct1.get_local_id(2); j < BLOCK_ESTIMATE; - j += item_ct1.get_local_range(2)) + for(int j = item_ct1.get_local_id(2); j < BLOCK_ESTIMATE; j+=item_ct1.get_local_range(2)) temp_storage.smem_qidx[j] = -1; - if (item_ct1.get_local_id(2) < 256) + /* + DPCT1065:79: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. + */ + item_ct1.barrier(); + + if(item_ct1.get_local_id(2) < 256) { float q_interval = (1.0f-(2.0f*offset))/255.0f; - int local_idx = - sycl::round(((offset + (item_ct1.get_local_id(2) * q_interval)) * - (valid_items - 1))); + /* + DPCT1064:83: Migrated round call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code. + */ + int local_idx = sycl::round(((offset+(item_ct1.get_local_id(2)*q_interval))*(valid_items-1))); temp_storage.smem_qidx[local_idx] = item_ct1.get_local_id(2); } /* - DPCT1065:3: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:80: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); - for (int i = item_ct1.get_local_id(2); i < BLOCK_ESTIMATE; - i += item_ct1.get_local_range(2)) + for(int i = item_ct1.get_local_id(2); i < BLOCK_ESTIMATE; i+=item_ct1.get_local_range(2)) { if(temp_storage.smem_qidx[i] != -1) - dpct::atomic_fetch_add< - sycl::access::address_space::generic_space>( - &code[temp_storage.smem_qidx[i]], vals[i / THREADS_ESTIMATE]); + dpct::atomic_fetch_add(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]); } } } -void kQuantize(float *code, float *__restrict__ const A, unsigned char *out, - const int n, const sycl::nd_item<3> &item_ct1, float *smem_code) +SYCL_EXTERNAL +void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_code) { const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); - int valid_items = (item_ct1.get_group(2) + 1 == item_ct1.get_group_range(2)) - ? n - (item_ct1.get_group(2) * NUM_BLOCK) - : NUM_BLOCK; + int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*NUM_BLOCK) : NUM_BLOCK; const int base_idx = (item_ct1.get_group(2) * NUM_BLOCK); float vals[NUM]; unsigned char qvals[NUM]; //const int lane_id = threadIdx.x % 2; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockStore StoreChar; + //typedef cub::BlockLoad LoadFloat; + //typedef cub::BlockStore StoreChar; + sycl::buffer buff_smem_code(smem_code,sycl::range<1>(257)); + sycl::buffer buff_vals(vals, sycl::range<1>(NUM)); + sycl::buffer buff_A(A,sycl::range<1>(NUM)); + sycl::buffer buff_out(out,sycl::range<1>(NUM)); + //__shared__ float smem_code[2][257]; - if (item_ct1.get_local_id(2) < 256) + if(item_ct1.get_local_id(2) < 256) { smem_code[item_ct1.get_local_id(2)] = code[item_ct1.get_local_id(2)]; //smem_code[0][threadIdx.x] = code[threadIdx.x]; //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; } - for (unsigned int i = base_idx; i < n_full; - i += item_ct1.get_group_range(2) * NUM_BLOCK) + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*NUM_BLOCK) { // number of values already processed in blocks + // number of values already processed in this block + @@ -812,42 +768,89 @@ void kQuantize(float *code, float *__restrict__ const A, unsigned char *out, valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; /* - DPCT1065:89: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1118:50: SYCL group functions and algorithms must be encountered in converged control flow. You may need to adjust the code. */ - item_ct1.barrier(); /* - DPCT1007:160: Migration of cub::BlockLoad.Load is not supported. + DPCT1065:224: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - LoadFloat(loadf).Load(&(A[i]), vals, valid_items); + item_ct1.barrier(sycl::access::fence_space::local_space); + + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_A[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, vals); + }); + + }); + + //LoadFloat(loadf).Load(&(A[i]), vals, valid_items); -#pragma unroll 4 + + #pragma unroll 4 for(int j = 0; j < NUM; j++) qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); /* - DPCT1065:90: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1118:51: SYCL group functions and algorithms must be encountered in converged control flow. You may need to adjust the code. */ - item_ct1.barrier(); /* - DPCT1007:161: Migration of cub::BlockStore.Store is not supported. + DPCT1065:225: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - StoreChar(storec).Store(&(out[i]), qvals, valid_items); + item_ct1.barrier(sycl::access::fence_space::local_space); + + + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_store; + size_t temp_storage_size = group_store::get_local_memory_size(NUM); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_out[i], h, sycl::read_write); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, qvals); + }); + + }); + + + //StoreChar(storec).Store(&(out[i]), qvals, valid_items); } } -template +template //__launch_bounds__(TH, 4) -void kQuantizeBlockwise(float *code, T *__restrict__ const A, float *absmax, - unsigned char *out, float *__restrict__ const rand, - const int rand_offset, const int n, - const sycl::nd_item<3> &item_ct1 - float *smem_code, float *smem_absmax_value) +SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_code, + float *smem_absmax_value) { + + const int n_full = item_ct1.get_group_range(2) * BLOCK_SIZE; int valid_items = 0; const int base_idx = (item_ct1.get_group(2) * BLOCK_SIZE); @@ -859,33 +862,62 @@ void kQuantizeBlockwise(float *code, T *__restrict__ const A, float *absmax, float local_abs_max = 0.0f; int local_rand_idx = 0; - typedef cub::BlockLoad LoadT; - typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; - - typedef cub::BlockLoad LoadFloat; + + sycl::buffer buff_A(A,sycl::range<1>(NUM_PER_TH)); + sycl::buffer buff_out(out,sycl::range<1>(NUM_PER_TH)); + sycl::buffer buff_rand(rand,sycl::range<1>(NUM_PER_TH)); + + + //typedef cub::BlockLoad LoadT; + //typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; + + //typedef cub::BlockLoad LoadFloat; + if(DATA_TYPE == General8bit) - for (int i = item_ct1.get_local_id(2); i < 256; - i += item_ct1.get_local_range(2)) + for(int i = item_ct1.get_local_id(2); i < 256; i+=item_ct1.get_local_range(2)) smem_code[i] = code[i]; - for (unsigned int i = base_idx; i < n_full; - i += item_ct1.get_group_range(2) * BLOCK_SIZE) + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) { valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; local_abs_max = -FLT_MAX; /* - DPCT1065:4: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:84: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* - DPCT1007:98: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:87: Migration of cub::BlockLoad::Load is not supported. */ - LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); - + //LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); + + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_A[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, vals); + }); + + }); + + + // 1. compute local max // 2. broadcast local max // 3. normalize inputs and quantize @@ -895,23 +927,20 @@ void kQuantizeBlockwise(float *code, T *__restrict__ const A, float *absmax, local_abs_max = sycl::fmax(local_abs_max, sycl::fabs((float)vals[j])); /* - DPCT1007:7: Migration of cub::Reduce is not supported. + DPCT1007:0: Migration of cub::Reduce is not supported. */ - local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, sycl::maximum<>(), - valid_items); - //local_abs_max = dpct::group::reduce(item_ct1, local_abs_max, sycl::maximum<>() + //local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, sycl::maximum<>(), valid_items); + local_abs_max = dpct::group::reduce(item_ct1, local_abs_max, sycl::maximum<>()); - if (item_ct1.get_local_id(2) == 0) + if(item_ct1.get_local_id(2) == 0) smem_absmax_value[0] = local_abs_max; /* - DPCT1065:5: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:85: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); - if (item_ct1.get_local_id(2) == 0) + if(item_ct1.get_local_id(2) == 0) absmax[i/BLOCK_SIZE] = local_abs_max; else local_abs_max = smem_absmax_value[0]; @@ -922,13 +951,38 @@ void kQuantizeBlockwise(float *code, T *__restrict__ const A, float *absmax, if(STOCHASTIC) { - local_rand_idx = ((item_ct1.get_group(2) * NUM_BLOCK) + - (item_ct1.get_local_id(2) * NUM) + rand_offset) % - (1024 - 4); + local_rand_idx = ((item_ct1.get_group(2)*NUM_BLOCK) + (item_ct1.get_local_id(2)*NUM) + rand_offset) % (1024-4); /* - DPCT1007:99: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:88: Migration of cub::BlockLoad::Load is not supported. */ - LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); + + //LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); + + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_rand[local_rand_idx], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, rand_vals); + }); + + }); + + } unsigned char packed_4bit = 0; @@ -965,23 +1019,43 @@ void kQuantizeBlockwise(float *code, T *__restrict__ const A, float *absmax, } /* - DPCT1065:6: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:86: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* - DPCT1007:100: Migration of cub::BlockStore.Store is not supported. + DPCT1007:89: Migration of cub::BlockStore::Store is not supported. */ - StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i / 2 : i]), qvals, - (DATA_TYPE > 0) ? (valid_items + 1) / 2 - : valid_items); + + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_out[(DATA_TYPE > 0) ? i/2 : i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, qvals); + }); + + }); + + //StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); } } -template -void kDequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, - const int blocksize, const int n, +template +SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1) { @@ -994,11 +1068,17 @@ void kDequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; - typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + sycl::buffer buff_out(out, sycl::range<1>(NUM_PER_TH)); + sycl::buffer buff_A(A,sycl::range<1>(NUM_PER_TH)); + + + //typedef cub::BlockLoad LoadChar; + //typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; - for (unsigned int i = base_idx; i < n_load; - i += item_ct1.get_group_range(2) * TILE_SIZE) + + + + for (unsigned int i = base_idx; i < n_load; i += item_ct1.get_group_range(2)*TILE_SIZE) { if(DATA_TYPE > 0) { @@ -1011,22 +1091,46 @@ void kDequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; } /* - DPCT1026:101: The call to __ldg was removed because there is no - corresponding API in SYCL. + DPCT1098:92: The '*' expression is used instead of the __ldg call. These two expressions do not provide the exact same functionality. Check the generated code for potential precision and/or performance issues. + */ + /* + DPCT1064:96: Migrated __ldg call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code. */ - local_abs_max = - absmax[(i + item_ct1.get_local_id(2) * NUM_PER_TH) / (blocksize)]; + local_abs_max = absmax[(i+item_ct1.get_local_id(2)*NUM_PER_TH)/(blocksize)]; /* - DPCT1065:8: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:90: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* - DPCT1007:102: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:93: Migration of cub::BlockLoad::Load is not supported. */ - LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); + //LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_A[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, qvals); + }); + + }); + + switch(DATA_TYPE) { @@ -1035,10 +1139,12 @@ void kDequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) /* - DPCT1026:103: The call to __ldg was removed because there is no - corresponding API in SYCL. + DPCT1098:94: The '*' expression is used instead of the __ldg call. These two expressions do not provide the exact same functionality. Check the generated code for potential precision and/or performance issues. + */ + /* + DPCT1064:228: Migrated __ldg call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code. */ - vals[j] = code[qvals[j]] * local_abs_max; + vals[j] = code[qvals[j]]*local_abs_max; break; case FP4: #pragma unroll NUM_PER_TH @@ -1059,35 +1165,55 @@ void kDequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, } /* - DPCT1065:9: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:91: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:104: Migration of cub::BlockStore.Store is not supported. + DPCT1007:95: Migration of cub::BlockStore::Store is not supported. */ - StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i * 2 : i]), vals, - valid_items_store); + //StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); + + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_out[(DATA_TYPE > 0) ? i/2 : i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, vals); + }); + + }); } } -void kDequantize(float *code, unsigned char *A, float *out, const int n, +SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code) { - const unsigned int numThreads = - item_ct1.get_local_range(2) * item_ct1.get_group_range(2); - const int idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + - item_ct1.get_local_id(2); + const unsigned int numThreads = item_ct1.get_local_range(2) * item_ct1.get_group_range(2); + const int idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); - if (item_ct1.get_local_id(2) < 256) - { - smem_code[item_ct1.get_local_id(2)] = code[item_ct1.get_local_id(2)]; - } + + if(item_ct1.get_local_id(2) < 256) + { + smem_code[item_ct1.get_local_id(2)] = code[item_ct1.get_local_id(2)]; + } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::local_space); - for (int i = idx;i < n; i += numThreads) + for (int i = idx;i < n; i += numThreads) { out[i] = smem_code[A[i]]; } @@ -1096,7 +1222,10 @@ void kDequantize(float *code, unsigned char *A, float *out, const int n, template - +/* +DPCT1110:1: The total declared local variable size in device function kPreconditionOptimizer32bit2State exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, @@ -1105,8 +1234,7 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); - const int base_idx = - (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_VALS); + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_VALS); int valid_items = 0; T g_vals[NUM_VALS]; @@ -1114,59 +1242,50 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, float s1_vals[NUM_VALS]; float s2_vals[NUM_VALS]; - const float correction1 = 1.0f / (1.0f - sycl::pow(beta1, step)); - const float correction2 = 1.0f / (1.0f - sycl::pow(beta2, step)); + const float correction1 = 1.0f/(1.0f - dpct::pow(beta1, step)); + const float correction2 = 1.0f/(1.0f - dpct::pow(beta2, step)); typedef cub::BlockLoad Load; typedef cub::BlockLoad LoadFloat; typedef sycl::group<3> BlockReduce; - union type_ct2 { + union type_ct2{ typename Load::TempStorage load; typename LoadFloat::TempStorage loadf; typename BlockReduce::TempStorage reduce; }; type_ct2 &temp_storage = *(type_ct2 *)temp_storage_ct1; - for (unsigned int i = base_idx; i < n_full; - i += item_ct1.get_group_range(2) * BLOCK_SIZE) + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) { valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; /* - DPCT1065:10: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:97: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:105: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:101: Migration of cub::BlockLoad::Load is not supported. */ Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); /* - DPCT1065:11: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:98: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:106: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:102: Migration of cub::BlockLoad::Load is not supported. */ - LoadFloat(temp_storage.loadf) - .Load(&(state1[i]), s1_vals, valid_items, 0.0f); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); /* - DPCT1065:12: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:99: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:107: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:103: Migration of cub::BlockLoad::Load is not supported. */ - LoadFloat(temp_storage.loadf) - .Load(&(state2[i]), s2_vals, valid_items, 0.0f); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); -# pragma unroll NUM_VALS + # pragma unroll NUM_VALS for(unsigned int j = 0; j < NUM_VALS; j++) g_vals[j] = gnorm_scale*((float)g_vals[j]); @@ -1180,8 +1299,7 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); s1_vals[j] *= correction1; s2_vals[j] *= correction2; - s1_vals[j] = - s1_vals[j] / (sycl::sqrt(s2_vals[j]) + eps); // update + s1_vals[j] = s1_vals[j]/(sycl::sqrt(s2_vals[j])+eps); // update s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) break; } @@ -1192,17 +1310,13 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, s1_vals[0] += s1_vals[j]; /* - DPCT1065:13: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:100: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); - s1_vals[0] = sycl::reduce_over_group(item_ct1.get_group(), s1_vals[0], - sycl::plus<>()); + s1_vals[0] = sycl::reduce_over_group(item_ct1.get_group(), s1_vals[0], sycl::plus<>()); - if (item_ct1.get_local_id(2) == 0) - dpct::atomic_fetch_add( - &unorm[0], s1_vals[0]); + if(item_ct1.get_local_id(2) == 0) + dpct::atomic_fetch_add(&unorm[0], s1_vals[0]); sycl::group_barrier(item_ct1.get_sub_group()); } @@ -1213,7 +1327,7 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, #define NUM_PER_THREAD 4 template - +SYCL_EXTERNAL void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, @@ -1222,8 +1336,7 @@ void kOptimizer32bit2State(T* g, T* p, { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); - const int base_idx = - (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); int valid_items = 0; float update_scale = 0.0f; T g_vals[NUM_PER_THREAD]; @@ -1232,8 +1345,8 @@ void kOptimizer32bit2State(T* g, T* p, float s1_vals[NUM_PER_THREAD]; float s2_vals[NUM_PER_THREAD]; - const float correction1 = 1.0f - sycl::pow(beta1, step); - const float correction2 = sycl::sqrt(1.0f - sycl::pow(beta2, step)); + const float correction1 = 1.0f - dpct::pow(beta1, step); + const float correction2 = sycl::sqrt(1.0f - dpct::pow(beta2, step)); const float step_size = -lr*correction2/correction1; if(max_unorm > 0.0f) @@ -1244,13 +1357,20 @@ void kOptimizer32bit2State(T* g, T* p, } else{ update_scale = 1.0f; } - typedef cub::BlockLoad Load; - typedef cub::BlockStore Store; + sycl::buffer buff_smem_code(smem_code,sycl::range<1>(257)); + sycl::buffer buff_vals(vals, sycl::range<1>(NUM_PER_TH)); + sycl::buffer buff_A(A,sycl::range<1>(NUM_PER_TH)); + sycl::buffer buff_out(out,sycl::range<1>(NUM_PER_TH)); + sycl::buffer buff_rand(rand,sycl::range<1>(NUM_PER_TH)); + + + //typedef cub::BlockLoad Load; + //typedef cub::BlockStore Store; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockStore StoreFloat; + //typedef cub::BlockLoad LoadFloat; + //typedef cub::BlockStore StoreFloat; - union type_ct3 { + union type_ct3{ typename Load::TempStorage load; typename Store::TempStorage store; typename LoadFloat::TempStorage loadf; @@ -1258,53 +1378,44 @@ void kOptimizer32bit2State(T* g, T* p, }; type_ct3 &temp_storage = *(type_ct3 *)temp_storage_ct1; - for (unsigned int i = base_idx; i < n_full; - i += item_ct1.get_group_range(2) * TH * NUM_PER_THREAD) + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*TH*NUM_PER_THREAD) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; /* - DPCT1065:14: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:104: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* - DPCT1007:108: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:111: Migration of cub::BlockLoad::Load is not supported. */ Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); /* - DPCT1065:15: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:105: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:109: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:112: Migration of cub::BlockLoad::Load is not supported. */ LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); /* - DPCT1065:16: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:106: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:110: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:113: Migration of cub::BlockLoad::Load is not supported. */ LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); /* - DPCT1065:17: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:107: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:111: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:114: Migration of cub::BlockLoad::Load is not supported. */ Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); -# pragma unroll 4 + # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) g_vals[j] = gnorm_scale*((float)g_vals[j]); @@ -1318,18 +1429,7 @@ void kOptimizer32bit2State(T* g, T* p, { s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); - p_vals[j] = - ((float)p_vals - [j]) + - (update_scale * - step_size * - (s1_vals - [j] / - (sycl::sqrt( - s2_vals - [j]) + - (eps * - correction2)))); + p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sycl::sqrt(s2_vals[j])+(eps*correction2)))); if(weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); @@ -1339,40 +1439,34 @@ void kOptimizer32bit2State(T* g, T* p, } /* - DPCT1065:18: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:108: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:112: Migration of cub::BlockStore.Store is not supported. + DPCT1007:115: Migration of cub::BlockStore::Store is not supported. */ Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); /* - DPCT1065:19: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:109: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:113: Migration of cub::BlockStore.Store is not supported. + DPCT1007:116: Migration of cub::BlockStore::Store is not supported. */ StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); /* - DPCT1065:20: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:110: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:114: Migration of cub::BlockStore.Store is not supported. + DPCT1007:117: Migration of cub::BlockStore::Store is not supported. */ StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); } } template - +SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, @@ -1381,8 +1475,7 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); - const int base_idx = - (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_VALS); + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_VALS); int valid_items = 0; T g_vals[NUM_VALS]; @@ -1393,41 +1486,35 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, typedef cub::BlockLoad LoadFloat; typedef sycl::group<3> BlockReduce; - union type_ct4 { + union type_ct4{ typename Load::TempStorage load; typename LoadFloat::TempStorage loadf; typename BlockReduce::TempStorage reduce; }; type_ct4 &temp_storage = *(type_ct4 *)temp_storage_ct1; - for (unsigned int i = base_idx; i < n_full; - i += item_ct1.get_group_range(2) * BLOCK_SIZE) + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) { valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; /* - DPCT1065:21: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:118: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:115: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:121: Migration of cub::BlockLoad::Load is not supported. */ Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); /* - DPCT1065:22: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:119: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:116: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:122: Migration of cub::BlockLoad::Load is not supported. */ - LoadFloat(temp_storage.loadf) - .Load(&(state1[i]), s1_vals, valid_items, 0.0f); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); -# pragma unroll NUM_VALS + # pragma unroll NUM_VALS for(unsigned int j = 0; j < NUM_VALS; j++) g_vals[j] = gnorm_scale*((float)g_vals[j]); @@ -1448,14 +1535,12 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update - s1_vals[j] = (float)g_vals[j] / - (sycl::sqrt(s1_vals[j]) + eps); // update value + s1_vals[j] = (float)g_vals[j] / (sycl::sqrt(s1_vals[j])+eps); // update value s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; case ADAGRAD: s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update - s1_vals[j] = (float)g_vals[j] / - (sycl::sqrt(s1_vals[j]) + eps); // update value + s1_vals[j] = (float)g_vals[j] / (sycl::sqrt(s1_vals[j])+eps); // update value s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; } @@ -1466,27 +1551,23 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, s1_vals[0] += s1_vals[j]; /* - DPCT1065:23: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:120: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:24: Migration of cub::Sum is not supported. + DPCT1007:2: Migration of cub::Sum is not supported. */ - s1_vals[0] = - BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); - if (item_ct1.get_local_id(2) == 0) - dpct::atomic_fetch_add( - &unorm[0], s1_vals[0]); + if(item_ct1.get_local_id(2) == 0) + dpct::atomic_fetch_add(&unorm[0], s1_vals[0]); sycl::group_barrier(item_ct1.get_sub_group()); } } template - +SYCL_EXTERNAL void kOptimizer32bit1State(T *g, T *p, float *state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, @@ -1495,8 +1576,7 @@ void kOptimizer32bit1State(T *g, T *p, { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); - const int base_idx = - (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); int valid_items = 0; float update_scale = 0.0f; @@ -1519,7 +1599,7 @@ void kOptimizer32bit1State(T *g, T *p, typedef cub::BlockLoad LoadFloat; typedef cub::BlockStore StoreFloat; - union type_ct5 { + union type_ct5{ typename Load::TempStorage load; typename Store::TempStorage store; typename LoadFloat::TempStorage loadf; @@ -1527,43 +1607,36 @@ void kOptimizer32bit1State(T *g, T *p, }; type_ct5 &temp_storage = *(type_ct5 *)temp_storage_ct1; - for (unsigned int i = base_idx; i < n_full; - i += item_ct1.get_group_range(2) * TH * NUM_PER_THREAD) + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*TH*NUM_PER_THREAD) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; /* - DPCT1065:25: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:123: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:117: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:128: Migration of cub::BlockLoad::Load is not supported. */ Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); /* - DPCT1065:26: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:124: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:118: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:129: Migration of cub::BlockLoad::Load is not supported. */ LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); /* - DPCT1065:27: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:125: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:119: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:130: Migration of cub::BlockLoad::Load is not supported. */ Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); -# pragma unroll 4 + # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) { g_vals[j] = gnorm_scale*((float)g_vals[j]); @@ -1592,53 +1665,30 @@ void kOptimizer32bit1State(T *g, T *p, break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); - p_vals[j] = - ((float)p_vals - [j]) - - update_scale * - (lr * - (float)g_vals - [j] / - (sycl::sqrt( - (float)s1_vals - [j]) + - eps)); - break; + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*(float)g_vals[j] / (sycl::sqrt((float)s1_vals[j])+eps)); + break; case ADAGRAD: s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); - p_vals[j] = - ((float)p_vals - [j]) - - lr * - (float)g_vals - [j] / - (sycl::sqrt( - (float)s1_vals - [j]) + - eps); - break; + p_vals[j] = ((float)p_vals[j]) - lr*(float)g_vals[j] / (sycl::sqrt((float)s1_vals[j])+eps); + break; } } } /* - DPCT1065:28: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:126: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:120: Migration of cub::BlockStore.Store is not supported. + DPCT1007:131: Migration of cub::BlockStore::Store is not supported. */ Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); /* - DPCT1065:29: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:127: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:121: Migration of cub::BlockStore.Store is not supported. + DPCT1007:132: Migration of cub::BlockStore::Store is not supported. */ StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); } @@ -1650,7 +1700,10 @@ void kOptimizer32bit1State(T *g, T *p, #define NUM_PER_BLOCK 4096 template -void +/* +DPCT1110:6: The total declared local variable size in device function kPreconditionOptimizerStatic8bit2State exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +SYCL_EXTERNAL void kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, float *unorm, @@ -1663,12 +1716,8 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c float *smem_quantiles1, float *smem_quantiles2) { const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; - const int base_idx = - (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); - int valid_items = - n - (item_ct1.get_group(2) * NUM_PER_BLOCK) > NUM_PER_BLOCK - ? NUM_PER_BLOCK - : n - (item_ct1.get_group(2) * NUM_PER_BLOCK); + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); + int valid_items = n - (item_ct1.get_group(2)*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (item_ct1.get_group(2)*NUM_PER_BLOCK); float g_val = 0.0f; float local_max_s1 = -FLT_MAX; float local_max_s2 = -FLT_MAX; @@ -1684,67 +1733,58 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c typedef cub::BlockLoad LoadUInt8; typedef sycl::group<3> BlockReduce; - union type_ct6 { + + union type_ct6{ typename LoadT::TempStorage loadh; typename LoadUInt8::TempStorage loadc; typename BlockReduce::TempStorage reduce; }; type_ct6 &temp_storage = *(type_ct6 *)temp_storage_ct1; - if (item_ct1.get_local_id(2) < 256) + + + + if(item_ct1.get_local_id(2) < 256) { - smem_quantiles1[item_ct1.get_local_id(2)] = - quantiles1[item_ct1.get_local_id(2)]; - smem_quantiles2[item_ct1.get_local_id(2)] = - quantiles2[item_ct1.get_local_id(2)]; + smem_quantiles1[item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; + smem_quantiles2[item_ct1.get_local_id(2)] = quantiles2[item_ct1.get_local_id(2)]; } /* - DPCT1065:42: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:150: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); - for (unsigned int i = base_idx; i < n_full; - i += NUM_THREADS * item_ct1.get_group_range(2) * NUM8BIT) + for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*item_ct1.get_group_range(2)*NUM8BIT) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; /* - DPCT1007:129: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:156: Migration of cub::BlockLoad::Load is not supported. */ LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); /* - DPCT1065:45: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:153: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:130: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:157: Migration of cub::BlockLoad::Load is not supported. */ - LoadUInt8(temp_storage.loadc) - .Load(&(state1[i]), m_c1, valid_items, 128); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); /* - DPCT1065:46: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:154: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:131: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:158: Migration of cub::BlockLoad::Load is not supported. */ - LoadUInt8(temp_storage.loadc) - .Load(&(state2[i]), r_c2, valid_items, 128); + LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); /* - DPCT1065:47: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:155: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); -#pragma unroll 16 + #pragma unroll 16 for(int j = 0; j < NUM8BIT; j++) { g_val = g_vals[j]; @@ -1769,62 +1809,49 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c #pragma unroll 16 for(int j = 0; j < NUM8BIT; j++) { - float correction1 = 1.0f / (1.0f - sycl::pow(beta1, step)); - float correction2 = 1.0f / (1.0f - sycl::pow(beta2, step)); + float correction1 = 1.0f / (1.0f - dpct::pow(beta1, step)); + float correction2 = 1.0f / (1.0f - dpct::pow(beta2, step)); s1_vals[j] *= correction1; s2_vals[j] *= correction2; - float update_val = - s1_vals[j] / (sycl::sqrt(s2_vals[j]) + eps); // update + float update_val = s1_vals[j]/(sycl::sqrt(s2_vals[j])+eps); // update local_unorm += update_val*update_val; } } } /* - DPCT1065:43: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:151: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:48: Migration of cub::Reduce is not supported. + DPCT1007:7: Migration of cub::Reduce is not supported. */ - local_max_s1 = BlockReduce(temp_storage.reduce) - .Reduce(local_max_s1, sycl::maximum<>(), valid_items); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, sycl::maximum<>(), valid_items); /* - DPCT1065:44: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:152: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:49: Migration of cub::Reduce is not supported. + DPCT1007:8: Migration of cub::Reduce is not supported. */ - local_max_s2 = BlockReduce(temp_storage.reduce) - .Reduce(local_max_s2, sycl::maximum<>(), valid_items); + local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, sycl::maximum<>(), valid_items); if(unorm != NULL) { /* - DPCT1065:50: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:159: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:51: Migration of cub::Reduce is not supported. + DPCT1007:9: Migration of cub::Reduce is not supported. */ - local_unorm = BlockReduce(temp_storage.reduce) - .Reduce(local_unorm, sycl::plus<>(), valid_items); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, sycl::plus<>(), valid_items); } - if (item_ct1.get_local_id(2) == 0) + if(item_ct1.get_local_id(2) == 0) { atomicMax(&new_max1[0], local_max_s1); atomicMax(&new_max2[0], local_max_s2); - if (unorm != NULL) { - dpct::atomic_fetch_add( - &unorm[0], local_unorm); - } + if(unorm != NULL){ dpct::atomic_fetch_add(&unorm[0], local_unorm); } } } @@ -1833,7 +1860,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c #define NUM_PER_BLOCK2 4096 template -void +SYCL_EXTERNAL void kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, const float *unorm, const float max_unorm, const float param_norm, \ @@ -1847,17 +1874,14 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha float *smem_quantiles2, uint8_t *temp_storage_ct1) { - const int n_full = - (item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) * - NUM_PER_THREAD2; - const int base_idx = - (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD2); + const int n_full = (item_ct1.get_local_range(2) * item_ct1.get_group_range(2))*NUM_PER_THREAD2; + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD2); int valid_items = 0; float g_val = 0.0f; float s1_vals[NUM_PER_THREAD2]; float s2_vals[NUM_PER_THREAD2]; - const float correction1 = 1.0f - sycl::pow(beta1, step); - const float correction2 = sycl::sqrt(1.0f - sycl::pow(beta2, step)); + const float correction1 = 1.0f - dpct::pow(beta1, step); + const float correction2 = sycl::sqrt(1.0f - dpct::pow(beta2, step)); const float step_size = -lr*correction2/correction1; //const float step_size = -lr*correction2/correction1; float new_max_val1 = 1.0f/new_max1[0]; @@ -1882,7 +1906,10 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha typedef cub::BlockStore StoreChar; typedef cub::BlockStore StoreT; - union type_ct7 { + + + + union type_ct7{ typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; typename StoreChar::TempStorage storec; @@ -1890,68 +1917,54 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha }; type_ct7 &temp_storage = *(type_ct7 *)temp_storage_ct1; - if (item_ct1.get_local_id(2) < 512) + if(item_ct1.get_local_id(2) < 512) { - if (item_ct1.get_local_id(2) < 256) - smem_quantiles1[item_ct1.get_local_id(2)] = - quantiles1[item_ct1.get_local_id(2)]; + if(item_ct1.get_local_id(2) < 256) + smem_quantiles1[item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; else - smem_quantiles2[item_ct1.get_local_id(2) - 256] = - quantiles2[item_ct1.get_local_id(2) - 256]; + smem_quantiles2[item_ct1.get_local_id(2)-256] = quantiles2[item_ct1.get_local_id(2)-256]; } /* - DPCT1065:52: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:160: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); - for (unsigned int i = base_idx; i < n_full; - i += item_ct1.get_group_range(2) * NUM_THREADS2 * NUM_PER_THREAD2) + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*NUM_THREADS2*NUM_PER_THREAD2) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; /* - DPCT1007:132: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:167: Migration of cub::BlockLoad::Load is not supported. */ LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); /* - DPCT1065:53: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:161: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:133: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:168: Migration of cub::BlockLoad::Load is not supported. */ LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); /* - DPCT1065:54: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:162: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:134: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:169: Migration of cub::BlockLoad::Load is not supported. */ LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); /* - DPCT1065:55: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:163: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:135: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:170: Migration of cub::BlockLoad::Load is not supported. */ LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); - if ((i + (item_ct1.get_local_id(2) * NUM_PER_THREAD2) + - NUM_PER_THREAD2) > n) { - continue; - } + if((i + (item_ct1.get_local_id(2)*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } -# pragma unroll 4 + # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) { g_val = float(g_vals[j]); @@ -1965,8 +1978,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) - if (sycl::signbit(smem_quantiles1[c1s[j]]) != - sycl::signbit(s1_vals[j])) + if(sycl::signbit(smem_quantiles1[c1s[j]]) != sycl::signbit(s1_vals[j])) { if(s1_vals[j] > 0.0f) c1s[j] += 1; @@ -1983,42 +1995,33 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) { - p_vals[j] = (T)(((float)p_vals[j]) + - ((update_scale * step_size * - (s1_vals[j] / (sycl::sqrt(s2_vals[j]) + - (correction2 * eps)))))); + p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sycl::sqrt(s2_vals[j])+(correction2*eps)))))); if(weight_decay > 0.0f) p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); } /* - DPCT1007:136: Migration of cub::BlockStore.Store is not supported. + DPCT1007:171: Migration of cub::BlockStore::Store is not supported. */ StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); /* - DPCT1065:56: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:164: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:137: Migration of cub::BlockStore.Store is not supported. + DPCT1007:172: Migration of cub::BlockStore::Store is not supported. */ StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); /* - DPCT1065:57: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:165: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:138: Migration of cub::BlockStore.Store is not supported. + DPCT1007:173: Migration of cub::BlockStore::Store is not supported. */ StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); /* - DPCT1065:58: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:166: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); } @@ -2026,7 +2029,10 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha template -void +/* +DPCT1110:3: The total declared local variable size in device function kPreconditionOptimizerStatic8bit1State exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +SYCL_EXTERNAL void kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, float *unorm, @@ -2040,12 +2046,8 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c float *smem_quantiles1) { const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; - const int base_idx = - (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); - int valid_items = - n - (item_ct1.get_group(2) * NUM_PER_BLOCK) > NUM_PER_BLOCK - ? NUM_PER_BLOCK - : n - (item_ct1.get_group(2) * NUM_PER_BLOCK); + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); + int valid_items = n - (item_ct1.get_group(2)*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (item_ct1.get_group(2)*NUM_PER_BLOCK); float g_val = 0.0f; float local_max_s1 = -FLT_MAX; float local_unorm = 0.0f; @@ -2058,52 +2060,46 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c typedef cub::BlockLoad LoadUInt8; typedef sycl::group<3> BlockReduce; - union type_ct8 { + + union type_ct8{ typename LoadT::TempStorage loadh; typename LoadUInt8::TempStorage loadc; typename BlockReduce::TempStorage reduce; }; type_ct8 &temp_storage = *(type_ct8 *)temp_storage_ct1; - if (item_ct1.get_local_id(2) < 256) - smem_quantiles1[item_ct1.get_local_id(2)] = - quantiles1[item_ct1.get_local_id(2)]; + + + if(item_ct1.get_local_id(2) < 256) + smem_quantiles1[item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; /* - DPCT1065:30: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:133: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); - for (unsigned int i = base_idx; i < n_full; - i += item_ct1.get_group_range(2) * NUM_THREADS * NUM8BIT) + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*NUM_THREADS*NUM8BIT) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; /* - DPCT1065:32: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:135: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:122: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:137: Migration of cub::BlockLoad::Load is not supported. */ LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); /* - DPCT1065:33: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:136: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:123: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:138: Migration of cub::BlockLoad::Load is not supported. */ - LoadUInt8(temp_storage.loadc) - .Load(&(state1[i]), m_c1, valid_items, 128); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); -#pragma unroll 16 + #pragma unroll 16 for(int j = 0; j < NUM8BIT; j++) { g_val = g_vals[j]; @@ -2132,42 +2128,31 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c } /* - DPCT1065:31: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:134: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:34: Migration of cub::Reduce is not supported. + DPCT1007:4: Migration of cub::Reduce is not supported. */ - local_max_s1 = BlockReduce(temp_storage.reduce) - .Reduce(local_max_s1, sycl::maximum<>(), valid_items); - if (item_ct1.get_local_id(2) == 0) { - atomicMax(&new_max1[0], local_max_s1); - } + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, sycl::maximum<>(), valid_items); + if(item_ct1.get_local_id(2) == 0){ atomicMax(&new_max1[0], local_max_s1); } if(unorm != NULL) { /* - DPCT1065:35: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:139: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:36: Migration of cub::Reduce is not supported. + DPCT1007:5: Migration of cub::Reduce is not supported. */ - local_unorm = BlockReduce(temp_storage.reduce) - .Reduce(local_unorm, sycl::plus<>(), valid_items); - if (item_ct1.get_local_id(2) == 0) { - dpct::atomic_fetch_add( - &unorm[0], local_unorm); - } + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, sycl::plus<>(), valid_items); + if(item_ct1.get_local_id(2) == 0){ dpct::atomic_fetch_add(&unorm[0], local_unorm); } } } template -void +SYCL_EXTERNAL void kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, @@ -2181,11 +2166,8 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, uint8_t *temp_storage_ct1) { - const int n_full = - (item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) * - NUM_PER_THREAD2; - const int base_idx = - (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD2); + const int n_full = (item_ct1.get_local_range(2) * item_ct1.get_group_range(2))*NUM_PER_THREAD2; + const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD2); int valid_items = 0; float g_val = 0.0f; float s1_vals[NUM_PER_THREAD2]; @@ -2209,7 +2191,9 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, typedef cub::BlockStore StoreChar; typedef cub::BlockStore StoreT; - union type_ct9 { + + + union type_ct9{ typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; typename StoreChar::TempStorage storec; @@ -2217,52 +2201,41 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, }; type_ct9 &temp_storage = *(type_ct9 *)temp_storage_ct1; - if (item_ct1.get_local_id(2) < 256) - smem_quantiles1[item_ct1.get_local_id(2)] = - quantiles1[item_ct1.get_local_id(2)]; + if(item_ct1.get_local_id(2) < 256) + smem_quantiles1[item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; /* - DPCT1065:37: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:140: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); - for (unsigned int i = base_idx; i < n_full; - i += item_ct1.get_group_range(2) * NUM_THREADS2 * NUM_PER_THREAD2) + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*NUM_THREADS2*NUM_PER_THREAD2) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; /* - DPCT1007:124: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:145: Migration of cub::BlockLoad::Load is not supported. */ LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); /* - DPCT1065:38: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:141: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:125: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:146: Migration of cub::BlockLoad::Load is not supported. */ LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); /* - DPCT1065:39: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:142: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:126: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:147: Migration of cub::BlockLoad::Load is not supported. */ LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); - if ((i + (item_ct1.get_local_id(2) * NUM_PER_THREAD2) + - NUM_PER_THREAD2) > n) { - continue; - } + if((i + (item_ct1.get_local_id(2)*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } -# pragma unroll 4 + # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) { g_val = float(g_vals[j]); @@ -2298,16 +2271,14 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); - p_vals[j] = ((float)p_vals[j]) - - (lr * g_val / (sycl::sqrt(s1_vals[j]) + eps)); + p_vals[j] = ((float)p_vals[j]) - (lr*g_val / (sycl::sqrt(s1_vals[j])+eps)); break; } c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); // make sure state1 term has still the same sign after quantization - if (sycl::signbit(smem_quantiles1[c1s[j]]) != - sycl::signbit(s1_vals[j])) + if(sycl::signbit(smem_quantiles1[c1s[j]]) != sycl::signbit(s1_vals[j])) { if(s1_vals[j] > 0.0f) c1s[j] += 1; @@ -2317,79 +2288,74 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, } /* - DPCT1007:127: Migration of cub::BlockStore.Store is not supported. + DPCT1007:148: Migration of cub::BlockStore::Store is not supported. */ StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); /* - DPCT1065:40: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:143: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:128: Migration of cub::BlockStore.Store is not supported. + DPCT1007:149: Migration of cub::BlockStore::Store is not supported. */ StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); /* - DPCT1065:41: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:144: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); } } -template -void kPercentileClipping(T *__restrict__ g, float *gnorm_vec, int step, - const int n, const sycl::nd_item<3> &item_ct1) + +template +SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n, + const sycl::nd_item<3> &item_ct1) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); int valid_items = 0; + typedef cub::BlockLoad LoadT; - typename BlockReduce::TempStorage &reduce; + + + T vals[NUM_VALS]; float local_sum = 0.0f; - for (unsigned int i = (item_ct1.get_group(2) * BLOCK_SIZE); i < n_full; - i += item_ct1.get_group_range(2) * BLOCK_SIZE) + for (unsigned int i = (item_ct1.get_group(2) * BLOCK_SIZE); i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) { valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; local_sum = 0.0f; /* - DPCT1065:75: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:202: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:151: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:203: Migration of cub::BlockLoad::Load is not supported. */ LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); -#pragma unroll NUM_VALS + #pragma unroll NUM_VALS for(int j = 0; j < NUM_VALS; j++) local_sum += ((float)vals[j])*((float)vals[j]); /* - DPCT1007:76: Migration of cub::Sum is not supported. + DPCT1007:12: Migration of cub::Sum is not supported. */ local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); - if (item_ct1.get_local_id(2) == 0) + if(item_ct1.get_local_id(2) == 0) { if(step == 1) { // initialize with the same norm for all positions //#pragma unroll 10 for(int j = 0; j < 100; j++) - dpct::atomic_fetch_add( - &gnorm_vec[j], local_sum); + dpct::atomic_fetch_add(&gnorm_vec[j], local_sum); } else - dpct::atomic_fetch_add( - &gnorm_vec[step % 100], local_sum); + dpct::atomic_fetch_add(&gnorm_vec[step % 100], local_sum); } } @@ -2398,19 +2364,24 @@ void kPercentileClipping(T *__restrict__ g, float *gnorm_vec, int step, #define LANES 2 #define QUAD 3 -template - -void kOptimizerStatic8bit2StateBlockwise( - T *p, T *__restrict__ const g, unsigned char *state1, unsigned char *state2, - const float beta1, const float beta2, const float eps, const int step, - const float lr, float *__restrict__ const quantiles1, - float *__restrict__ const quantiles2, float *absmax1, float *absmax2, - float weight_decay, const float gnorm_scale, const bool skip_zeros, - const int n, const sycl::nd_item<3> &item_ct1, - sycl::local_accessor smem_quantiles1, - sycl::local_accessor smem_quantiles2, - float *smem_exchange1, - float *smem_exchange2, uint8_t *temp_storage_ct1) +template +/* +DPCT1110:10: The total declared local variable size in device function kOptimizerStatic8bit2StateBlockwise exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +SYCL_EXTERNAL +void +kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + sycl::local_accessor smem_quantiles2, + float *smem_exchange1, float *smem_exchange2, + uint8_t *temp_storage_ct1) { //const int n_full = n + (n%BLOCK_SIZE); @@ -2421,9 +2392,9 @@ void kOptimizerStatic8bit2StateBlockwise( float s1_vals[N_PER_TH]; float s2_vals[N_PER_TH]; // 2-5% - const float correction1 = 1.0f - sycl::pow(beta1, step); - const float correction2 = sycl::sqrt(1.0f - sycl::pow(beta2, step)); - const float step_size = (-lr * correction2) / correction1; + const float correction1 = 1.0f - dpct::pow(beta1, step); + const float correction2 = sycl::sqrt(1.0f -dpct::pow(beta2, step)); + const float step_size = (-lr*correction2) / correction1; const int lane_id = item_ct1.get_local_id(2) % LANES; float new_local_abs_max1 = -FLT_MAX; float new_local_abs_max2 = -FLT_MAX; @@ -2440,16 +2411,16 @@ void kOptimizerStatic8bit2StateBlockwise( typedef cub::BlockStore StoreChar; typedef cub::BlockStore StoreT; - - //__shared__ float smem_quantiles1[LANES][257]; - //__shared__ float smem_quantiles2[LANES][257]; - typedef cub::BlockReduce BlockReduce1; - typedef cub::BlockReduce BlockReduce2; - typename BlockReduce1::TempStorage reduce1; - typename BlockReduce2::TempStorage reduce2; - - union type_ct10 { + + + + + + + + + union type_ct10{ typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; typename StoreChar::TempStorage storec; @@ -2459,66 +2430,54 @@ void kOptimizerStatic8bit2StateBlockwise( // init: 0.2 -> 0.23 // 0.23 -> 0.23 - smem_quantiles1[0][item_ct1.get_local_id(2)] = - quantiles1[item_ct1.get_local_id(2)]; - smem_quantiles2[0][item_ct1.get_local_id(2)] = - quantiles2[item_ct1.get_local_id(2)]; -# pragma unroll + smem_quantiles1[0][item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; + smem_quantiles2[0][item_ct1.get_local_id(2)] = quantiles2[item_ct1.get_local_id(2)]; + # pragma unroll for(unsigned int j = 1; j < LANES; j++) { - smem_quantiles1[j][item_ct1.get_local_id(2)] = - smem_quantiles1[0][item_ct1.get_local_id(2)]; - smem_quantiles2[j][item_ct1.get_local_id(2)] = - smem_quantiles2[0][item_ct1.get_local_id(2)]; + smem_quantiles1[j][item_ct1.get_local_id(2)] = smem_quantiles1[0][item_ct1.get_local_id(2)]; + smem_quantiles2[j][item_ct1.get_local_id(2)] = smem_quantiles2[0][item_ct1.get_local_id(2)]; } /* - DPCT1065:59: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:174: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); -#pragma unroll + #pragma unroll for(int k = 0; k < QUAD; k++) { quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; } - for (unsigned int i = base_idx; i < n_full; - i += item_ct1.get_group_range(2) * BLOCK_SIZE) + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) { // loads: 0.23 -> 0.85/1.44 valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; /* - DPCT1065:60: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:175: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:139: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:183: Migration of cub::BlockLoad::Load is not supported. */ LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); /* - DPCT1065:61: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:176: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:140: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:184: Migration of cub::BlockLoad::Load is not supported. */ LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); /* - DPCT1065:62: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:177: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:141: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:185: Migration of cub::BlockLoad::Load is not supported. */ LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); @@ -2529,15 +2488,14 @@ void kOptimizerStatic8bit2StateBlockwise( # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - if (!sycl::isnan((float)g_vals[j]) && - !sycl::isinf((float)g_vals[j])) - { + if(!sycl::isnan((float)g_vals[j]) && !sycl::isinf((float)g_vals[j])) + { s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; g_val = g_vals[j]; //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; g_val *= gnorm_scale; - + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; @@ -2549,33 +2507,27 @@ void kOptimizerStatic8bit2StateBlockwise( s2_vals[j] = 0.0f; } - new_local_abs_max1 = - sycl::fmax(new_local_abs_max1, sycl::fabs(s1_vals[j])); - new_local_abs_max2 = - sycl::fmax(new_local_abs_max2, sycl::fabs(s2_vals[j])); + new_local_abs_max1 = sycl::fmax(new_local_abs_max1, sycl::fabs(s1_vals[j])); + new_local_abs_max2 = sycl::fmax(new_local_abs_max2, sycl::fabs(s2_vals[j])); } // reduce: 2.51/1.60 -> 2.67/1.69 - new_local_abs_max1 = sycl::reduce_over_group( - item_ct1.get_group(), new_local_abs_max1, sycl::maximum<>()); - new_local_abs_max2 = sycl::reduce_over_group( - item_ct1.get_group(), new_local_abs_max2, sycl::maximum<>()); + new_local_abs_max1 = sycl::reduce_over_group(item_ct1.get_group(), new_local_abs_max1, sycl::maximum<>()); + new_local_abs_max2 = sycl::reduce_over_group(item_ct1.get_group(), new_local_abs_max2, sycl::maximum<>()); - if (item_ct1.get_local_id(2) == 0) + if(item_ct1.get_local_id(2) == 0) { smem_exchange1[0] = new_local_abs_max1; smem_exchange2[0] = new_local_abs_max2; } /* - DPCT1065:63: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:178: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); - if (item_ct1.get_local_id(2) == 0) + if(item_ct1.get_local_id(2) == 0) { absmax1[i/BLOCK_SIZE] = new_local_abs_max1; absmax2[i/BLOCK_SIZE] = new_local_abs_max2; @@ -2587,13 +2539,11 @@ void kOptimizerStatic8bit2StateBlockwise( } /* - DPCT1065:64: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:179: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:142: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:186: Migration of cub::BlockLoad::Load is not supported. */ LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); // reduce: 2.67/1.69 -> 2.67/1.70 @@ -2601,33 +2551,21 @@ void kOptimizerStatic8bit2StateBlockwise( for(unsigned int j = 0; j < N_PER_TH; j++) { //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) - if (!sycl::isnan((float)g_vals[j]) && - !sycl::isinf((float)g_vals[j])) - { - p_vals[j] = - (T)(((float)p_vals - [j]) + - ((step_size * - (s1_vals[j] / - (sycl::sqrt( - s2_vals - [j]) + - (correction2 * - eps)))))); - if(weight_decay > 0.0f) + if(!sycl::isnan((float)g_vals[j]) && !sycl::isinf((float)g_vals[j])) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(s1_vals[j] / (sycl::sqrt(s2_vals[j])+(correction2*eps)))))); + if(weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } } // store: 0.85/1.44 -> 2.48/1.57 /* - DPCT1065:65: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:180: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:143: Migration of cub::BlockStore.Store is not supported. + DPCT1007:187: Migration of cub::BlockStore::Store is not supported. */ StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); @@ -2635,15 +2573,12 @@ void kOptimizerStatic8bit2StateBlockwise( # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], - s1_vals[j] / new_local_abs_max1); - c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], - s2_vals[j] / new_local_abs_max2); + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], s1_vals[j] / new_local_abs_max1); + c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], s2_vals[j] / new_local_abs_max2); // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) - if (sycl::signbit(smem_quantiles1[lane_id][c1s[j]]) != - sycl::signbit(s1_vals[j])) + if(sycl::signbit(smem_quantiles1[lane_id][c1s[j]]) != sycl::signbit(s1_vals[j])) { if(s1_vals[j] > 0.0f) c1s[j] += 1; @@ -2653,23 +2588,19 @@ void kOptimizerStatic8bit2StateBlockwise( } /* - DPCT1065:66: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:181: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:144: Migration of cub::BlockStore.Store is not supported. + DPCT1007:188: Migration of cub::BlockStore::Store is not supported. */ StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); /* - DPCT1065:67: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:182: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:145: Migration of cub::BlockStore.Store is not supported. + DPCT1007:189: Migration of cub::BlockStore::Store is not supported. */ StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); } @@ -2678,17 +2609,22 @@ void kOptimizerStatic8bit2StateBlockwise( #define LANES 2 #define QUAD 3 -template - -void kOptimizerStatic8bit1StateBlockwise( - T *p, T *__restrict__ const g, unsigned char *state1, const float beta1, - const float beta2, const float eps, const int step, const float lr, - float *__restrict__ const quantiles1, float *absmax1, float weight_decay, - const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, - sycl::local_accessor smem_quantiles1, - float *smem_exchange1, - uint8_t *temp_storage_ct1) +template +/* +DPCT1110:11: The total declared local variable size in device function kOptimizerStatic8bit1StateBlockwise exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +SYCL_EXTERNAL +void +kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + float *smem_exchange1, uint8_t *temp_storage_ct1) { //const int n_full = n + (n%BLOCK_SIZE); @@ -2712,12 +2648,12 @@ void kOptimizerStatic8bit1StateBlockwise( typedef cub::BlockStore StoreChar; typedef cub::BlockStore StoreT; - //__shared__ float smem_quantiles1[LANES][257]; - typedef cub::BlockReduce BlockReduce1; - __shared__ typename BlockReduce1::TempStorage reduce1; - //__shared__ float smem_exchange1[1]; + + + + - union type_ct11 { + union type_ct11{ typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; typename StoreChar::TempStorage storec; @@ -2727,57 +2663,46 @@ void kOptimizerStatic8bit1StateBlockwise( // init: 0.2 -> 0.23 // 0.23 -> 0.23 - smem_quantiles1[0][item_ct1.get_local_id(2)] = - quantiles1[item_ct1.get_local_id(2)]; -# pragma unroll + smem_quantiles1[0][item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; + # pragma unroll for(unsigned int j = 1; j < LANES; j++) - smem_quantiles1[j][item_ct1.get_local_id(2)] = - smem_quantiles1[0][item_ct1.get_local_id(2)]; + smem_quantiles1[j][item_ct1.get_local_id(2)] = smem_quantiles1[0][item_ct1.get_local_id(2)]; /* - DPCT1065:68: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:190: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); -#pragma unroll + #pragma unroll for(int k = 0; k < QUAD; k++) quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; - for (unsigned int i = base_idx; i < n_full; - i += item_ct1.get_group_range(2) * BLOCK_SIZE) + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) { // loads: 0.23 -> 0.85/1.44 valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; /* - DPCT1065:69: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:191: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:146: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:197: Migration of cub::BlockLoad::Load is not supported. */ LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); /* - DPCT1065:70: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:192: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:147: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:198: Migration of cub::BlockLoad::Load is not supported. */ LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); /* - DPCT1065:71: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:193: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:148: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:199: Migration of cub::BlockLoad::Load is not supported. */ LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); @@ -2828,26 +2753,22 @@ void kOptimizerStatic8bit1StateBlockwise( } } - new_local_abs_max1 = - sycl::fmax(new_local_abs_max1, sycl::fabs(s1_vals[j])); + new_local_abs_max1 = sycl::fmax(new_local_abs_max1, sycl::fabs(s1_vals[j])); } // reduce: 2.51/1.60 -> 2.67/1.69 - new_local_abs_max1 = sycl::reduce_over_group( - item_ct1.get_group(), new_local_abs_max1, sycl::maximum<>()); + new_local_abs_max1 = sycl::reduce_over_group(item_ct1.get_group(), new_local_abs_max1, sycl::maximum<>()); - if (item_ct1.get_local_id(2) == 0) + if(item_ct1.get_local_id(2) == 0) smem_exchange1[0] = new_local_abs_max1; /* - DPCT1065:72: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:194: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); - if (item_ct1.get_local_id(2) == 0) + if(item_ct1.get_local_id(2) == 0) absmax1[i/BLOCK_SIZE] = new_local_abs_max1; else new_local_abs_max1 = smem_exchange1[0]; @@ -2868,41 +2789,23 @@ void kOptimizerStatic8bit1StateBlockwise( break; case RMSPROP: g_val = g_vals[j]; - p_vals[j] = - ((float)p_vals - [j]) - - lr * - (g_val / - (sycl::sqrt( - s1_vals - [j]) + - eps)); - break; + p_vals[j] = ((float)p_vals[j]) - lr*(g_val / (sycl::sqrt(s1_vals[j])+eps)); + break; case ADAGRAD: g_val = g_vals[j]; - p_vals[j] = - ((float)p_vals - [j]) - - lr * - (g_val / - (sycl::sqrt( - s1_vals - [j]) + - eps)); - break; + p_vals[j] = ((float)p_vals[j]) - lr*(g_val / (sycl::sqrt(s1_vals[j])+eps)); + break; } } } // store: 0.85/1.44 -> 2.48/1.57 /* - DPCT1065:73: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:195: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:149: Migration of cub::BlockStore.Store is not supported. + DPCT1007:200: Migration of cub::BlockStore::Store is not supported. */ StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); @@ -2910,13 +2813,11 @@ void kOptimizerStatic8bit1StateBlockwise( # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], - s1_vals[j] / new_local_abs_max1); + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], s1_vals[j] / new_local_abs_max1); // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) - if (sycl::signbit(smem_quantiles1[lane_id][c1s[j]]) != - sycl::signbit(s1_vals[j])) + if(sycl::signbit(smem_quantiles1[lane_id][c1s[j]]) != sycl::signbit(s1_vals[j])) { if(s1_vals[j] > 0.0f) c1s[j] += 1; @@ -2926,13 +2827,11 @@ void kOptimizerStatic8bit1StateBlockwise( } /* - DPCT1065:74: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. + DPCT1065:196: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:150: Migration of cub::BlockStore.Store is not supported. + DPCT1007:201: Migration of cub::BlockStore::Store is not supported. */ StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); } @@ -2953,10 +2852,9 @@ template BlockRowSum; typedef cub::BlockExchange BlockExchange; - union type_ct12 { + union type_ct12{ typename BlockExchange::TempStorage exchange; typename BlockRowReduce::TempStorage rowreduce; typename BlockRowSum::TempStorage rowsum; @@ -2973,6 +2871,9 @@ template{0.0f} - .convert()[0]); + LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, sycl::vec(0.0f).convert()[0]); -#pragma unroll ITEMS_PER_THREAD + #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_data[j] = sycl::fabs(local_data[j]); + if(SPARSE_DECOMP) #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) @@ -3043,10 +2943,7 @@ template{local_data[j]} - .convert()[0]); + local_col_absmax_values[j] = sycl::fmax(local_col_absmax_values[j], sycl::vec(local_data[j]).convert()[0]); // 3. compute row max (per block); store in smem to accumulate full global mem transation @@ -3056,29 +2953,23 @@ template()); + row_absmax = (float)dpct::group::reduce(item_ct1, local_data_fp32, sycl::maximum<>()); if(SPARSE_DECOMP) { /* - DPCT1065:84: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:214: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); - local_row_nnz_count = sycl::reduce_over_group( - item_ct1.get_group(), local_row_nnz_count, sycl::plus<>()); + local_row_nnz_count = sycl::reduce_over_group(item_ct1.get_group(), local_row_nnz_count, sycl::plus<>()); } // we store the data temporarily in shared memory so we // can execute a full atomic block transaction into global memory later // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores - if (item_ct1.get_local_id(2) == 0) + if(item_ct1.get_local_id(2) == 0) { smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax; // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block @@ -3086,80 +2977,63 @@ template( - sycl::half *__restrict__ A, float *rowStats, float *colStats, - int *nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, - int tiledCols, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, - float *smem_row_absmax_values, int *smem_row_nnz_values); -template void kgetColRowStats( - sycl::half *__restrict__ A, float *rowStats, float *colStats, - int *nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, - int tiledCols, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, - float *smem_row_absmax_values, int *smem_row_nnz_values); +template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, + const sycl::nd_item<3> &item_ct1, + uint8_t *temp_storage_ct1, + float *smem_row_absmax_values, + int *smem_row_nnz_values); +template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, + const sycl::nd_item<3> &item_ct1, + uint8_t *temp_storage_ct1, + float *smem_row_absmax_values, + int *smem_row_nnz_values); #define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) -template -void kdequant_mm_int32_fp16( - int *__restrict__ const A, float *__restrict__ const rowStats, - float *__restrict__ const colStats, sycl::half *out, float *newRowStats, - float *newcolStats, sycl::half *__restrict__ const bias, const int numRows, - const int numCols, const int tileCols, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_rowStats) +template void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, + const sycl::nd_item<3> &item_ct1, + float *smem_rowStats) { // Strategy: To dequantize we need to load col/row statistics. This can be very expensive @@ -3177,7 +3051,7 @@ void kdequant_mm_int32_fp16( // data is in 32 column-tile major with tile width 32 columns and numRows rows // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. - // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) // C2. Compute normalization values and store col values in register // S1. Store C1 into 16-bit output @@ -3193,11 +3067,9 @@ void kdequant_mm_int32_fp16( // we have tiles of size numRows*32, thus col only increases every numRows // num_row_tiles is the tiles after which the column increases by 32 // blockIdx.x is the index of the current tile - int col = ((item_ct1.get_local_id(2) % 32) + - ((item_ct1.get_group(2) / num_row_tiles) * 32)); + int col = ((item_ct1.get_local_id(2) % 32) + ((item_ct1.get_group(2)/num_row_tiles)*32)); // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached - int base_row = - (item_ct1.get_group(2) * SUBTILE_ROWS) % (num_row_tiles * SUBTILE_ROWS); + int base_row = (item_ct1.get_group(2)*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD @@ -3213,20 +3085,19 @@ void kdequant_mm_int32_fp16( int local_values[ITEMS_PER_THREAD]; sycl::half local_output[ITEMS_PER_THREAD]; float local_rowStats[ITEMS_PER_THREAD]; + typedef cub::BlockLoad LoadInt32; typedef cub::BlockExchange ExchangeInt32; + + + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. float colStat = col >= numCols ? 0.0f : colStats[col]; - float local_biasValue = - ((bias == NULL) || (col >= numCols)) - ? 0.0f - : sycl::vec{bias[col]} - .convert()[0]; + float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : sycl::vec(bias[col]).convert()[0]; // no block loads for rows for now -- keep it simple - for (int j = item_ct1.get_local_id(2); j < SUBTILE_ROWS; - j += item_ct1.get_local_range(2)) + for(int j = item_ct1.get_local_id(2); j < SUBTILE_ROWS; j+=item_ct1.get_local_range(2)) { // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? int row = (base_row+j) % numRows; // wrap around @@ -3236,22 +3107,19 @@ void kdequant_mm_int32_fp16( smem_rowStats[j] = rowStats[row]; } /* - DPCT1065:78: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:205: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); + // each block processes SUBTILE_ROWS*32 elements const int items_per_load = THREADS*ITEMS_PER_THREAD; const int rows_per_load = items_per_load/32; - int subtile_base_row = - (item_ct1.get_local_id(2) / 32) * ITEMS_PER_THREAD; // row within the tile + int subtile_base_row = (item_ct1.get_local_id(2) / 32)*ITEMS_PER_THREAD; // row within the tile int row_offset = 0; // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed - int subtile_start = (item_ct1.get_group(2) / num_row_tiles) * (numRows * 32) + - (base_row * 32); + int subtile_start = (item_ct1.get_group(2)/num_row_tiles)*(numRows*32) + (base_row*32); for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load) { int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset); @@ -3259,29 +3127,23 @@ void kdequant_mm_int32_fp16( if(valid_items <= 0) // the sub-tile might have more elements than the tile itself break; - // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) /* - DPCT1007:152: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:206: Migration of cub::BlockLoad::Load is not supported. */ LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); /* - DPCT1007:153: Migration of cub::BlockExchange.BlockedToWarpStriped is not - supported. + DPCT1007:207: Migration of cub::BlockExchange::BlockedToWarpStriped is not supported. */ - ExchangeInt32(exchangeint32) - .BlockedToWarpStriped(local_values, local_values); + ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); -#pragma unroll ITEMS_PER_THREAD + #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) - local_output[j] = - sycl::vec{((local_values[j] * MM_DEQUANT_CONST * - local_rowStats[j] * colStat) + - local_biasValue)} - .convert()[0]; + local_output[j] = sycl::vec((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue).convert()[0]; //absmax_col = fmax(fabsf(local_output[j]), absmax_col); // we store data in row major @@ -3301,17 +3163,11 @@ void kdequant_mm_int32_fp16( } } -template -void kDoubleRowColQuant(sycl::half *__restrict__ const A, - float *__restrict__ const rowStats, - float *__restrict__ const colStats, - char *out_col_normed, char *out_row_normed, int *rowidx, - int *colidx, sycl::half *val, - int *__restrict__ nnz_block_ptr, float threshold, - int rows, int cols, int tiledCols, - const sycl::nd_item<3> &item_ct1, - float *smem_row_stats, unsigned int *smem_nnz_row_idx) + +template void kDoubleRowColQuant(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, + const sycl::nd_item<3> &item_ct1, + float *smem_row_stats, + unsigned int *smem_nnz_row_idx) { // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD // Each thread reads the same column but multiple rows @@ -3325,18 +3181,19 @@ void kDoubleRowColQuant(sycl::half *__restrict__ const A, // each block loads TILE_COLs columns and TILE_ROW rows // after reading a tile the row counter increase by TILE_ROWS // the col counter reset after reading TILE_COL elements - const int base_row = - ((item_ct1.get_group(2) * TILE_COLS) / tiledCols) * TILE_ROWS; + const int base_row = ((item_ct1.get_group(2)*TILE_COLS)/tiledCols)*TILE_ROWS; // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached - const int base_col = (item_ct1.get_group(2) * TILE_COLS) % tiledCols; + const int base_col = (item_ct1.get_group(2)*TILE_COLS) % tiledCols; const int base_idx = (base_row*cols) + base_col; const int items_per_load = ITEMS_PER_THREAD*THREADS; - typedef cub::BlockLoad - LoadHalf; - + typedef cub::BlockLoad LoadHalf; + typedef cub::BlockStore StoreInt8; + + + + sycl::half local_data[ITEMS_PER_THREAD]; float local_col_stats[ITEMS_PER_THREAD]; @@ -3345,25 +3202,22 @@ void kDoubleRowColQuant(sycl::half *__restrict__ const A, // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) - if (base_col + (item_ct1.get_local_id(2) * ITEMS_PER_THREAD) + j < cols) - local_col_stats[j] = - 127.0f / colStats[base_col + - (item_ct1.get_local_id(2) * ITEMS_PER_THREAD) + j]; + if(base_col+(item_ct1.get_local_id(2)*ITEMS_PER_THREAD) + j < cols) + /* + DPCT1064:221: Migrated __fdividef call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code. + */ + local_col_stats[j] = 127.0f / colStats[base_col+(item_ct1.get_local_id(2)*ITEMS_PER_THREAD)+j]; - for (int i = item_ct1.get_local_id(2); i < TILE_ROWS; - i += item_ct1.get_local_range(2)) + for(int i = item_ct1.get_local_id(2); i < TILE_ROWS; i+=item_ct1.get_local_range(2)) { if(base_row + i < rows) smem_row_stats[i] = rowStats[base_row+i]; if(SPARSE_DECOMP) - smem_nnz_row_idx[i] = - nnz_block_ptr[(TILE_ROWS * item_ct1.get_group(2)) + i]; + smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*item_ct1.get_group(2)) + i]; } /* - DPCT1065:85: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:216: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); @@ -3375,8 +3229,9 @@ void kDoubleRowColQuant(sycl::half *__restrict__ const A, int i = base_idx + (row*cols); int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + /* - DPCT1007:156: Migration of cub::BlockLoad.Load is not supported. + DPCT1007:218: Migration of cub::BlockLoad::Load is not supported. */ LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f); float row_stat = 127.0f / smem_row_stats[row]; @@ -3389,47 +3244,29 @@ void kDoubleRowColQuant(sycl::half *__restrict__ const A, // what this does is float/absmax*127 = int8 if(SPARSE_DECOMP) { - if (sycl::fabs((float)local_data[j]) >= threshold) + if(sycl::fabs((float)local_data[j]) >= threshold) { local_quantized_data[j] = 0; - int old_idx = - dpct::atomic_fetch_compare_inc< - sycl::access::address_space:: - generic_space>( - &smem_nnz_row_idx[row], - UINT_MAX); + int old_idx = dpct::atomic_fetch_compare_inc(&smem_nnz_row_idx[row], UINT_MAX); rowidx[old_idx] = base_row+row; - colidx[old_idx] = - base_col + (item_ct1.get_local_id(2) * ITEMS_PER_THREAD) + j; + colidx[old_idx] = base_col+(item_ct1.get_local_id(2)*ITEMS_PER_THREAD)+j; val[old_idx] = local_data[j]; } else { - local_quantized_data[j] = - (char)(sycl::rint( - sycl::vec{ - local_data[j]} - .convert< - float, - sycl::rounding_mode:: - automatic>()[0] * - row_stat)); - } + local_quantized_data[j] = (char)(sycl::rint(sycl::vec(local_data[j]).convert()[0]*row_stat)); + } } else - local_quantized_data[j] = (char)(sycl::rint( - sycl::vec{local_data[j]} - .convert()[0] * - row_stat)); + local_quantized_data[j] = (char)(sycl::rint(sycl::vec(local_data[j]).convert()[0]*row_stat)); } /* - DPCT1007:157: Migration of cub::BlockStore.Store is not supported. + DPCT1007:219: Migration of cub::BlockStore::Store is not supported. */ - StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, - valid_items); + StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items); // 2. quantize data with row/col stats #pragma unroll ITEMS_PER_THREAD @@ -3437,28 +3274,27 @@ void kDoubleRowColQuant(sycl::half *__restrict__ const A, { // we already pre-normalized the col/row stat: // what this does is float/absmax*127 = int8 - local_quantized_data[j] = (char)(sycl::rint( - sycl::vec{local_data[j]} - .convert()[0] * - local_col_stats[j])); + local_quantized_data[j] = (char)(sycl::rint(sycl::vec(local_data[j]).convert()[0]*local_col_stats[j])); } /* - DPCT1065:86: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:217: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); /* - DPCT1007:158: Migration of cub::BlockStore.Store is not supported. + DPCT1007:220: Migration of cub::BlockStore::Store is not supported. */ - StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, - valid_items); + StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items); + } } -template void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1,char *smem_data) +/* +DPCT1110:14: The total declared local variable size in device function kTransformRowToFormat exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +template SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, + const sycl::nd_item<3> &item_ct1, + char *smem_data) { // 0. Load data into 32*32 shared memory tiles @@ -3499,23 +3335,22 @@ template BlockExchange; // we load row after row from the base_position // Load data row by row - int warps = item_ct1.get_local_range(2) / 32; - int warp_id = item_ct1.get_local_id(2) / 32; + int warps = item_ct1.get_local_range(2)/32; + int warp_id = item_ct1.get_local_id(2)/32; int warp_lane = item_ct1.get_local_id(2) % 32; int offset = 0; @@ -3574,12 +3409,7 @@ template -void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, - int *offset_rowidx, int *rowidx, int *colidx, - sycl::half *values, T *B, sycl::half *out, - float *__restrict__ const dequant_stats, - int nnz, int rowsA, int rowsB, int colsB, +/* +DPCT1110:13: The total declared local variable size in device function kspmm_coo_very_sparse_naive exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, sycl::half *smem_dequant_stats) { @@ -3838,11 +3666,7 @@ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, // 2. Load A into registers for(int j = 0; j < MAX_SPARSE_COUNT; j++) { - local_valA[j] = - j < count - ? values[offset + j] - : sycl::vec{0.0f} - .convert()[0]; + local_valA[j] = j < count ? values[offset+j] : sycl::vec(0.0f).convert()[0]; local_colidxA[j] = j < count ? colidx[offset+j] : 0; } @@ -3850,21 +3674,20 @@ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, // we expect each warp to be SPMM_ITEMS*32 apart // we have a total of 128 bytes for the bank with a bank size of 4 bytes // added 3 bytes = 6 values between warps should reduce bank conflicts + + while(idx_col_B < colsB) { if(dequant_stats != NULL) { - for (int i = item_ct1.get_local_id(2); i < SMEM_SIZE; - i += item_ct1.get_local_range(2)) + for(int i = item_ct1.get_local_id(2); i < SMEM_SIZE; i+=item_ct1.get_local_range(2)) if((idx_col_B+i-local_idx_col_B_offset) < colsB) smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; /* - DPCT1065:77: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. + DPCT1065:204: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); } @@ -3888,13 +3711,9 @@ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, if((idx+num_items < colsB)) { if(BITS == 8) - reinterpret_cast(local_valsB)[0] = - reinterpret_cast( - B)[(row_offset + idx) / num_items]; + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; else - reinterpret_cast(local_valsB)[0] = - reinterpret_cast( - B)[(row_offset + idx) / num_items]; + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; } else { @@ -3935,16 +3754,13 @@ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, { // load outputs to do inplace addition - reinterpret_cast(local_valOut)[0] = - reinterpret_cast(out)[idx_val / num_items]; + reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; -#pragma unroll num_items + #pragma unroll num_items for(int k = 0; k < num_items; k++) local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; - reinterpret_cast(out)[idx_val / num_items] = - reinterpret_cast( - local_valC)[j / num_items]; + reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; } else { @@ -3955,17 +3771,17 @@ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, } } - idx_col_B += item_ct1.get_local_range(2) * SPMM_ITEMS; - local_idx_col_B_offset += item_ct1.get_local_range(2) * SPMM_ITEMS; + idx_col_B += item_ct1.get_local_range(2)*SPMM_ITEMS; + local_idx_col_B_offset += item_ct1.get_local_range(2)*SPMM_ITEMS; } } -template void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1) { - int local_colidx = idx[item_ct1.get_group(2)]; + int local_colidx = idx[item_ct1.get_group(2)]; - if(FORMAT==COL_TURING) + if(FORMAT==COL_TURING) { // TURING FORMAT: // 8*32 tiles with 4*4 subtiles @@ -3979,9 +3795,8 @@ template void kExtractOutliers(char *A, int *idx, char *out, int id // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...] // each thread reads 1 element = 1 row - for (int row = item_ct1.get_local_id(2); row < rowsA; - row += item_ct1.get_local_range(2)) - { + for(int row = item_ct1.get_local_id(2); row < rowsA; row+= item_ct1.get_local_range(2)) + { int offset_per_col_tile = ((rowsA+7)/8)*32*8; int tile_offset_rows = (row/8)*32*8; int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; @@ -3998,16 +3813,15 @@ template void kExtractOutliers(char *A, int *idx, char *out, int id char val = A[offset]; - int out_idx = (row * idx_size) + item_ct1.get_group(2); - out[out_idx] = val; + int out_idx = (row*idx_size) + item_ct1.get_group(2); + out[out_idx] = val; } } else if(FORMAT == COL_AMPERE) { - for (int row = item_ct1.get_local_id(2); row < rowsA; - row += item_ct1.get_local_range(2)) - { + for(int row = item_ct1.get_local_id(2); row < rowsA; row+= item_ct1.get_local_range(2)) + { // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element // within each tile. int offset_per_col_tile = ((rowsA+31)/32)*32*32; @@ -4020,8 +3834,8 @@ template void kExtractOutliers(char *A, int *idx, char *out, int id offset += tile_offset_cols + tile_offset_rows; char val = A[offset]; - int out_idx = (row * idx_size) + item_ct1.get_group(2); - out[out_idx] = val; + int out_idx = (row*idx_size) + item_ct1.get_group(2); + out[out_idx] = val; } } } @@ -4040,11 +3854,11 @@ template void kExtractOutliers(char *A, int *idx, char *out, int id //// use k warps per thread block //// 1. threadblock use read-only cache to read in register tile for A into shared memory //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments -//// 3. each warp reads a segment of values 16x32 from B +//// 3. each warp reads a segment of values 16x32 from B //// 4. do dequantization from register of B into second pair of registers //// 5. store (4) into fragment //// 6. matmul aggregate into fragment C -//// 7. aggreecate files of C into shared memroy block C +//// 7. aggregate files of C into shared memory block C //// 8. sum (7) //// 9. write outputs to matmul output matrix //} @@ -4066,17 +3880,23 @@ template inline void vector_load(T *loca } #define WARPS 3 -template void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) +/* +DPCT1110:15: The total declared local variable size in device function gemm_device exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + T *smem_A, + T *smem_B) { #if DPCT_COMPATIBILITY_TEMP >= 750 - //using namespace nvcuda; - int col_offset = blockIdx.x *32; - const int warp_id = threadIdx.x / 32; - const int half_warp_id = threadIdx.x / 16; - const int half_warp_lane = threadIdx.x % 16; + + int col_offset = item_ct1.get_group(2) *32; + const int warp_id = item_ct1.get_local_id(2) / 32; + const int half_warp_id = item_ct1.get_local_id(2) / 16; + const int half_warp_lane = item_ct1.get_local_id(2) % 16; const int batch_size_warps = (WARPS-1)*2; - const int val_per_iter = blockDim.x-32; + const int val_per_iter = item_ct1.get_local_range(2)-32; T local_A[4]; T local_B[128]; @@ -4084,17 +3904,23 @@ template void gemm_device(int M, int N, int const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); - __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + + //__shared__ T smem_C[8*32]; - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::fragment c_frag; - wmma::fill_fragment(c_frag, 0.0f); + /* + DPCT1082:16: Migration of nvcuda::wmma::fragment type is not supported. + */ + /* + DPCT1082:17: Migration of nvcuda::wmma::matrix_a type is not supported. + */ + /* + DPCT1082:18: Migration of nvcuda::wmma::row_major type is not supported. + */ + wmma::fragment 8, 32, 16, wmma::fragmentor, 8, 32, 16, wmma::fragment; int ticktock = 0; - int idx = 0 + threadIdx.x; + int idx = 0 + item_ct1.get_local_id(2); int loaded_values = 0; // prefetch if(idx < K && warp_id < (WARPS-1)) @@ -4165,11 +3991,11 @@ template void gemm_device(int M, int N, int ticktock = ticktock == 0 ? 1 : 0; //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = item_ct1.get_local_range(2)-32; base_idx < K; base_idx+=item_ct1.get_local_range(2)-32) { - idx = base_idx + threadIdx.x; + idx = base_idx + item_ct1.get_local_id(2); - __syncthreads(); + item_ct1.barrier(sycl::access::fence_space::local_space); if(idx < K && warp_id < (WARPS-1)) { //local_A[0] = A[idx]; @@ -4246,27 +4072,48 @@ template void gemm_device(int M, int N, int if(warp_id == (WARPS-1)) for(int k = 0; k < batch_size_warps; k++) { + /* + DPCT1007:25: Migration of nvcuda::wmma::load_matrix_sync is not supported. + */ wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + /* + DPCT1007:26: Migration of nvcuda::wmma::load_matrix_sync is not supported. + */ wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + /* + DPCT1007:27: Migration of nvcuda::wmma::mma_sync is not supported. + */ wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } } - __syncthreads(); + item_ct1.barrier(sycl::access::fence_space::local_space); if(warp_id != (WARPS-1)){ return; } // only warp_id == (WARPS-1) from here - int warp_lane = threadIdx.x % 32; + int warp_lane = item_ct1.get_local_id(2) % 32; ticktock = ticktock == 0 ? 1 : 0; for(int k = 0; k < batch_size_warps; k++) { + /* + DPCT1007:28: Migration of nvcuda::wmma::load_matrix_sync is not supported. + */ wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + /* + DPCT1007:29: Migration of nvcuda::wmma::load_matrix_sync is not supported. + */ wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + /* + DPCT1007:30: Migration of nvcuda::wmma::mma_sync is not supported. + */ wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } // 129 mu if(warp_id == (WARPS-1)) + /* + DPCT1007:31: Migration of nvcuda::wmma::store_matrix_sync is not supported. + */ wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); if(col_offset + warp_lane < M) @@ -4280,33 +4127,35 @@ template void printnonzero(T *A, int num_values, const char * strva { for(int i = 0; i < num_values; i++) if((float)A[i] != 0.0) - - stream_ct1 <<"Strval "<< strval << "index "<(float *A, int num_values, const char*strval, const sycl::stream &stream_ct1); -template void printnonzero(sycl::half *A, int num_values, - const char *strval, - const sycl::stream &stream_ct1); - -static dpct::global_memory nf4_data( - sycl::range<1>(16), - {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, - -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, - 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, - 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, - 0.7229568362236023, 1.0}); -template void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +template void printnonzero(sycl::half *A, int num_values, const char*strval, + const sycl::stream &stream_ct1); + +static dpct::global_memory nf4_data(sycl::range<1>(16), {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}); +/* +DPCT1110:32: The total declared local variable size in device function kgemm_4bit_inference exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + T *smem_A, + T *smem_B, + T *smem_C) { #if DPCT_COMPATIBILITY_TEMP >= 750 - using namespace nvcuda; - int col_offset = blockIdx.x *32; - const int warp_id = threadIdx.x / 32; - const int warp_idx = threadIdx.x % 32; - const int half_warp_id = threadIdx.x / 16; - const int half_warp_lane = threadIdx.x % 16; + + int col_offset = item_ct1.get_group(2) *32; + const int warp_id = item_ct1.get_local_id(2) / 32; + const int warp_idx = item_ct1.get_local_id(2) % 32; + const int half_warp_id = item_ct1.get_local_id(2) / 16; + const int half_warp_lane = item_ct1.get_local_id(2) % 16; const int batch_size_warps = (WARPS-1)*2; T quant_map[16]; @@ -4324,22 +4173,28 @@ template void kgemm_4bit_inference(int M, int N, int K const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); - __shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; - __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_C[8*32]; - - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::fragment c_frag; - wmma::fill_fragment(c_frag, 0.0f); + + + - for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) + /* + DPCT1082:33: Migration of nvcuda::wmma::fragment type is not supported. + */ + /* + DPCT1082:34: Migration of nvcuda::wmma::matrix_a type is not supported. + */ + /* + DPCT1082:35: Migration of nvcuda::wmma::row_major type is not supported. + */ + wmma::fragment 8, 32, 16, wmma::fragmentor, 8, 32, 16, wmma::fragment; + + for(int i = item_ct1.get_local_id(2); i < (8*32); i+=item_ct1.get_local_range(2)) smem_C[i] = 0.0f; - __syncthreads(); + item_ct1.barrier(sycl::access::fence_space::local_space); int ticktock = 0; - int idx = 0 + threadIdx.x; + int idx = 0 + item_ct1.get_local_id(2); int loaded_values = 0; // prefetch if(idx < K && warp_id < (WARPS-1)) @@ -4347,7 +4202,7 @@ template void kgemm_4bit_inference(int M, int N, int K if(loaded_values == 0) { local_A[0] = A[idx]; - local_A[1] = A[idx+blockDim.x-32]; + local_A[1] = A[idx+item_ct1.get_local_range(2)-32]; #pragma unroll 32 for(int col = 0; col < 32; col++) @@ -4401,9 +4256,9 @@ template void kgemm_4bit_inference(int M, int N, int K //printf("aa %i %i\n", idx, loaded_values); //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = item_ct1.get_local_range(2)-32; base_idx < K; base_idx+=item_ct1.get_local_range(2)-32) { - idx = base_idx + threadIdx.x; + idx = base_idx + item_ct1.get_local_id(2); //if(threadIdx.x == 0) //printf("%i %i\n", idx, loaded_values); @@ -4413,7 +4268,7 @@ template void kgemm_4bit_inference(int M, int N, int K if(loaded_values == 0) { local_A[0] = A[idx]; - local_A[1] = A[idx+blockDim.x-32]; + local_A[1] = A[idx+item_ct1.get_local_range(2)-32]; #pragma unroll 32 for(int col = 0; col < 32; col++) @@ -4430,7 +4285,10 @@ template void kgemm_4bit_inference(int M, int N, int K loaded_values--; int absidx = (idx + col_offset)/blocksize; - half local_absmax = __ldg(&(absmax[absidx])); + /* + DPCT1098:222: The '*' expression is used instead of the __ldg call. These two expressions do not provide the exact same functionality. Check the generated code for potential precision and/or performance issues. + */ + sycl::half local_absmax = absmax[absidx]; #pragma unroll 64 for(int col = 0; col < 64; col+=2) @@ -4472,13 +4330,22 @@ template void kgemm_4bit_inference(int M, int N, int K if(warp_id == (WARPS-1)) for(int k = 0; k < batch_size_warps; k++) { + /* + DPCT1007:42: Migration of nvcuda::wmma::load_matrix_sync is not supported. + */ wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + /* + DPCT1007:43: Migration of nvcuda::wmma::load_matrix_sync is not supported. + */ wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + /* + DPCT1007:44: Migration of nvcuda::wmma::mma_sync is not supported. + */ wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } } - __syncthreads(); + item_ct1.barrier(sycl::access::fence_space::local_space); //if(threadIdx.x == 0) //{ // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); @@ -4486,20 +4353,32 @@ template void kgemm_4bit_inference(int M, int N, int K //} if(warp_id != (WARPS-1)){ return; } // only warp_id == (WARPS-1) from here - int warp_lane = threadIdx.x % 32; + int warp_lane = item_ct1.get_local_id(2) % 32; ticktock = ticktock == 0 ? 1 : 0; for(int k = 0; k < batch_size_warps; k++) { //if(warp_lane == 0) //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); + /* + DPCT1007:45: Migration of nvcuda::wmma::load_matrix_sync is not supported. + */ wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + /* + DPCT1007:46: Migration of nvcuda::wmma::load_matrix_sync is not supported. + */ wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + /* + DPCT1007:47: Migration of nvcuda::wmma::mma_sync is not supported. + */ wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } // 129 mu if(warp_id == (WARPS-1)) + /* + DPCT1007:48: Migration of nvcuda::wmma::store_matrix_sync is not supported. + */ wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); //printnonzero(smem_C, 32, ""); @@ -4510,39 +4389,36 @@ template void kgemm_4bit_inference(int M, int N, int K } #define num_values_4bit 32 -template -void kgemm_4bit_inference_naive(int M, int N, int K, T *__restrict__ const A, - unsigned char *B, float *absmax, - const float *datatype, T *out, int lda, int ldb, - int ldc, int blocksize, - const sycl::nd_item<3> &item_ct1, T *quant_map) +/* +DPCT1110:49: The total declared local variable size in device function kgemm_4bit_inference_naive exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + T *quant_map) { - // per threadblock: + // per threadblock: // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] // 4 warps -> 4 loads per iter // 1x32 * 32x4 -> 1x4 outputs per thread block + + const int warp_idx = item_ct1.get_local_id(2) / 32; const int warp_lane = item_ct1.get_local_id(2) % 32; - const int row_B = (THREADS / 32) * item_ct1.get_group(2) + warp_idx; + const int row_B = (THREADS/32)*item_ct1.get_group(2) + warp_idx; const int num_values_8bit = num_values_4bit/2; float local_C = 0.0f; unsigned char local_B_4bit[num_values_8bit]; - T local_B[num_values_4bit]; - T local_A[num_values_4bit]; - - T local_absmax = T(0.0f); + T local_B[num_values_4bit/4]; + T local_A[num_values_4bit/4]; + + T local_absmax = T(0.0f); - for (int i = item_ct1.get_local_id(2); i < 16; i++) + for(int i = item_ct1.get_local_id(2); i < 16; i++) quant_map[i] = T(datatype[i]); - /* - DPCT1065:88: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. - */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); // A: [1, K] // B: [N, K] @@ -4551,20 +4427,17 @@ void kgemm_4bit_inference_naive(int M, int N, int K, T *__restrict__ const A, int inner_idx_halved = inner_idx/2; int offset_B = ldb*row_B; int absidx = ((2*offset_B)+inner_idx)/blocksize; - /* - DPCT1026:159: The call to __ldg was removed because there is no - corresponding API in SYCL. - */ - local_absmax = absmax[absidx]; + /* + DPCT1098:223: The '*' expression is used instead of the __ldg call. These two expressions do not provide the exact same functionality. Check the generated code for potential precision and/or performance issues. + */ + local_absmax = absmax[absidx]; if(row_B < M) { if((inner_idx_halved + num_values_8bit) < (K/2)) { // this is the most important for performance considerations - reinterpret_cast(local_B_4bit)[0] = - reinterpret_cast( - B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; } else { @@ -4583,90 +4456,59 @@ void kgemm_4bit_inference_naive(int M, int N, int K, T *__restrict__ const A, local_B_4bit[j] = 0b01110111; } - #pragma unroll - for(int k = 0; k < num_values_8bit; k++) - { -#if DPCT_COMPATIBILITY_TEMP >= 800 - local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; - local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax; - #else - // bf16 multipliation not supported - local_B[k*2] = T((float)quant_map[local_B_4bit[k] >> 4]*(float)local_absmax); - local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[k] & 0x0F]*(float)local_absmax); - #endif - } - - if(inner_idx+num_values_4bit < K) + for(int i = 0; i < 4; i++) { - // this is also relatively important for performance - if(BITS==16) - { - reinterpret_cast(local_A)[0] = - reinterpret_cast( - A)[inner_idx / (num_values_4bit / 4) + 0]; - reinterpret_cast(local_A)[1] = - reinterpret_cast( - A)[inner_idx / (num_values_4bit / 4) + 1]; - reinterpret_cast(local_A)[2] = - reinterpret_cast( - A)[inner_idx / (num_values_4bit / 4) + 2]; - reinterpret_cast(local_A)[3] = - reinterpret_cast( - A)[inner_idx / (num_values_4bit / 4) + 3]; - } - else + #pragma unroll + for(int k = 0; k < num_values_8bit/4; k++) { - reinterpret_cast(local_A)[0] = - reinterpret_cast( - A)[inner_idx / (num_values_4bit / 8) + 0]; - reinterpret_cast(local_A)[1] = - reinterpret_cast( - A)[inner_idx / (num_values_4bit / 8) + 1]; - reinterpret_cast(local_A)[2] = - reinterpret_cast( - A)[inner_idx / (num_values_4bit / 8) + 2]; - reinterpret_cast(local_A)[3] = - reinterpret_cast( - A)[inner_idx / (num_values_4bit / 8) + 3]; - reinterpret_cast(local_A)[4] = - reinterpret_cast( - A)[inner_idx / (num_values_4bit / 8) + 4]; - reinterpret_cast(local_A)[5] = - reinterpret_cast( - A)[inner_idx / (num_values_4bit / 8) + 5]; - reinterpret_cast(local_A)[6] = - reinterpret_cast( - A)[inner_idx / (num_values_4bit / 8) + 6]; - reinterpret_cast(local_A)[7] = - reinterpret_cast( - A)[inner_idx / (num_values_4bit / 8) + 7]; + #if DPCT_COMPATIBILITY_TEMP >= 800 + local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; + local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; + #else + // bf16 multipliation not supported + local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax); + local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax); + #endif } - } - else - #pragma unroll - for(int k = 0; k < num_values_4bit; k++) - if(inner_idx + k < K) - local_A[k] = A[inner_idx + k]; + if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K) + { + // this is also relatively important for performance + if(BITS==16) + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + i]; + } else - local_A[k] = T(0.0f); + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1]; + } + + } + else + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + if(inner_idx + (i*num_values_4bit/4) + k < K) + local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)]; + else + local_A[k] = T(0.0f); - // accumulate in float; small performance hit for Ampere, but lower error for outputs - #pragma unroll - for(int k = 0; k < num_values_4bit; k++) - { -#if DPCT_COMPATIBILITY_TEMP >= 800 - local_C += (float)(local_A[k]*local_B[k]); - #else - // bf16 multipliation not supported - local_C += ((float)local_A[k]*(float)local_B[k]); - #endif + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + { + #if DPCT_COMPATIBILITY_TEMP >= 800 + local_C += (float)(local_A[k]*local_B[k]); + #else + // bf16 multipliation not supported + local_C += ((float)local_A[k]*(float)local_B[k]); + #endif + } } } - local_C = sycl::reduce_over_group(item_ct1.get_sub_group(), local_C, - sycl::plus<>()); + local_C = sycl::reduce_over_group(item_ct1.get_sub_group(), local_C, sycl::plus<>()); if(row_B < M && warp_lane == 0) out[row_B] = T(local_C); @@ -4785,16 +4627,14 @@ void kgemm_4bit_inference_naive(int M, int N, int K, T *__restrict__ const A, //} -template void kfunc(T *A, T *B, T value, long n, +template SYCL_EXTERNAL void kfunc(T *A, T *B, T value, long n, const sycl::nd_item<3> &item_ct1) { - for (long i = (item_ct1.get_local_range(2) * item_ct1.get_group(2)) + - item_ct1.get_local_id(2); - i < n; i += (item_ct1.get_local_range(2) * item_ct1.get_group_range(2))) + for(long i = (item_ct1.get_local_range(2)*item_ct1.get_group(2)) + item_ct1.get_local_id(2); i < n; i+=(item_ct1.get_local_range(2)*item_ct1.get_group_range(2))) { switch(FUNC) { - case FILL: + case FILL: A[i] = (T)value; break; case ARANGE: @@ -4812,266 +4652,210 @@ template void kfunc(T *A, T *B, T value, long n, // TEMPLATE DEFINITIONS //============================================================== -template void kfunc(float *A, float *B, float value, long n, +template SYCL_EXTERNAL void kfunc(float *A, float *B, float value, long n, const sycl::nd_item<3> &item_ct1); -template void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n, +template SYCL_EXTERNAL void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n, const sycl::nd_item<3> &item_ct1); -template void kfunc(float *A, float *B, float value, long n, +template SYCL_EXTERNAL void kfunc(float *A, float *B, float value, long n, const sycl::nd_item<3> &item_ct1); -template void kfunc(float *A, float *B, float value, long n, +template SYCL_EXTERNAL void kfunc(float *A, float *B, float value, long n, const sycl::nd_item<3> &item_ct1); // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); -template void gemm_device(int M, int N, int K, - sycl::half *__restrict__ const A, - sycl::half *B, sycl::half *out, - int lda, int ldb, int ldc); - -template void kgemm_4bit_inference( - int M, int N, int K, sycl::half *__restrict__ const A, unsigned char *B, - float *absmax, sycl::half *out, int lda, int ldb, int ldc, int blocksize); -template void kgemm_4bit_inference( - int M, int N, int K, sycl::half *__restrict__ const A, unsigned char *B, - float *absmax, sycl::half *out, int lda, int ldb, int ldc, int blocksize); -template void kgemm_4bit_inference( - int M, int N, int K, sycl::half *__restrict__ const A, unsigned char *B, - float *absmax, sycl::half *out, int lda, int ldb, int ldc, int blocksize); -template void kgemm_4bit_inference( - int M, int N, int K, sycl::half *__restrict__ const A, unsigned char *B, - float *absmax, sycl::half *out, int lda, int ldb, int ldc, int blocksize); - -template void kgemm_4bit_inference_naive( - int M, int N, int K, sycl::half *__restrict__ const A, unsigned char *B, - float *absmax, const float *datatype, sycl::half *out, int lda, int ldb, - int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, - sycl::half *quant_map); -template void kgemm_4bit_inference_naive( - int M, int N, int K, oneapi::mkl::bfloat16 *__restrict__ const A, - unsigned char *B, float *absmax, const float *datatype, - oneapi::mkl::bfloat16 *out, int lda, int ldb, int ldc, int blocksize, - const sycl::nd_item<3> &item_ct1, oneapi::mkl::bfloat16 *quant_map); -template void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, sycl::half *smem_B); + +template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, + sycl::half *smem_B, + sycl::half *smem_C); +template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, + sycl::half *smem_B, + sycl::half *smem_C); +template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, + sycl::half *smem_B, + sycl::half *smem_C); +template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, + sycl::half *smem_B, + sycl::half *smem_C); + +template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, sycl::half * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + sycl::half *quant_map); +template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, sycl::ext::oneapi::bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, sycl::ext::oneapi::bfloat16 * out, int lda, int ldb, int ldc, int blocksize, + const sycl::nd_item<3> &item_ct1, + sycl::ext::oneapi::bfloat16 *quant_map); +template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, float *quant_map); -template void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1); -template void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1); -template void kspmm_coo_very_sparse_naive( - int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, - sycl::half *values, sycl::half *B, sycl::half *out, float *dequant_stats, - int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); -template void kspmm_coo_very_sparse_naive( - int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, - sycl::half *values, sycl::half *B, sycl::half *out, float *dequant_stats, - int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); -template void kspmm_coo_very_sparse_naive( - int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, - sycl::half *values, sycl::half *B, sycl::half *out, float *dequant_stats, - int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); -template void kspmm_coo_very_sparse_naive( - int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, - sycl::half *values, signed char *B, sycl::half *out, float *dequant_stats, - int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); -template void kspmm_coo_very_sparse_naive( - int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, - sycl::half *values, signed char *B, sycl::half *out, float *dequant_stats, - int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); -template void kspmm_coo_very_sparse_naive( - int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, - sycl::half *values, signed char *B, sycl::half *out, float *dequant_stats, - int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); - -template void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, +template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); +template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); +template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); +template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); +template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); +template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, + const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats); + +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data); -template void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data); -template void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data); -template void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data); -template void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data); -template void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data); -template void kdequant_mm_int32_fp16<4, 128, 512>( - int *__restrict__ const A, float *__restrict__ const rowStats, - float *__restrict__ const colStats, sycl::half *out, float *newRowStats, - float *newcolStats, sycl::half *__restrict__ const bias, const int numRows, - const int numCols, const int tileCols, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_rowStats); - -template void kDoubleRowColQuant<64, 4, 16, 64 * 4, 0>( - sycl::half *__restrict__ const A, float *__restrict__ const rowStats, - float *__restrict__ const colStats, char *out_col_normed, - char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, - int *__restrict__ nnz_block_ptr, float threshold, int rows, int cols, - int tiledCols, const sycl::nd_item<3> &item_ct1, - float *smem_row_stats, - unsigned int *smem_nnz_row_idx); -template void kDoubleRowColQuant<64, 4, 16, 64 * 4, 1>( - sycl::half *__restrict__ const A, float *__restrict__ const rowStats, - float *__restrict__ const colStats, char *out_col_normed, - char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, - int *__restrict__ nnz_block_ptr, float threshold, int rows, int cols, - int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, - unsigned int *smem_nnz_row_idx); +template void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, + const sycl::nd_item<3> &item_ct1, + float *smem_rowStats); + +template void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, + const sycl::nd_item<3> &item_ct1, + float *smem_row_stats, + unsigned int *smem_nnz_row_idx); +template void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, + const sycl::nd_item<3> &item_ct1, + float *smem_row_stats, + unsigned int *smem_nnz_row_idx); template unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template unsigned char dQuantize<1>(float* smem_code, const float rand, float x); -template void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n, +template SYCL_EXTERNAL void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); -template void kEstimateQuantiles(sycl::half *__restrict__ const A, float *code, - const float offset, const sycl::half max_val, - const int n, const sycl::nd_item<3> &item_ct1, +template SYCL_EXTERNAL void kEstimateQuantiles(sycl::half *__restrict__ const A, float *code, const float offset, const sycl::half max_val, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ float* state1, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); - -MAKE_PreconditionOptimizer32bit1State(MOMENTUM, sycl::half) -MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) -MAKE_PreconditionOptimizer32bit1State(RMSPROP, sycl::half) -MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) -MAKE_PreconditionOptimizer32bit1State(LION, sycl::half) -MAKE_PreconditionOptimizer32bit1State(LION, float) -MAKE_PreconditionOptimizer32bit1State(LION, oneapi::mkl::bfloat16) -MAKE_PreconditionOptimizer32bit1State(ADAGRAD, sycl::half) -MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); \ + +SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(MOMENTUM, sycl::half) +SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) +SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(RMSPROP, sycl::half) +SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(LION, sycl::half) +SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(LION, float) +SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(LION, sycl::ext::oneapi::bfloat16) +SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(ADAGRAD, sycl::half) +SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ template void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); - -MAKE_Optimizer32bit1State(MOMENTUM, sycl::half) -MAKE_Optimizer32bit1State(MOMENTUM, float) -MAKE_Optimizer32bit1State(RMSPROP, sycl::half) -MAKE_Optimizer32bit1State(RMSPROP, float) -MAKE_Optimizer32bit1State(LION, sycl::half) -MAKE_Optimizer32bit1State(LION, float) -MAKE_Optimizer32bit1State(LION, oneapi::mkl::bfloat16) -MAKE_Optimizer32bit1State(ADAGRAD, sycl::half) -MAKE_Optimizer32bit1State(ADAGRAD, float) + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); \ + +SYCL_EXTERNAL MAKE_Optimizer32bit1State(MOMENTUM, sycl::half) +SYCL_EXTERNAL MAKE_Optimizer32bit1State(MOMENTUM, float) +SYCL_EXTERNAL MAKE_Optimizer32bit1State(RMSPROP, sycl::half) +SYCL_EXTERNAL MAKE_Optimizer32bit1State(RMSPROP, float) +SYCL_EXTERNAL MAKE_Optimizer32bit1State(LION, sycl::half) +SYCL_EXTERNAL MAKE_Optimizer32bit1State(LION, float) +SYCL_EXTERNAL MAKE_Optimizer32bit1State(LION, sycl::ext::oneapi::bfloat16) +SYCL_EXTERNAL MAKE_Optimizer32bit1State(ADAGRAD, sycl::half) +SYCL_EXTERNAL MAKE_Optimizer32bit1State(ADAGRAD, float) #define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ template void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ float* state1, float* state2, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); - -MAKE_PreconditionOptimizer32bit2State(ADAM, float) -MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::half) -MAKE_PreconditionOptimizer32bit2State(ADAM, oneapi::mkl::bfloat16) - -template void kOptimizer32bit2State( - float *g, float *p, - float *state1, - float *state2, float *unorm, - const float max_unorm, - const float param_norm, - const float beta1, - const float beta2, - const float eps, - const float weight_decay, - const int step, - const float lr, - const float gnorm_scale, - const bool skip_zeros, - const int n, - const sycl::nd_item<3> - &item_ct1, - uint8_t *temp_storage_ct1); -template void kOptimizer32bit2State( - sycl::half *g, sycl::half *p, float *state1, float *state2, float *unorm, - const float max_unorm, const float param_norm, const float beta1, - const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, - const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, - uint8_t *temp_storage_ct1); -template void kOptimizer32bit2State( - oneapi::mkl::bfloat16 *g, oneapi::mkl::bfloat16 *p, float *state1, - float *state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, - const float weight_decay, const int step, const float lr, - const float gnorm_scale, const bool skip_zeros, const int n, + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); \ + +SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, float) +SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::half) +SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::ext::oneapi::bfloat16) + +template SYCL_EXTERNAL void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); +template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::half* g, sycl::half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, + const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); +template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::ext::oneapi::bfloat16* g, sycl::ext::oneapi::bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ @@ -5084,14 +4868,14 @@ template void kPreconditionOptimizerStatic8bit1State(gtype* p, gty float* max1, float* new_max1, \ const float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, float *smem_quantiles1); + const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, float *smem_quantiles1); \ -MAKE_PreconditionStatic8bit1State(MOMENTUM, sycl::half) -MAKE_PreconditionStatic8bit1State(MOMENTUM, float) -MAKE_PreconditionStatic8bit1State(RMSPROP, sycl::half) -MAKE_PreconditionStatic8bit1State(RMSPROP, float) -MAKE_PreconditionStatic8bit1State( LION, sycl::half) -MAKE_PreconditionStatic8bit1State(LION, float) +SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(MOMENTUM, sycl::half) +SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(MOMENTUM, float) +SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(RMSPROP, sycl::half) +SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(RMSPROP, float) +SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(LION, sycl::half) +SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(LION, float) #define MAKE_optimizerStatic8bit1State(oname, gtype) \ template void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ @@ -5103,14 +4887,14 @@ template void kOptimizerStatic8bit1State(gtype* p, gtype* const g, float* max1, float* new_max1, \ float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, uint8_t *temp_storage_ct1); + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, uint8_t *temp_storage_ct1); \ -MAKE_optimizerStatic8bit1State(MOMENTUM, sycl::half) -MAKE_optimizerStatic8bit1State(MOMENTUM, float) -MAKE_optimizerStatic8bit1State(RMSPROP, sycl::half) -MAKE_optimizerStatic8bit1State(RMSPROP, float) -MAKE_optimizerStatic8bit1State(LION, sycl::half) -MAKE_optimizerStatic8bit1State(LION, float) +SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(MOMENTUM, sycl::half) +SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(MOMENTUM, float) +SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(RMSPROP, sycl::half) +SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(RMSPROP, float) +SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(LION, sycl::half) +SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(LION, float) #define MAKE_PreconditionStatic8bit2State(oname, gtype) \ template void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ @@ -5120,10 +4904,10 @@ template void kPreconditionOptimizerStatic8bit2State(gtype* p, gty float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, float *smem_quantiles1, float *smem_quantiles2); + const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, float *smem_quantiles1, float *smem_quantiles2); \ -MAKE_PreconditionStatic8bit2State(ADAM, sycl::half) -MAKE_PreconditionStatic8bit2State(ADAM, float) +SYCL_EXTERNAL MAKE_PreconditionStatic8bit2State(ADAM, sycl::half) +SYCL_EXTERNAL MAKE_PreconditionStatic8bit2State(ADAM, float) #define MAKE_optimizerStatic8bit2State(oname, gtype) \ template void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ @@ -5134,122 +4918,105 @@ template void kOptimizerStatic8bit2State(gtype* p, gtype* const g, float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, uint8_t *temp_storage_ct1); + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, uint8_t *temp_storage_ct1); \ -MAKE_optimizerStatic8bit2State(ADAM, sycl::half) -MAKE_optimizerStatic8bit2State(ADAM, float) +SYCL_EXTERNAL MAKE_optimizerStatic8bit2State(ADAM, sycl::half) +SYCL_EXTERNAL MAKE_optimizerStatic8bit2State(ADAM, float) -template void kPercentileClipping(float *__restrict__ g,float *gnorm_vec, int step,const int n,const sycl::nd_item<3> &item_ct1); -template void kPercentileClipping(sycl::half *__restrict__ g, float *gnorm_vec, int step, const int n, const sycl::nd_item<3> &item_ct1); +template SYCL_EXTERNAL void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n, + const sycl::nd_item<3> &item_ct1); +template SYCL_EXTERNAL void kPercentileClipping(sycl::half * __restrict__ g, float *gnorm_vec, int step, const int n, + const sycl::nd_item<3> &item_ct1); #define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ -template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value); - -MAKE_kQuantizeBlockwise(sycl::half, 4096,4, 0, General8bit) -MAKE_kQuantizeBlockwise(sycl::half,4096, 4,1, General8bit) -MAKE_kQuantizeBlockwise(sycl::half,2048, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0. FP4) -MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, FP4) -MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, FP4) -MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, FP4) -MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, FP4) -MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, FP4) -MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0. FP4) -MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, NF4) -MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, NF4) -MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, NF4) -MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, NF4) -MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, NF4) -MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, NF4) -MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, NF4) -MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) -MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) -MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) -MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) -MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) -MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) -MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) -MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) -MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) -MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) -MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) -MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) -MAKE_kQuantizeBlockwise(float, 256,2, 0, NF4) -MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) -MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16,4096, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 4096, 4, 1, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 2048, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 1024, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 512, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 256, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 128, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 64, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 4096, 4, 0, FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 2048, 4, 0, FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 1024, 4, 0, FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 512, 2, 0, FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 256, 2, 0, FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 128, 2, 0,FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 64, 2, 0, FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 4096, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 4096, 4, 1, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 2048, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 1024, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 512, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 256, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 128, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 64, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 4096, 4, 0, FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 2048, 4, 0, FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 1024, 4, 0, FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 512, 2, 0, FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 256, 2, 0, FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 128, 2, 0, FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 64, 2, 0, FP4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 4096, 4, 0, NF4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 2048, 4, 0, NF4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 1024, 4, 0, NF4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 512, 2, 0, NF4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 256, 2, 0, NF4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 128, 2, 0, NF4) -MAKE_kQuantizeBlockwise(oneapi::mkl::bfloat16, 64, 2, 0, NF4) - -template void kDequantizeBlockwise(float *code,unsigned char *A,float *absmax,sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); -template void kDequantizeBlockwise( - float *code, unsigned char *A, float *absmax, sycl::half *out, - const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); -template void kDequantizeBlockwise( - float *code, unsigned char *A, float *absmax, sycl::half *out, - const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); -template void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, +template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value); \ + +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 1, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) + +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 1, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 2048, 4, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 1024, 4, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 512, 2, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, General8bit) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 2048, 4, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 1024, 4, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 512, 2, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, FP4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 2048, 4, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 1024, 4, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 512, 2, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, NF4) +SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, NF4) + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); -template void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); -template void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); -template void kDequantizeBlockwise( - float *code, unsigned char *A, float *absmax, oneapi::mkl::bfloat16 *out, - const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); -template void -kDequantizeBlockwise( - float *code, unsigned char *A, float *absmax, oneapi::mkl::bfloat16 *out, - const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); -template void kDequantizeBlockwise( - float *code, unsigned char *A, float *absmax, oneapi::mkl::bfloat16 *out, - const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ template void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ @@ -5258,11 +5025,11 @@ template void kOptimizerStatic8bit2StateBlockwise &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2, uint8_t *temp_storage_ct1); + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2, uint8_t *temp_storage_ct1); \ -MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) -MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::half, 2048, 8) -MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, oneapi::mkl::bfloat16, 2048, 8) +SYCL_EXTERNAL MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) +SYCL_EXTERNAL MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::half, 2048, 8) +SYCL_EXTERNAL MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::ext::oneapi::bfloat16, 2048, 8) #define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ @@ -5273,14 +5040,14 @@ template void kOptimizerStatic8bit1StateBlockwise &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1, uint8_t *temp_storage_ct1); - -MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::half, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, sycl::half, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::half, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(LION, oneapi::mkl::bfloat16, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) - MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, sycl::half, 2048, 8) + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1, uint8_t *temp_storage_ct1); \ + +SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) +SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::half, 2048, 8) +SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) +SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, sycl::half, 2048, 8) +SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) +SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::half, 2048, 8) +SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::ext::oneapi::bfloat16, 2048, 8) +SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) +SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, sycl::half, 2048, 8) diff --git a/csrc/kernels.dp.hpp b/csrc/sycl/kernels.h similarity index 100% rename from csrc/kernels.dp.hpp rename to csrc/sycl/kernels.h diff --git a/csrc/ops.dp.cpp b/csrc/sycl/ops.cpp similarity index 100% rename from csrc/ops.dp.cpp rename to csrc/sycl/ops.cpp diff --git a/csrc/ops.dp.hpp b/csrc/sycl/ops.h similarity index 100% rename from csrc/ops.dp.hpp rename to csrc/sycl/ops.h From dd935ad57a0181cb8ce93f877ae540adbafaa405 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 20 Mar 2024 10:38:36 -0700 Subject: [PATCH 06/66] add quant kernels --- csrc/sycl/kernels.cpp | 549 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 492 insertions(+), 57 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 6dcefca23..fd55f0740 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -895,7 +895,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { - using group_load = dpct::group::workgroup_load; + using group_load = dpct::group::workgroup_load; size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_TH); sycl::local_accessor tacc( temp_storage_size, h); @@ -961,7 +961,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { - using group_load = dpct::group::workgroup_load; + using group_load = dpct::group::workgroup_load; size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_TH); sycl::local_accessor tacc( temp_storage_size, h); @@ -1029,7 +1029,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { - using group_store = dpct::group::workgroup_store; + using group_store = dpct::group::workgroup_store; size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_TH); sycl::local_accessor tacc( temp_storage_size, h); @@ -1109,7 +1109,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { - using group_load = dpct::group::workgroup_load; + using group_load = dpct::group::workgroup_load; size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_TH); sycl::local_accessor tacc( temp_storage_size, h); @@ -1176,7 +1176,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { - using group_store = dpct::group::workgroup_store; + using group_store = dpct::group::workgroup_store; size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_TH); sycl::local_accessor tacc( temp_storage_size, h); @@ -1230,7 +1230,7 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1) + const sycl::nd_item<3> &item_ct1) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); @@ -1244,17 +1244,23 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, const float correction1 = 1.0f/(1.0f - dpct::pow(beta1, step)); const float correction2 = 1.0f/(1.0f - dpct::pow(beta2, step)); + + sycl::buffer buff_g(g, sycl::range<1>(NUM_VALS)); + sycl::buffer buff_state1(state1,sycl::range<1>(NUM_VALS)); + sycl::buffer buff_state2(state2,sycl::range<1>(NUM_VALS)); + - typedef cub::BlockLoad Load; - typedef cub::BlockLoad LoadFloat; - typedef sycl::group<3> BlockReduce; - + //typedef cub::BlockLoad Load; + //typedef cub::BlockLoad LoadFloat; + //typedef sycl::group<3> BlockReduce; + /* union type_ct2{ typename Load::TempStorage load; typename LoadFloat::TempStorage loadf; typename BlockReduce::TempStorage reduce; }; - type_ct2 &temp_storage = *(type_ct2 *)temp_storage_ct1; + */ + //type_ct2 &temp_storage = *(type_ct2 *)temp_storage_ct1; for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) { @@ -1263,28 +1269,103 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, /* DPCT1065:97: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:101: Migration of cub::BlockLoad::Load is not supported. */ - Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_VALS); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_g[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); + }); + + }); + + /* DPCT1065:98: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:102: Migration of cub::BlockLoad::Load is not supported. */ - LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_VALS); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, s1_vals); + }); + + }); + + + /* DPCT1065:99: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:103: Migration of cub::BlockLoad::Load is not supported. */ - LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); - + //LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_VALS); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state2[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, s2_vals); + }); + + }); + # pragma unroll NUM_VALS for(unsigned int j = 0; j < NUM_VALS; j++) g_vals[j] = gnorm_scale*((float)g_vals[j]); @@ -1312,7 +1393,7 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, /* DPCT1065:100: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); s1_vals[0] = sycl::reduce_over_group(item_ct1.get_group(), s1_vals[0], sycl::plus<>()); if(item_ct1.get_local_id(2) == 0) @@ -1332,7 +1413,7 @@ void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1) + const sycl::nd_item<3> &item_ct1) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -1357,11 +1438,10 @@ void kOptimizer32bit2State(T* g, T* p, } else{ update_scale = 1.0f; } - sycl::buffer buff_smem_code(smem_code,sycl::range<1>(257)); - sycl::buffer buff_vals(vals, sycl::range<1>(NUM_PER_TH)); - sycl::buffer buff_A(A,sycl::range<1>(NUM_PER_TH)); - sycl::buffer buff_out(out,sycl::range<1>(NUM_PER_TH)); - sycl::buffer buff_rand(rand,sycl::range<1>(NUM_PER_TH)); + sycl::buffer buff_g(g, sycl::range<1>(NUM_PER_THREAD)); + sycl::buffer buff_p(p, sycl::range<1>(NUM_PER_THREAD)); + sycl::buffer buff_state1(state1,sycl::range<1>(NUM_PER_THREAD)); + sycl::buffer buff_state2(state2,sycl::range<1>(NUM_PER_THREAD)); //typedef cub::BlockLoad Load; @@ -1369,14 +1449,15 @@ void kOptimizer32bit2State(T* g, T* p, //typedef cub::BlockLoad LoadFloat; //typedef cub::BlockStore StoreFloat; - + /* union type_ct3{ typename Load::TempStorage load; typename Store::TempStorage store; typename LoadFloat::TempStorage loadf; typename StoreFloat::TempStorage storef; }; - type_ct3 &temp_storage = *(type_ct3 *)temp_storage_ct1; + */ + //type_ct3 &temp_storage = *(type_ct3 *)temp_storage_ct1; for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*TH*NUM_PER_THREAD) { @@ -1389,7 +1470,32 @@ void kOptimizer32bit2State(T* g, T* p, /* DPCT1007:111: Migration of cub::BlockLoad::Load is not supported. */ - Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_g[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); + }); + + }); + + /* DPCT1065:105: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -1397,7 +1503,32 @@ void kOptimizer32bit2State(T* g, T* p, /* DPCT1007:112: Migration of cub::BlockLoad::Load is not supported. */ - LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, s1_vals); + }); + + }); + + /* DPCT1065:106: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -1405,7 +1536,32 @@ void kOptimizer32bit2State(T* g, T* p, /* DPCT1007:113: Migration of cub::BlockLoad::Load is not supported. */ - LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); + //LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); + + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state2[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, s2_vals); + }); + + }); + /* DPCT1065:107: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -1413,7 +1569,32 @@ void kOptimizer32bit2State(T* g, T* p, /* DPCT1007:114: Migration of cub::BlockLoad::Load is not supported. */ - Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + //Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_p[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, p_vals); + }); + + }); + # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) @@ -1445,7 +1626,31 @@ void kOptimizer32bit2State(T* g, T* p, /* DPCT1007:115: Migration of cub::BlockStore::Store is not supported. */ - Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + //Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_p[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, p_vals); + }); + + }); + /* DPCT1065:109: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -1453,7 +1658,33 @@ void kOptimizer32bit2State(T* g, T* p, /* DPCT1007:116: Migration of cub::BlockStore::Store is not supported. */ - StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + //StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, s1_vals); + }); + + }); + + + /* DPCT1065:110: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -1461,7 +1692,31 @@ void kOptimizer32bit2State(T* g, T* p, /* DPCT1007:117: Migration of cub::BlockStore::Store is not supported. */ - StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); + //StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state2[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, s2_vals); + }); + + }); + } } @@ -1471,7 +1726,7 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1) + const sycl::nd_item<3> &item_ct1) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); @@ -1482,17 +1737,21 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, float s1_vals[NUM_VALS]; - typedef cub::BlockLoad Load; - typedef cub::BlockLoad LoadFloat; - typedef sycl::group<3> BlockReduce; - + //typedef cub::BlockLoad Load; + //typedef cub::BlockLoad LoadFloat; + //typedef sycl::group<3> BlockReduce; + sycl::buffer buff_g(g, sycl::range<1>(NUM_PER_THREAD)); + sycl::buffer buff_p(p, sycl::range<1>(NUM_PER_THREAD)); + sycl::buffer buff_state1(state1,sycl::range<1>(NUM_PER_THREAD)); + sycl::buffer buff_state2(state2,sycl::range<1>(NUM_PER_THREAD)); + /* union type_ct4{ typename Load::TempStorage load; typename LoadFloat::TempStorage loadf; typename BlockReduce::TempStorage reduce; }; type_ct4 &temp_storage = *(type_ct4 *)temp_storage_ct1; - + */ for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) { valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; @@ -1504,7 +1763,32 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, /* DPCT1007:121: Migration of cub::BlockLoad::Load is not supported. */ - Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_VALS); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_g[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); + }); + + }); + + /* DPCT1065:119: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -1512,7 +1796,31 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, /* DPCT1007:122: Migration of cub::BlockLoad::Load is not supported. */ - LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_VALS); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, s1_vals); + }); + + }); + # pragma unroll NUM_VALS for(unsigned int j = 0; j < NUM_VALS; j++) @@ -1557,8 +1865,9 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, /* DPCT1007:2: Migration of cub::Sum is not supported. */ - s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); - + //s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); + s1_vals[0] = sycl::reduce_over_group(item_ct1.get_group(), s1_vals[0], sycl::plus<>()); + if(item_ct1.get_local_id(2) == 0) dpct::atomic_fetch_add(&unorm[0], s1_vals[0]); @@ -1572,7 +1881,7 @@ void kOptimizer32bit1State(T *g, T *p, float *state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1) + const sycl::nd_item<3> &item_ct1) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -1593,12 +1902,17 @@ void kOptimizer32bit1State(T *g, T *p, float s1_vals[NUM_PER_THREAD]; - typedef cub::BlockLoad Load; - typedef cub::BlockStore Store; + //typedef cub::BlockLoad Load; + //typedef cub::BlockStore Store; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockStore StoreFloat; + //typedef cub::BlockLoad LoadFloat; + //typedef cub::BlockStore StoreFloat; + sycl::buffer buff_g(g, sycl::range<1>(NUM_PER_THREAD)); + sycl::buffer buff_p(p, sycl::range<1>(NUM_PER_THREAD)); + sycl::buffer buff_state1(state1,sycl::range<1>(NUM_PER_THREAD)); + sycl::buffer buff_state2(state2,sycl::range<1>(NUM_PER_THREAD)); + /* union type_ct5{ typename Load::TempStorage load; typename Store::TempStorage store; @@ -1606,7 +1920,7 @@ void kOptimizer32bit1State(T *g, T *p, typename StoreFloat::TempStorage storef; }; type_ct5 &temp_storage = *(type_ct5 *)temp_storage_ct1; - + */ for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*TH*NUM_PER_THREAD) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; @@ -1618,7 +1932,32 @@ void kOptimizer32bit1State(T *g, T *p, /* DPCT1007:128: Migration of cub::BlockLoad::Load is not supported. */ - Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_g[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); + }); + + }); + + /* DPCT1065:124: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -1626,7 +1965,32 @@ void kOptimizer32bit1State(T *g, T *p, /* DPCT1007:129: Migration of cub::BlockLoad::Load is not supported. */ - LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, s1_vals); + }); + + }); + + /* DPCT1065:125: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -1634,8 +1998,30 @@ void kOptimizer32bit1State(T *g, T *p, /* DPCT1007:130: Migration of cub::BlockLoad::Load is not supported. */ - Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); - + //Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_p[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, p_vals); + }); + + }); # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) { @@ -1682,7 +2068,31 @@ void kOptimizer32bit1State(T *g, T *p, /* DPCT1007:131: Migration of cub::BlockStore::Store is not supported. */ - Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + //Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_p[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, p_vals); + }); + + }); + /* DPCT1065:127: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -1690,7 +2100,31 @@ void kOptimizer32bit1State(T *g, T *p, /* DPCT1007:132: Migration of cub::BlockStore::Store is not supported. */ - StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + //StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, s1_vals); + }); + + }); + } } @@ -1712,7 +2146,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2) { const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; @@ -5051,3 +5485,4 @@ SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::half, 2048, 8) SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::ext::oneapi::bfloat16, 2048, 8) SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, sycl::half, 2048, 8) + From d93fc2ed9fc744cd2ea6a0a5c3653c1267912c00 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 20 Mar 2024 22:56:09 -0700 Subject: [PATCH 07/66] add precondition optimizer kernel sycl --- csrc/sycl/kernels.cpp | 308 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 278 insertions(+), 30 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index fd55f0740..e19295e85 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -570,7 +570,7 @@ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, cons [=](sycl::nd_item<3> item) { auto *d = dacc.get_multi_ptr().get(); auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, values); + group_load(tmp).load(item_ct1,item_ct1.get_local_linear_id(), d, values); }); @@ -2163,18 +2163,22 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c unsigned char m_c1[NUM8BIT]; unsigned char r_c2[NUM8BIT]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadUInt8; - typedef sycl::group<3> BlockReduce; - - + //typedef cub::BlockLoad LoadT; + //typedef cub::BlockLoad LoadUInt8; + //typedef sycl::group<3> BlockReduce; + sycl::buffer buff_g(g, sycl::range<1>(NUM_THREADS)); + sycl::buffer buff_p(p, sycl::range<1>(NUM_THREADS)); + sycl::buffer buff_state1(state1,sycl::range<1>(NUM_THREADS)); + sycl::buffer buff_state2(state2,sycl::range<1>(NUM_THREADS)); + + /* union type_ct6{ typename LoadT::TempStorage loadh; typename LoadUInt8::TempStorage loadc; typename BlockReduce::TempStorage reduce; }; type_ct6 &temp_storage = *(type_ct6 *)temp_storage_ct1; - + */ @@ -2196,7 +2200,30 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c /* DPCT1007:156: Migration of cub::BlockLoad::Load is not supported. */ - LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_THREADS); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_g[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); + }); + + }); /* DPCT1065:153: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2204,7 +2231,32 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c /* DPCT1007:157: Migration of cub::BlockLoad::Load is not supported. */ - LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + //LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_THREADS); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, m_c1); + }); + + }); + + /* DPCT1065:154: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2212,7 +2264,31 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c /* DPCT1007:158: Migration of cub::BlockLoad::Load is not supported. */ - LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); + //LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_THREADS); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state2[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, r_c2); + }); + + }); + /* DPCT1065:155: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2260,7 +2336,8 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c /* DPCT1007:7: Migration of cub::Reduce is not supported. */ - local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, sycl::maximum<>(), valid_items); + //local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, sycl::maximum<>(), valid_items); + local_max_s1 = sycl::reduce_over_group(item_ct1.get_group(), local_max_s1, sycl::maximum<>()); /* DPCT1065:152: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2268,7 +2345,8 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c /* DPCT1007:8: Migration of cub::Reduce is not supported. */ - local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, sycl::maximum<>(), valid_items); + //local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, sycl::maximum<>(), valid_items); + local_max_s2 = sycl::reduce_over_group(item_ct1.get_group(), local_max_s2, sycl::maximum<>()); if(unorm != NULL) { /* @@ -2278,7 +2356,8 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c /* DPCT1007:9: Migration of cub::Reduce is not supported. */ - local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, sycl::plus<>(), valid_items); + //local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, sycl::plus<>(), valid_items); + local_unorm = sycl::reduce_over_group(item_ct1.get_group(), local_unorm, sycl::plus<>()); } if(item_ct1.get_local_id(2) == 0) @@ -2305,7 +2384,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha float weight_decay, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, - float *smem_quantiles2, uint8_t *temp_storage_ct1) + float *smem_quantiles2) { const int n_full = (item_ct1.get_local_range(2) * item_ct1.get_group_range(2))*NUM_PER_THREAD2; @@ -2334,15 +2413,17 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha unsigned char c2s[NUM_PER_THREAD2]; T p_vals[NUM_PER_THREAD2]; T g_vals[NUM_PER_THREAD2]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadChar; - - typedef cub::BlockStore StoreChar; - typedef cub::BlockStore StoreT; - + sycl::buffer buff_g(g, sycl::range<1>(NUM_PER_THREAD2)); + sycl::buffer buff_p(p, sycl::range<1>(NUM_PER_THREAD2)); + sycl::buffer buff_state1(state1,sycl::range<1>(NUM_PER_THREAD2)); + sycl::buffer buff_state2(state2,sycl::range<1>(NUM_PER_THREAD2)); - + //typedef cub::BlockLoad LoadT; + //typedef cub::BlockLoad LoadChar; + //typedef cub::BlockStore StoreChar; + //typedef cub::BlockStore StoreT; + /* union type_ct7{ typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; @@ -2350,7 +2431,8 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha typename StoreT::TempStorage storeh; }; type_ct7 &temp_storage = *(type_ct7 *)temp_storage_ct1; - + */ + if(item_ct1.get_local_id(2) < 512) { if(item_ct1.get_local_id(2) < 256) @@ -2370,7 +2452,31 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha /* DPCT1007:167: Migration of cub::BlockLoad::Load is not supported. */ - LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_g[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); + }); + + }); + /* DPCT1065:161: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2378,7 +2484,31 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha /* DPCT1007:168: Migration of cub::BlockLoad::Load is not supported. */ - LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, c1s); + }); + + }); + /* DPCT1065:162: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2386,7 +2516,31 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha /* DPCT1007:169: Migration of cub::BlockLoad::Load is not supported. */ - LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + //LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state2[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, c2s); + }); + + }); + /* DPCT1065:163: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2394,8 +2548,30 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha /* DPCT1007:170: Migration of cub::BlockLoad::Load is not supported. */ - LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); - + //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_p[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, p_vals); + }); + + }); if((i + (item_ct1.get_local_id(2)*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } # pragma unroll 4 @@ -2437,7 +2613,31 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha /* DPCT1007:171: Migration of cub::BlockStore::Store is not supported. */ - StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_p[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, p_vals); + }); + + }); + /* DPCT1065:164: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2445,7 +2645,31 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha /* DPCT1007:172: Migration of cub::BlockStore::Store is not supported. */ - StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, c1s); + }); + + }); + /* DPCT1065:165: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2453,7 +2677,31 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha /* DPCT1007:173: Migration of cub::BlockStore::Store is not supported. */ - StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + //StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state2[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, c2s); + }); + + }); + /* DPCT1065:166: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ From 29e895f259026a25e037c02c80cdd040b25093fa Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Thu, 21 Mar 2024 01:09:22 -0700 Subject: [PATCH 08/66] add koptimizer 8bits sycl --- csrc/sycl/kernels.cpp | 657 ++++++++++++++++++++++++++++++++++++------ csrc/sycl/kernels.h | 28 +- 2 files changed, 580 insertions(+), 105 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index e19295e85..f70bcfe2d 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -2724,7 +2724,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c float* max1, float* new_max1, const float weight_decay, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1) { const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; @@ -2738,11 +2738,15 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c T g_vals[NUM8BIT]; unsigned char m_c1[NUM8BIT]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadUInt8; - typedef sycl::group<3> BlockReduce; - + sycl::buffer buff_g(g, sycl::range<1>(NUM_THREADS)); + sycl::buffer buff_p(p, sycl::range<1>(NUM_THREADS)); + sycl::buffer buff_state1(state1,sycl::range<1>(NUM_THREADS)); + + //typedef cub::BlockLoad LoadT; + //typedef cub::BlockLoad LoadUInt8; + //typedef sycl::group<3> BlockReduce; + /* union type_ct8{ typename LoadT::TempStorage loadh; typename LoadUInt8::TempStorage loadc; @@ -2750,7 +2754,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c }; type_ct8 &temp_storage = *(type_ct8 *)temp_storage_ct1; - + */ if(item_ct1.get_local_id(2) < 256) smem_quantiles1[item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; @@ -2771,7 +2775,30 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c /* DPCT1007:137: Migration of cub::BlockLoad::Load is not supported. */ - LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_THREADS); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_g[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); + }); + + }); /* DPCT1065:136: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2779,7 +2806,30 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c /* DPCT1007:138: Migration of cub::BlockLoad::Load is not supported. */ - LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + //LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_THREADS); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, m_c1); + }); + + }); #pragma unroll 16 for(int j = 0; j < NUM8BIT; j++) @@ -2816,7 +2866,8 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c /* DPCT1007:4: Migration of cub::Reduce is not supported. */ - local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, sycl::maximum<>(), valid_items); + local_max_s1 = sycl::reduce_over_group(item_ct1.get_group(), local_max_s1, sycl::maximum<>()); + //local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, sycl::maximum<>(), valid_items); if(item_ct1.get_local_id(2) == 0){ atomicMax(&new_max1[0], local_max_s1); } if(unorm != NULL) { @@ -2827,7 +2878,8 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c /* DPCT1007:5: Migration of cub::Reduce is not supported. */ - local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, sycl::plus<>(), valid_items); + //local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, sycl::plus<>(), valid_items); + local_norm = sycl::reduce_over_group(item_ct1.get_group(), local_norm, sycl::plus<>()); if(item_ct1.get_local_id(2) == 0){ dpct::atomic_fetch_add(&unorm[0], local_unorm); } } @@ -2844,8 +2896,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, - uint8_t *temp_storage_ct1) + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1) { const int n_full = (item_ct1.get_local_range(2) * item_ct1.get_group_range(2))*NUM_PER_THREAD2; @@ -2867,14 +2918,20 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, unsigned char c1s[NUM_PER_THREAD2]; T p_vals[NUM_PER_THREAD2]; T g_vals[NUM_PER_THREAD2]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadChar; + + sycl::buffer buff_g(g, sycl::range<1>(NUM_PER_THREAD2)); + sycl::buffer buff_p(p, sycl::range<1>(NUM_PER_THREAD2)); + sycl::buffer buff_state1(state1,sycl::range<1>(NUM_PER_THREAD2)); + + + //typedef cub::BlockLoad LoadT; + //typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore StoreChar; - typedef cub::BlockStore StoreT; + //typedef cub::BlockStore StoreChar; + //typedef cub::BlockStore StoreT; - + /* union type_ct9{ typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; @@ -2882,7 +2939,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, typename StoreT::TempStorage storeh; }; type_ct9 &temp_storage = *(type_ct9 *)temp_storage_ct1; - + */ if(item_ct1.get_local_id(2) < 256) smem_quantiles1[item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; @@ -2897,7 +2954,32 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, /* DPCT1007:145: Migration of cub::BlockLoad::Load is not supported. */ - LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_g[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); + }); + + }); + + /* DPCT1065:141: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2905,7 +2987,31 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, /* DPCT1007:146: Migration of cub::BlockLoad::Load is not supported. */ - LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, c1s); + }); + + }); + /* DPCT1065:142: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2913,8 +3019,30 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, /* DPCT1007:147: Migration of cub::BlockLoad::Load is not supported. */ - LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); - + //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_p[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, p_vals); + }); + + }); if((i + (item_ct1.get_local_id(2)*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } # pragma unroll 4 @@ -2972,7 +3100,31 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, /* DPCT1007:148: Migration of cub::BlockStore::Store is not supported. */ - StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_p[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, p_vals); + }); + + }); + /* DPCT1065:143: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2980,7 +3132,30 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, /* DPCT1007:149: Migration of cub::BlockStore::Store is not supported. */ - StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, c1s); + }); + + }); /* DPCT1065:144: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2997,9 +3172,9 @@ SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int int valid_items = 0; - typedef cub::BlockLoad LoadT; + //typedef cub::BlockLoad LoadT; - + sycl::buffer buff_g(g, sycl::range<1>(NUM_VALS)); T vals[NUM_VALS]; @@ -3017,8 +3192,30 @@ SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int /* DPCT1007:203: Migration of cub::BlockLoad::Load is not supported. */ - LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); - + //LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_VALS); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_g[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, vals); + }); + + }); #pragma unroll NUM_VALS for(int j = 0; j < NUM_VALS; j++) local_sum += ((float)vals[j])*((float)vals[j]); @@ -3026,7 +3223,8 @@ SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int /* DPCT1007:12: Migration of cub::Sum is not supported. */ - local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); + locacl_sum = sycl::reduce_over_group(item_ct1.get_group(), local_sum, sycl::plus<>()); + //local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); if(item_ct1.get_local_id(2) == 0) { if(step == 1) @@ -3087,21 +3285,21 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char unsigned char c2s[N_PER_TH]; T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadChar; - - typedef cub::BlockStore StoreChar; - typedef cub::BlockStore StoreT; - - - - - - + sycl::buffer buff_g(g, sycl::range<1>(N_PER_TH)); + sycl::buffer buff_p(p, sycl::range<1>(N_PER_TH)); + sycl::buffer buff_state1(state1,sycl::range<1>(N_PER_TH)); + sycl::buffer buff_state2(state2,sycl::range<1>(N_PER_TH)); + //typedef cub::BlockLoad LoadT; + //typedef cub::BlockLoad LoadChar; + //typedef cub::BlockStore StoreChar; + //typedef cub::BlockStore StoreT; + + + /* union type_ct10{ typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; @@ -3109,6 +3307,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char typename StoreT::TempStorage storeh; }; type_ct10 &temp_storage = *(type_ct10 *)temp_storage_ct1; + */ // init: 0.2 -> 0.23 // 0.23 -> 0.23 @@ -3145,29 +3344,100 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char /* DPCT1007:183: Migration of cub::BlockLoad::Load is not supported. */ - LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - /* - DPCT1065:176: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:184: Migration of cub::BlockLoad::Load is not supported. - */ - LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); - /* - DPCT1065:177: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:185: Migration of cub::BlockLoad::Load is not supported. - */ - LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); - - new_local_abs_max1 = -FLT_MAX; - new_local_abs_max2 = -FLT_MAX; - - // update: 2.48/1.57 -> 2.51/1.60 - # pragma unroll N_PER_TH + //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_g[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); + }); + + }); + + /* + DPCT1065:176: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:184: Migration of cub::BlockLoad::Load is not supported. + */ + //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, c1s); + }); + + }); + + /* + DPCT1065:177: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. + */ + item_ct1.barrier(); + /* + DPCT1007:185: Migration of cub::BlockLoad::Load is not supported. + */ + //LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state2[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, c2s); + }); + + }); + + new_local_abs_max1 = -FLT_MAX; + new_local_abs_max2 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { if(!sycl::isnan((float)g_vals[j]) && !sycl::isinf((float)g_vals[j])) @@ -3227,7 +3497,33 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char /* DPCT1007:186: Migration of cub::BlockLoad::Load is not supported. */ - LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_p[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, p_vals); + }); + + }); + + + // reduce: 2.67/1.69 -> 2.67/1.70 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) @@ -3249,8 +3545,30 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char /* DPCT1007:187: Migration of cub::BlockStore::Store is not supported. */ - StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - + //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_p[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, p_vals); + }); + + }); // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) @@ -3276,7 +3594,31 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char /* DPCT1007:188: Migration of cub::BlockStore::Store is not supported. */ - StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, c1s); + }); + + }); + /* DPCT1065:182: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -3284,7 +3626,30 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char /* DPCT1007:189: Migration of cub::BlockStore::Store is not supported. */ - StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + //StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state2[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, c2s); + }); + + }); } } @@ -3306,7 +3671,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, - float *smem_exchange1, uint8_t *temp_storage_ct1) + float *smem_exchange1) { //const int n_full = n + (n%BLOCK_SIZE); @@ -3324,24 +3689,28 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadChar; + //typedef cub::BlockLoad LoadT; + //typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore StoreChar; - typedef cub::BlockStore StoreT; + //typedef cub::BlockStore StoreChar; + //typedef cub::BlockStore StoreT; - - - - + sycl::buffer buff_g(g, sycl::range<1>(N_PER_TH)); + sycl::buffer buff_p(p, sycl::range<1>(N_PER_TH)); + sycl::buffer buff_state1(state1,sycl::range<1>(N_PER_TH)); + sycl::buffer buff_state2(state2,sycl::range<1>(N_PER_TH)); + + /* union type_ct11{ typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; typename StoreChar::TempStorage storec; typename StoreT::TempStorage storeh; }; + type_ct11 &temp_storage = *(type_ct11 *)temp_storage_ct1; + */ // init: 0.2 -> 0.23 // 0.23 -> 0.23 @@ -3370,7 +3739,30 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char /* DPCT1007:197: Migration of cub::BlockLoad::Load is not supported. */ - LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_g[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); + }); + + }); /* DPCT1065:192: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -3378,7 +3770,30 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char /* DPCT1007:198: Migration of cub::BlockLoad::Load is not supported. */ - LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, c1s); + }); + + }); /* DPCT1065:193: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -3386,8 +3801,30 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char /* DPCT1007:199: Migration of cub::BlockLoad::Load is not supported. */ - LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); - + //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_p[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, p_vals); + }); + + }); new_local_abs_max1 = -FLT_MAX; // update: 2.48/1.57 -> 2.51/1.60 @@ -3489,8 +3926,30 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char /* DPCT1007:200: Migration of cub::BlockStore::Store is not supported. */ - StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - + //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_p[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, p_vals); + }); + + }); // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) @@ -3515,15 +3974,35 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char /* DPCT1007:201: Migration of cub::BlockStore::Store is not supported. */ - StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_store = dpct::group::workgroup_store; + size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_state1[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, c1s); + }); + + }); + } } -template void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, - const sycl::nd_item<3> &item_ct1, - uint8_t *temp_storage_ct1, - float *smem_row_absmax_values, - int *smem_row_nnz_values) +template void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values) { // 0. reset stats to -FLT_MAX // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 499493f23..2f9014b5a 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -16,8 +16,7 @@ //template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); template extern SYCL_EXTERNAL void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n, - const sycl::nd_item<3> &item_ct1, - uint8_t *temp_storage_ct1); + const sycl::nd_item<3> &item_ct1); extern SYCL_EXTERNAL void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, const sycl::nd_item<3> &item_ct1); @@ -28,36 +27,35 @@ template &item_ct1, float *smem_code, float *smem_absmax_value); -template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1); +template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); template extern SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + const sycl::nd_item<3> &item_ct1); template extern SYCL_EXTERNAL void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + const sycl::nd_item<3> &item_ct1); template extern SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + const sycl::nd_item<3> &item_ct1); template extern SYCL_EXTERNAL void kOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + const sycl::nd_item<3> &item_ct1); template extern SYCL_EXTERNAL void @@ -69,7 +67,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c float* max1, float* new_max1, const float weight_decay, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1); @@ -82,8 +80,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, - uint8_t *temp_storage_ct1); + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1); @@ -96,7 +93,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2); @@ -110,7 +107,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, - float *smem_quantiles2, uint8_t *temp_storage_ct1); + float *smem_quantiles2); template extern SYCL_EXTERNAL void kOptimizerStatic8bit2StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, @@ -120,8 +117,7 @@ template extern SYCL_EX const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, - float *smem_exchange1, float *smem_exchange2, - uint8_t *temp_storage_ct1); + float *smem_exchange1, float *smem_exchange2); template extern SYCL_EXTERNAL void kOptimizerStatic8bit1StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, @@ -133,7 +129,7 @@ template extern SYCL_EX const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, - float *smem_exchange1, uint8_t *temp_storage_ct1); + float *smem_exchange1); template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n,const sycl::nd_item<3> &item_ct1); From 98bd39ba91084f77aba0e510cf18f337c24683f0 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Thu, 21 Mar 2024 05:37:35 -0700 Subject: [PATCH 09/66] add row col transform kernels --- csrc/sycl/kernels.cpp | 336 ++++++++++++++++++++++++++++++++---------- csrc/sycl/kernels.h | 2 +- 2 files changed, 263 insertions(+), 75 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index f70bcfe2d..8f597b358 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -644,17 +644,20 @@ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); T vals[NUM_ESTIMATE]; - - typedef cub::BlockRadixSort BlockRadixSort; - typedef cub::BlockLoad LoadFloat; - + + //typedef cub::BlockRadixSort BlockRadixSort; + //typedef cub::BlockLoad LoadFloat; + /* union type_ct1{ typename LoadFloat::TempStorage loadf; typename BlockRadixSort::TempStorage sort; int smem_qidx[BLOCK_ESTIMATE]; }; type_ct1 &temp_storage = *(type_ct1 *)temp_storage_ct1; - + */ + int smem_qidx[BLOCK_ESTIMATE]; + sycl::buffer buff_A(A,sycl::range<1>(NUM_ESTIMATE)); + sycl::buffer for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_ESTIMATE) { valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; @@ -673,10 +676,31 @@ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset /* DPCT1007:81: Migration of cub::BlockLoad::Load is not supported. */ + //LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(NUM_ESTIMATE); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_A[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, vals); + }); + + }); - - LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); - #pragma unroll 4 for(int j = 0; j < NUM_ESTIMATE; j++) vals[j] = ((float)vals[j]) * reciprocal_num_blocks; @@ -692,8 +716,29 @@ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset /* DPCT1007:82: Migration of cub::BlockRadixSort::SortBlockedToStriped is not supported. */ - BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); - + //BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_radix_sort = dpct::group::radix_sort; + size_t temp_storage_size = group_radix_sort::get_local_memory_size(NUM_ESTIMATE); + sycl::local_accessor tacc( + temp_storage_size, h); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *tmp = tacc.get_multi_ptr().get(); + group_radix_sort(tmp).sort_blocked_to_striped(item, vals); + }); + + }); + /* DPCT1065:78: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -713,7 +758,7 @@ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset DPCT1064:83: Migrated round call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code. */ int local_idx = sycl::round(((offset+(item_ct1.get_local_id(2)*q_interval))*(valid_items-1))); - temp_storage.smem_qidx[local_idx] = item_ct1.get_local_id(2); + smem_qidx[local_idx] = item_ct1.get_local_id(2); } /* @@ -723,8 +768,8 @@ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset for(int i = item_ct1.get_local_id(2); i < BLOCK_ESTIMATE; i+=item_ct1.get_local_range(2)) { - if(temp_storage.smem_qidx[i] != -1) - dpct::atomic_fetch_add(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]); + if(smem_qidx[i] != -1) + dpct::atomic_fetch_add(&code[smem_qidx[i]], vals[i/THREADS_ESTIMATE]); } } } @@ -3260,8 +3305,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, - float *smem_exchange1, float *smem_exchange2, - uint8_t *temp_storage_ct1) + float *smem_exchange1, float *smem_exchange2) { //const int n_full = n + (n%BLOCK_SIZE); @@ -4019,6 +4063,7 @@ template LoadT; typedef sycl::group<3> BlockRowReduce; typedef sycl::group<3> BlockRowSum; @@ -4031,8 +4076,8 @@ template buff_A(A, sycl::range<1>(ITEMS_PER_THREAD)); sycl::half local_data[ITEMS_PER_THREAD]; @@ -4080,8 +4125,30 @@ template(0.0f).convert()[0]); - + //LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, sycl::vec(0.0f).convert()[0]); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(ITEMS_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_A[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, local_data); + }); + + }); #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_data[j] = sycl::fabs(local_data[j]); @@ -4118,7 +4185,7 @@ template()); + row_absmax = (float)sycl::reduce_over_group(item_ct1.get_group(), local_data_fp32, sycl::maximum<>()); if(SPARSE_DECOMP) { /* @@ -4154,8 +4221,28 @@ template; + size_t temp_storage_size = group_exchange::get_local_memory_size(ITEMS_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *tmp = tacc.get_multi_ptr().get(); + group_exchange(tmp).blocked_to_striped(item, local_col_absmax_values); + }); + + }); #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) if(base_col+item_ct1.get_local_id(2)+(j*THREADS) < cols) @@ -4181,20 +4268,16 @@ template(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, - uint8_t *temp_storage_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values); template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, - uint8_t *temp_storage_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values); #define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) -template void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, - const sycl::nd_item<3> &item_ct1, - float *smem_rowStats) +template void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats) { // Strategy: To dequantize we need to load col/row statistics. This can be very expensive @@ -4248,10 +4331,10 @@ template void kdequant_mm_i float local_rowStats[ITEMS_PER_THREAD]; - typedef cub::BlockLoad LoadInt32; - typedef cub::BlockExchange ExchangeInt32; - + //typedef cub::BlockLoad LoadInt32; + //typedef cub::BlockExchange ExchangeInt32; + sycl::buffer buff_A(A, sycl::range<1>(ITEMS_PER_THREAD)); // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. @@ -4292,12 +4375,56 @@ template void kdequant_mm_i /* DPCT1007:206: Migration of cub::BlockLoad::Load is not supported. */ - LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); + //LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_load = dpct::group::workgroup_load; + size_t temp_storage_size = group_load::get_local_memory_size(ITEMS_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_A[subtile_idx], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, local_values); + }); + + }); /* DPCT1007:207: Migration of cub::BlockExchange::BlockedToWarpStriped is not supported. */ - ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); - + //ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + + + using group_exchange = dpct::group::exchange; + size_t temp_storage_size = group_exchange::get_local_memory_size(ITEMS_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *tmp = tacc.get_multi_ptr().get(); + group_exchange(tmp).blocked_to_warpstriped(item, local_values); + }); + + }); + #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; @@ -4325,10 +4452,7 @@ template void kdequant_mm_i } -template void kDoubleRowColQuant(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, - const sycl::nd_item<3> &item_ct1, - float *smem_row_stats, - unsigned int *smem_nnz_row_idx) +template void kDoubleRowColQuant(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx) { // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD // Each thread reads the same column but multiple rows @@ -4347,12 +4471,15 @@ template LoadHalf; + //typedef cub::BlockLoad LoadHalf; - typedef cub::BlockStore StoreInt8; + //typedef cub::BlockStore StoreInt8; - + sycl::buffer buff_A(A,sycl::range<1>(ITEMS_PER_THREAD)); + sycl::buffer buff_out_row_normed(out_row_normed,sycl::range<1>(ITEMS_PER_THREAD)); + sycl::buffer buff_out_col_normed(out_col_normed,sycl::range<1>(ITEMS_PER_THREAD)); @@ -4394,7 +4521,31 @@ template ; + size_t temp_storage_size = group_load::get_local_memory_size(ITEMS_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_A[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), d, local_data); + }); + + }); + float row_stat = 127.0f / smem_row_stats[row]; // 2. quantize data with row/col stats @@ -4427,8 +4578,30 @@ template ; + size_t temp_storage_size = group_store::get_local_memory_size(ITEMS_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_out_row_normed[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, local_quantized_data); + }); + + }); // 2. quantize data with row/col stats #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) @@ -4445,17 +4618,37 @@ template ; + size_t temp_storage_size = group_store::get_local_memory_size(ITEMS_PER_THREAD); + sycl::local_accessor tacc( + temp_storage_size, h); + sycl::accessor dacc(buff_out_col_normed[i], h, sycl::read_write); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item) { + auto *d = dacc.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), d, local_quantized_data); + }); + + }); } } /* DPCT1110:14: The total declared local variable size in device function kTransformRowToFormat exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. */ -template SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, - const sycl::nd_item<3> &item_ct1, - char *smem_data) +template SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data) { // 0. Load data into 32*32 shared memory tiles @@ -4506,8 +4699,10 @@ template BlockExchange; + //typedef cub::BlockExchange BlockExchange; + + // we load row after row from the base_position // Load data row by row int warps = item_ct1.get_local_range(2)/32; @@ -4789,9 +4984,8 @@ template /* DPCT1110:13: The total declared local variable size in device function kspmm_coo_very_sparse_naive exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. */ -SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, - const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats) +SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, + sycl::half *smem_dequant_stats) { // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block @@ -4937,8 +5131,7 @@ SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int } } -template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, - const sycl::nd_item<3> &item_ct1) +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1) { int local_colidx = idx[item_ct1.get_group(2)]; @@ -5044,10 +5237,7 @@ template inline void vector_load(T *loca /* DPCT1110:15: The total declared local variable size in device function gemm_device exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. */ -template SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, - const sycl::nd_item<3> &item_ct1, - T *smem_A, - T *smem_B) +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, T *smem_A, T *smem_B) { #if DPCT_COMPATIBILITY_TEMP >= 750 @@ -5963,17 +6153,15 @@ template unsigned char dQuantize<0>(float* smem_code, const float rand, float x) template unsigned char dQuantize<1>(float* smem_code, const float rand, float x); template SYCL_EXTERNAL void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n, - const sycl::nd_item<3> &item_ct1, - uint8_t *temp_storage_ct1); + const sycl::nd_item<3> &item_ct1); template SYCL_EXTERNAL void kEstimateQuantiles(sycl::half *__restrict__ const A, float *code, const float offset, const sycl::half max_val, const int n, - const sycl::nd_item<3> &item_ct1, - uint8_t *temp_storage_ct1); + const sycl::nd_item<3> &item_ct1); #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ float* state1, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); \ + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1; \ SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(MOMENTUM, sycl::half) SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) @@ -5987,7 +6175,7 @@ SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ template void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1); \ SYCL_EXTERNAL MAKE_Optimizer32bit1State(MOMENTUM, sycl::half) SYCL_EXTERNAL MAKE_Optimizer32bit1State(MOMENTUM, float) @@ -6003,7 +6191,7 @@ SYCL_EXTERNAL MAKE_Optimizer32bit1State(ADAGRAD, float) template void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ float* state1, float* state2, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); \ + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1); \ SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, float) SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::half) @@ -6011,13 +6199,13 @@ SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::ext::oneapi::bfl template SYCL_EXTERNAL void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + const sycl::nd_item<3> &item_ct1); template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::half* g, sycl::half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + const sycl::nd_item<3> &item_ct1); template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::ext::oneapi::bfloat16* g, sycl::ext::oneapi::bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1); + const sycl::nd_item<3> &item_ct1); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ template void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ @@ -6029,7 +6217,7 @@ template void kPreconditionOptimizerStatic8bit1State(gtype* p, gty float* max1, float* new_max1, \ const float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, float *smem_quantiles1); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1); \ SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(MOMENTUM, sycl::half) SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(MOMENTUM, float) @@ -6048,7 +6236,7 @@ template void kOptimizerStatic8bit1State(gtype* p, gtype* const g, float* max1, float* new_max1, \ float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, uint8_t *temp_storage_ct1); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1); \ SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(MOMENTUM, sycl::half) SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(MOMENTUM, float) @@ -6065,7 +6253,7 @@ template void kPreconditionOptimizerStatic8bit2State(gtype* p, gty float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, uint8_t *temp_storage_ct1, float *smem_quantiles1, float *smem_quantiles2); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2); \ SYCL_EXTERNAL MAKE_PreconditionStatic8bit2State(ADAM, sycl::half) SYCL_EXTERNAL MAKE_PreconditionStatic8bit2State(ADAM, float) @@ -6079,7 +6267,7 @@ template void kOptimizerStatic8bit2State(gtype* p, gtype* const g, float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, uint8_t *temp_storage_ct1); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2); \ SYCL_EXTERNAL MAKE_optimizerStatic8bit2State(ADAM, sycl::half) SYCL_EXTERNAL MAKE_optimizerStatic8bit2State(ADAM, float) @@ -6186,7 +6374,7 @@ template void kOptimizerStatic8bit2StateBlockwise &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2, uint8_t *temp_storage_ct1); \ + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2); \ SYCL_EXTERNAL MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) SYCL_EXTERNAL MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::half, 2048, 8) @@ -6201,7 +6389,7 @@ template void kOptimizerStatic8bit1StateBlockwise &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1, uint8_t *temp_storage_ct1); \ + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1); \ SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::half, 2048, 8) diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 2f9014b5a..e334d1f6c 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -156,7 +156,7 @@ extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( const sycl::nd_item<3> &item_ct1, float *smem_rowStats); template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, - const sycl::nd_item<3> &item_ct1,uint8_t *temp_storage_ct1,float *smem_row_absmax_values,int *smem_row_nnz_values); + const sycl::nd_item<3> &item_ct1,float *smem_row_absmax_values,int *smem_row_nnz_values); template extern SYCL_EXTERNAL void kDoubleRowColQuant(sycl::half *__restrict__ const A, From 8e15d5a871c4b368403fa37ab8c421bc9459b273 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Fri, 22 Mar 2024 00:33:54 -0700 Subject: [PATCH 10/66] add gemm kernels --- csrc/sycl/kernels.cpp | 88 ++++++++++++++++++++++++++++++++----------- 1 file changed, 66 insertions(+), 22 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 8f597b358..9a35824d8 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -10,7 +10,7 @@ #include #include "kernels.dp.hpp" #include - +#include #include @@ -5268,8 +5268,16 @@ template SYCL_EXTERNAL void gemm_device(int /* DPCT1082:18: Migration of nvcuda::wmma::row_major type is not supported. */ - wmma::fragment 8, 32, 16, wmma::fragmentor, 8, 32, 16, wmma::fragment; - + sycl::ext::oneapi::experimental::matrix::joint_matrix a_frag; + sycl::ext::oneapi::experimental::matrix::joint_matrix b_frag; + sycl::ext::oneapi::experimental::matrix::joint_matrix c_frag; + sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(item_ct1.get_sub_group(), c_frag, 0.0f); + + //wmma::fragment a_frag; + //wmma::fragment b_frag; + //wmma::fragment c_frag; + //wmma::fill_fragment(c_frag, 0.0f); + int ticktock = 0; int idx = 0 + item_ct1.get_local_id(2); int loaded_values = 0; @@ -5426,15 +5434,23 @@ template SYCL_EXTERNAL void gemm_device(int /* DPCT1007:25: Migration of nvcuda::wmma::load_matrix_sync is not supported. */ - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + //wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); + /* DPCT1007:26: Migration of nvcuda::wmma::load_matrix_sync is not supported. */ - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); + /* DPCT1007:27: Migration of nvcuda::wmma::mma_sync is not supported. */ - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(item_ct1.get_sub_group(), c_frag, a_frag, b_frag, c_frag); + } + } } } @@ -5449,15 +5465,20 @@ template SYCL_EXTERNAL void gemm_device(int /* DPCT1007:28: Migration of nvcuda::wmma::load_matrix_sync is not supported. */ - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + //wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); + /* DPCT1007:29: Migration of nvcuda::wmma::load_matrix_sync is not supported. */ - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); + /* DPCT1007:30: Migration of nvcuda::wmma::mma_sync is not supported. */ - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(item_ct1.get_sub_group(), c_frag, a_frag, b_frag, c_frag); } // 129 mu @@ -5465,8 +5486,9 @@ template SYCL_EXTERNAL void gemm_device(int /* DPCT1007:31: Migration of nvcuda::wmma::store_matrix_sync is not supported. */ - wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); - + //wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + sycl::ext::oneapi::experimental::matrix::joint_matrix_store(item_ct1.get_sub_group(), c_frag, sycl::address_space_cast&(smem_A[0]), 32, sycl::ext::oneapi::experimental::matrix::layout::row_major); + if(col_offset + warp_lane < M) out[col_offset + warp_lane] = smem_A[warp_lane]; #endif @@ -5537,8 +5559,17 @@ template SYCL_EXTERNAL void kgemm_4bit_inference(int M /* DPCT1082:35: Migration of nvcuda::wmma::row_major type is not supported. */ - wmma::fragment 8, 32, 16, wmma::fragmentor, 8, 32, 16, wmma::fragment; - + + sycl::ext::oneapi::experimental::matrix::joint_matrix a_frag; + sycl::ext::oneapi::experimental::matrix::joint_matrix b_frag; + sycl::ext::oneapi::experimental::matrix::joint_matrix c_frag; + sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(item_ct1.get_sub_group(), c_frag, 0.0f); + + //wmma::fragment a_frag; + //wmma::fragment b_frag; + //wmma::fragment c_frag; + //wmma::fill_fragment(c_frag, 0.0f); + for(int i = item_ct1.get_local_id(2); i < (8*32); i+=item_ct1.get_local_range(2)) smem_C[i] = 0.0f; @@ -5684,15 +5715,21 @@ template SYCL_EXTERNAL void kgemm_4bit_inference(int M /* DPCT1007:42: Migration of nvcuda::wmma::load_matrix_sync is not supported. */ - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + //wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); + /* DPCT1007:43: Migration of nvcuda::wmma::load_matrix_sync is not supported. */ - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); + /* DPCT1007:44: Migration of nvcuda::wmma::mma_sync is not supported. */ - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(item_ct1.get_sub_group(), c_frag, a_frag, b_frag, c_frag); + } } @@ -5714,15 +5751,21 @@ template SYCL_EXTERNAL void kgemm_4bit_inference(int M /* DPCT1007:45: Migration of nvcuda::wmma::load_matrix_sync is not supported. */ - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); + /* DPCT1007:46: Migration of nvcuda::wmma::load_matrix_sync is not supported. */ - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); + /* DPCT1007:47: Migration of nvcuda::wmma::mma_sync is not supported. */ - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(item_ct1.get_sub_group(), c_frag, a_frag, b_frag, c_frag); +wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + } // 129 mu @@ -5730,10 +5773,11 @@ template SYCL_EXTERNAL void kgemm_4bit_inference(int M /* DPCT1007:48: Migration of nvcuda::wmma::store_matrix_sync is not supported. */ - wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); - - //printnonzero(smem_C, 32, ""); + //wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); + sycl::ext::oneapi::experimental::matrix::joint_matrix_store(item_ct1.get_sub_group(), c_frag, sycl::address_space_cast&(smem_A[0]), 32, sycl::ext::oneapi::experimental::matrix::layout::row_major); + //printnonzero(smem_C, 32, ""); + if(col_offset + warp_lane < M) out[col_offset + warp_lane] = smem_C[warp_lane]; #endif From bd9e8290f75dc2a0862e3d6fc91a68cc9ee8beb9 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Fri, 22 Mar 2024 04:52:44 -0700 Subject: [PATCH 11/66] fix header path --- csrc/sycl/kernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 9a35824d8..3614aad92 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -8,7 +8,7 @@ #include #include #include -#include "kernels.dp.hpp" +#include "kernels.h" #include #include #include From 45318d64c4d420fb9e76672a8d1ed0b0f5517308 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Fri, 22 Mar 2024 05:48:42 -0700 Subject: [PATCH 12/66] fix handler kernels --- csrc/sycl/kernels.cpp | 128 +++++++++++++++++++++--------------------- 1 file changed, 64 insertions(+), 64 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 3614aad92..afe279646 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -551,7 +551,7 @@ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, cons sycl::buffer buff_values(smem_max_values, sycl::range<1>(8*BLOCK_SIZE/32)); sycl::buffer buff_A(A,sycl::range<1>(8*BLOCK_SIZE/32)); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -677,7 +677,7 @@ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset DPCT1007:81: Migration of cub::BlockLoad::Load is not supported. */ //LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -717,7 +717,7 @@ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset DPCT1007:82: Migration of cub::BlockRadixSort::SortBlockedToStriped is not supported. */ //BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_radix_sort = dpct::group::radix_sort; @@ -820,7 +820,7 @@ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, c */ item_ct1.barrier(sycl::access::fence_space::local_space); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::- &h) { using group_load = dpct::group::workgroup_load; @@ -860,7 +860,7 @@ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, c item_ct1.barrier(sycl::access::fence_space::local_space); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_store; @@ -937,7 +937,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl */ //LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -1003,7 +1003,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl //LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -1071,7 +1071,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl DPCT1007:89: Migration of cub::BlockStore::Store is not supported. */ - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -1151,7 +1151,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * DPCT1007:93: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -1218,7 +1218,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * */ //StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -1320,7 +1320,7 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, */ //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -1353,7 +1353,7 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, DPCT1007:102: Migration of cub::BlockLoad::Load is not supported. */ //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -1387,7 +1387,7 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, DPCT1007:103: Migration of cub::BlockLoad::Load is not supported. */ //LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -1516,7 +1516,7 @@ void kOptimizer32bit2State(T* g, T* p, DPCT1007:111: Migration of cub::BlockLoad::Load is not supported. */ //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -1549,7 +1549,7 @@ void kOptimizer32bit2State(T* g, T* p, DPCT1007:112: Migration of cub::BlockLoad::Load is not supported. */ //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -1583,7 +1583,7 @@ void kOptimizer32bit2State(T* g, T* p, */ //LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -1616,7 +1616,7 @@ void kOptimizer32bit2State(T* g, T* p, */ //Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -1672,7 +1672,7 @@ void kOptimizer32bit2State(T* g, T* p, DPCT1007:115: Migration of cub::BlockStore::Store is not supported. */ //Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -1704,7 +1704,7 @@ void kOptimizer32bit2State(T* g, T* p, DPCT1007:116: Migration of cub::BlockStore::Store is not supported. */ //StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -1738,7 +1738,7 @@ void kOptimizer32bit2State(T* g, T* p, DPCT1007:117: Migration of cub::BlockStore::Store is not supported. */ //StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -1809,7 +1809,7 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, DPCT1007:121: Migration of cub::BlockLoad::Load is not supported. */ //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -1842,7 +1842,7 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, DPCT1007:122: Migration of cub::BlockLoad::Load is not supported. */ //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -1978,7 +1978,7 @@ void kOptimizer32bit1State(T *g, T *p, DPCT1007:128: Migration of cub::BlockLoad::Load is not supported. */ //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -2011,7 +2011,7 @@ void kOptimizer32bit1State(T *g, T *p, DPCT1007:129: Migration of cub::BlockLoad::Load is not supported. */ //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -2044,7 +2044,7 @@ void kOptimizer32bit1State(T *g, T *p, DPCT1007:130: Migration of cub::BlockLoad::Load is not supported. */ //Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -2114,7 +2114,7 @@ void kOptimizer32bit1State(T *g, T *p, DPCT1007:131: Migration of cub::BlockStore::Store is not supported. */ //Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -2146,7 +2146,7 @@ void kOptimizer32bit1State(T *g, T *p, DPCT1007:132: Migration of cub::BlockStore::Store is not supported. */ //StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -2246,7 +2246,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c DPCT1007:156: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -2277,7 +2277,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c DPCT1007:157: Migration of cub::BlockLoad::Load is not supported. */ //LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -2310,7 +2310,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c DPCT1007:158: Migration of cub::BlockLoad::Load is not supported. */ //LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -2498,7 +2498,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha DPCT1007:167: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -2530,7 +2530,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha DPCT1007:168: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -2562,7 +2562,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha DPCT1007:169: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -2594,7 +2594,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha DPCT1007:170: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -2659,7 +2659,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha DPCT1007:171: Migration of cub::BlockStore::Store is not supported. */ //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -2691,7 +2691,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha DPCT1007:172: Migration of cub::BlockStore::Store is not supported. */ //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -2723,7 +2723,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha DPCT1007:173: Migration of cub::BlockStore::Store is not supported. */ //StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -2821,7 +2821,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c DPCT1007:137: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -2852,7 +2852,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c DPCT1007:138: Migration of cub::BlockLoad::Load is not supported. */ //LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -3000,7 +3000,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, DPCT1007:145: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -3033,7 +3033,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, DPCT1007:146: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -3065,7 +3065,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, DPCT1007:147: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -3146,7 +3146,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, DPCT1007:148: Migration of cub::BlockStore::Store is not supported. */ //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -3178,7 +3178,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, DPCT1007:149: Migration of cub::BlockStore::Store is not supported. */ //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -3238,7 +3238,7 @@ SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int DPCT1007:203: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -3389,7 +3389,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char DPCT1007:183: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -3421,7 +3421,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char DPCT1007:184: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -3453,7 +3453,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char DPCT1007:185: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -3542,7 +3542,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char DPCT1007:186: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -3590,7 +3590,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char DPCT1007:187: Migration of cub::BlockStore::Store is not supported. */ //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -3639,7 +3639,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char DPCT1007:188: Migration of cub::BlockStore::Store is not supported. */ //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -3671,7 +3671,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char DPCT1007:189: Migration of cub::BlockStore::Store is not supported. */ //StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -3784,7 +3784,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char DPCT1007:197: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -3815,7 +3815,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char DPCT1007:198: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -3846,7 +3846,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char DPCT1007:199: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -3971,7 +3971,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char DPCT1007:200: Migration of cub::BlockStore::Store is not supported. */ //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -4019,7 +4019,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char DPCT1007:201: Migration of cub::BlockStore::Store is not supported. */ //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_store = dpct::group::workgroup_store; @@ -4126,7 +4126,7 @@ template(0.0f).convert()[0]); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -4222,7 +4222,7 @@ template; @@ -4376,7 +4376,7 @@ template void kdequant_mm_i DPCT1007:206: Migration of cub::BlockLoad::Load is not supported. */ //LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_load = dpct::group::workgroup_load; @@ -4403,7 +4403,7 @@ template void kdequant_mm_i DPCT1007:207: Migration of cub::BlockExchange::BlockedToWarpStriped is not supported. */ //ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); - dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { + dpct::get_in_order_queue().submit([&](sycl::handler &h) { using group_exchange = dpct::group::exchange; @@ -4522,7 +4522,7 @@ template ; @@ -4579,7 +4579,7 @@ template ; @@ -4619,7 +4619,7 @@ template ; From c9d606bb3c35057dd3ff040e059769cacaa6eee5 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Sat, 30 Mar 2024 00:54:26 -0700 Subject: [PATCH 13/66] integrate kenrels with ops on kquant --- csrc/sycl/kernels.cpp | 953 ++++++--------------- csrc/sycl/ops.cpp | 1860 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 2136 insertions(+), 677 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index afe279646..767ea5b03 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -632,11 +632,15 @@ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, cons #define THREADS_ESTIMATE 512 #define NUM_ESTIMATE 8 #define BLOCK_ESTIMATE 4096 +typedef sycl::accessor sycl_la_float; +typedef sycl::accessor sycl_la_T; +typedef sycl::accessor sycl_la_unsigned_char; + template SYCL_EXTERNAL -void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n, - const sycl::nd_item<3> &item_ct1) +void kEstimateQuantiles(const T *buff_A, float *code, const float offset, const T max_val, const int n, + const sycl::nd_item<3> &item_ct1, sycl_la_T tacc) { const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; @@ -656,8 +660,7 @@ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset type_ct1 &temp_storage = *(type_ct1 *)temp_storage_ct1; */ int smem_qidx[BLOCK_ESTIMATE]; - sycl::buffer buff_A(A,sycl::range<1>(NUM_ESTIMATE)); - sycl::buffer + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_ESTIMATE) { valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; @@ -673,33 +676,15 @@ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset DPCT1065:76: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); - /* - DPCT1007:81: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_ESTIMATE); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_A[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], vals); #pragma unroll 4 for(int j = 0; j < NUM_ESTIMATE; j++) @@ -713,32 +698,17 @@ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset // sort into striped pattern to mitigate bank conflicts // striped pattern index for thread 0 [0, 1024, 2048, 3096] // striped pattern index for thread 1 [1, 1025, 2049, 3097] - /* - DPCT1007:82: Migration of cub::BlockRadixSort::SortBlockedToStriped is not supported. - */ + //BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_radix_sort = dpct::group::radix_sort; - size_t temp_storage_size = group_radix_sort::get_local_memory_size(NUM_ESTIMATE); - sycl::local_accessor tacc( - temp_storage_size, h); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *tmp = tacc.get_multi_ptr().get(); - group_radix_sort(tmp).sort_blocked_to_striped(item, vals); - }); - - }); - + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + group_radix_sort(tmp).sort_blocked_to_striped(item, vals); + /* DPCT1065:78: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -776,8 +746,8 @@ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset SYCL_EXTERNAL -void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_code) +void kQuantize(float * code, float * __restrict__ const buff_A, unsigned char *buff_out, const int n, + const sycl::nd_item<3> &item_ct1, float *smem_code, sycl_la_float ltacc, sycl_la_unsigned_char stacc) { const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*NUM_BLOCK) : NUM_BLOCK; @@ -789,12 +759,6 @@ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, c //typedef cub::BlockLoad LoadFloat; //typedef cub::BlockStore StoreChar; - - sycl::buffer buff_smem_code(smem_code,sycl::range<1>(257)); - sycl::buffer buff_vals(vals, sycl::range<1>(NUM)); - sycl::buffer buff_A(A,sycl::range<1>(NUM)); - sycl::buffer buff_out(out,sycl::range<1>(NUM)); - //__shared__ float smem_code[2][257]; if(item_ct1.get_local_id(2) < 256) @@ -819,30 +783,15 @@ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, c DPCT1065:224: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(sycl::access::fence_space::local_space); - - dpct::get_in_order_queue().submit([&](sycl::- &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_A[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], vals); //LoadFloat(loadf).Load(&(A[i]), vals, valid_items); @@ -860,39 +809,23 @@ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, c item_ct1.barrier(sycl::access::fence_space::local_space); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_store; - size_t temp_storage_size = group_store::get_local_memory_size(NUM); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_out[i], h, sycl::read_write); - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, qvals); - }); - - }); - - + //1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_out[0], qvals); //StoreChar(storec).Store(&(out[i]), qvals, valid_items); } } template //__launch_bounds__(TH, 4) -SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, +SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, - float *smem_absmax_value) + float *smem_absmax_value, sycl_la_T ltacc_T, sycl_la_float ltacc_float, sycl_la_unsigned_char stacc) { @@ -937,31 +870,14 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl */ //LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_A[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, vals); - }); - - }); - - + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], vals); // 1. compute local max // 2. broadcast local max @@ -1003,31 +919,15 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl //LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_rand[local_rand_idx], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, rand_vals); - }); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_rand[0], rand_vals); - }); - - } unsigned char packed_4bit = 0; @@ -1071,37 +971,23 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl DPCT1007:89: Migration of cub::BlockStore::Store is not supported. */ - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_out[(DATA_TYPE > 0) ? i/2 : i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, qvals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_out[0], qvals); //StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); } } template -SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1) +SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * buff_A, float * absmax, T *buff_out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc) { const int n_load = (item_ct1.get_group_range(2) * TILE_SIZE); @@ -1112,16 +998,10 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; - - sycl::buffer buff_out(out, sycl::range<1>(NUM_PER_TH)); - sycl::buffer buff_A(A,sycl::range<1>(NUM_PER_TH)); - //typedef cub::BlockLoad LoadChar; //typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; - - for (unsigned int i = base_idx; i < n_load; i += item_ct1.get_group_range(2)*TILE_SIZE) { @@ -1151,29 +1031,15 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * DPCT1007:93: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_A[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, qvals); - }); - - }); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], qvals); + @@ -1218,33 +1084,21 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * */ //StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_out[(DATA_TYPE > 0) ? i/2 : i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + + auto *tmp = stacc.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_out[0], vals); + } } -SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const int n, +SYCL_EXTERNAL void kDequantize(float *code, unsigned char *buff_A, float *buff_out, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code) { const unsigned int numThreads = item_ct1.get_local_range(2) * item_ct1.get_group_range(2); @@ -1260,7 +1114,7 @@ SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const for (int i = idx;i < n; i += numThreads) { - out[i] = smem_code[A[i]]; + buff_out[i] = smem_code[buff_A[i]]; } } @@ -1271,11 +1125,11 @@ template DPCT1110:1: The total declared local variable size in device function kPreconditionOptimizer32bit2State exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. */ SYCL_EXTERNAL -void kPreconditionOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, +void kPreconditionOptimizer32bit2State(T* buff_g, T* p, + float* buff_state1, float* buff_state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1) + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); @@ -1290,9 +1144,6 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, const float correction1 = 1.0f/(1.0f - dpct::pow(beta1, step)); const float correction2 = 1.0f/(1.0f - dpct::pow(beta2, step)); - sycl::buffer buff_g(g, sycl::range<1>(NUM_VALS)); - sycl::buffer buff_state1(state1,sycl::range<1>(NUM_VALS)); - sycl::buffer buff_state2(state2,sycl::range<1>(NUM_VALS)); //typedef cub::BlockLoad Load; @@ -1320,30 +1171,15 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, */ //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_VALS); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_g[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); - }); - - }); - + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], g_vals); /* DPCT1065:98: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1353,30 +1189,15 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, DPCT1007:102: Migration of cub::BlockLoad::Load is not supported. */ //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_VALS); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, s1_vals); - }); - - }); - + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state[0], s1_vals); /* @@ -1387,30 +1208,16 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, DPCT1007:103: Migration of cub::BlockLoad::Load is not supported. */ //LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_VALS); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state2[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, s2_vals); - }); - - }); - + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float2.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state2[0], s2_vals); + # pragma unroll NUM_VALS for(unsigned int j = 0; j < NUM_VALS; j++) g_vals[j] = gnorm_scale*((float)g_vals[j]); @@ -1454,11 +1261,12 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, template SYCL_EXTERNAL -void kOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +void kOptimizer32bit2State(T* buff_g, T* buff_p, + float* buff_state1, float* buff_state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1) + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2 + sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -1483,12 +1291,6 @@ void kOptimizer32bit2State(T* g, T* p, } else{ update_scale = 1.0f; } - sycl::buffer buff_g(g, sycl::range<1>(NUM_PER_THREAD)); - sycl::buffer buff_p(p, sycl::range<1>(NUM_PER_THREAD)); - sycl::buffer buff_state1(state1,sycl::range<1>(NUM_PER_THREAD)); - sycl::buffer buff_state2(state2,sycl::range<1>(NUM_PER_THREAD)); - - //typedef cub::BlockLoad Load; //typedef cub::BlockStore Store; @@ -1516,30 +1318,16 @@ void kOptimizer32bit2State(T* g, T* p, DPCT1007:111: Migration of cub::BlockLoad::Load is not supported. */ //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_g[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); - }); - - }); - + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + /* DPCT1065:105: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1549,30 +1337,14 @@ void kOptimizer32bit2State(T* g, T* p, DPCT1007:112: Migration of cub::BlockLoad::Load is not supported. */ //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, s1_vals); - }); - - }); - + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], s1_vals); /* DPCT1065:106: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1583,30 +1355,16 @@ void kOptimizer32bit2State(T* g, T* p, */ //LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state2[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, s2_vals); - }); - - }); - + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float2.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state2[0], s2_vals); + /* DPCT1065:107: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -1615,31 +1373,17 @@ void kOptimizer32bit2State(T* g, T* p, DPCT1007:114: Migration of cub::BlockLoad::Load is not supported. */ //Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_p[0], p_vals); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_p[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, p_vals); - }); - - }); - + # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) @@ -1672,29 +1416,15 @@ void kOptimizer32bit2State(T* g, T* p, DPCT1007:115: Migration of cub::BlockStore::Store is not supported. */ //Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_p[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, p_vals); - }); - - }); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_T.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_p[0], p_vals); + /* DPCT1065:109: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1704,32 +1434,16 @@ void kOptimizer32bit2State(T* g, T* p, DPCT1007:116: Migration of cub::BlockStore::Store is not supported. */ //StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, s1_vals); - }); - - }); - - - + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_float1.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_state1[0], s1_vals); + /* DPCT1065:110: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -1738,40 +1452,26 @@ void kOptimizer32bit2State(T* g, T* p, DPCT1007:117: Migration of cub::BlockStore::Store is not supported. */ //StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state2[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, s2_vals); - }); - - }); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_float2.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_state2[0], s2_vals); + } } template SYCL_EXTERNAL -void kPreconditionOptimizer32bit1State(T* g, T* p, - float* state1, float *unorm, +void kPreconditionOptimizer32bit1State(T* buff_g, T* buff_p, + float* buff_state1, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1) + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_float ltacc_float1) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); @@ -1785,11 +1485,8 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, //typedef cub::BlockLoad Load; //typedef cub::BlockLoad LoadFloat; //typedef sycl::group<3> BlockReduce; - sycl::buffer buff_g(g, sycl::range<1>(NUM_PER_THREAD)); - sycl::buffer buff_p(p, sycl::range<1>(NUM_PER_THREAD)); - sycl::buffer buff_state1(state1,sycl::range<1>(NUM_PER_THREAD)); - sycl::buffer buff_state2(state2,sycl::range<1>(NUM_PER_THREAD)); - /* + + /* union type_ct4{ typename Load::TempStorage load; typename LoadFloat::TempStorage loadf; @@ -1809,30 +1506,16 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, DPCT1007:121: Migration of cub::BlockLoad::Load is not supported. */ //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_VALS); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_g[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); - }); - - }); - + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + /* DPCT1065:119: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1842,30 +1525,16 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, DPCT1007:122: Migration of cub::BlockLoad::Load is not supported. */ //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_VALS); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, s1_vals); - }); - - }); - + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_stat1[0], s1_vals); + # pragma unroll NUM_VALS for(unsigned int j = 0; j < NUM_VALS; j++) @@ -1922,11 +1591,12 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, template SYCL_EXTERNAL -void kOptimizer32bit1State(T *g, T *p, - float *state1, float *unorm, const float max_unorm, const float param_norm, +void kOptimizer32bit1State(T *buff_g, T *buff_p, + float *buff_state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1) + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, + sycl_la_T stacc_T, sycl_la_float, stacc_float1) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -1953,10 +1623,7 @@ void kOptimizer32bit1State(T *g, T *p, //typedef cub::BlockLoad LoadFloat; //typedef cub::BlockStore StoreFloat; - sycl::buffer buff_g(g, sycl::range<1>(NUM_PER_THREAD)); - sycl::buffer buff_p(p, sycl::range<1>(NUM_PER_THREAD)); - sycl::buffer buff_state1(state1,sycl::range<1>(NUM_PER_THREAD)); - sycl::buffer buff_state2(state2,sycl::range<1>(NUM_PER_THREAD)); + /* union type_ct5{ typename Load::TempStorage load; @@ -1978,30 +1645,16 @@ void kOptimizer32bit1State(T *g, T *p, DPCT1007:128: Migration of cub::BlockLoad::Load is not supported. */ //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_g[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); - }); - - }); - + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + /* DPCT1065:124: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2011,30 +1664,16 @@ void kOptimizer32bit1State(T *g, T *p, DPCT1007:129: Migration of cub::BlockLoad::Load is not supported. */ //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, s1_vals); - }); - - }); - + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], s1_vals); + /* DPCT1065:125: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2044,29 +1683,16 @@ void kOptimizer32bit1State(T *g, T *p, DPCT1007:130: Migration of cub::BlockLoad::Load is not supported. */ //Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_p[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, p_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_p[0], p_vals); + # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) { @@ -2114,29 +1740,15 @@ void kOptimizer32bit1State(T *g, T *p, DPCT1007:131: Migration of cub::BlockStore::Store is not supported. */ //Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_p[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, p_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_T.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_p[0], p_vals); /* DPCT1065:127: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2146,30 +1758,17 @@ void kOptimizer32bit1State(T *g, T *p, DPCT1007:132: Migration of cub::BlockStore::Store is not supported. */ //StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_store::get_local_memory_size(NUM_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, s1_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_float1.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_state1[0], s1_vals); + } } diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index e69de29bb..a07a05ce9 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -0,0 +1,1860 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include "ops.h" +#include "kernels.h" +#include +#include +#include +#include +#include + + +#define ERR_NOT_IMPLEMENTED 100 + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +#define THREADS_ESTIMATE 512 +#define NUM_ESTIMATE 8 +#define BLOCK_ESTIMATE 4096 + + +using namespace BinSearch; +using std::cout; +using std::endl; + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) +{ + int threads = 512; + int num_blocks = n/threads; + num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; + /* + DPCT1049:53: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + dpct::get_in_order_queue().parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kHistogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n, item_ct1); + }); + /* + DPCT1010:229: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + //CUDA_CHECK_RETURN(0); +} + +template void estimateQuantiles(T *A, float *code, float offset, int n) +{ + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(code, 0, 256*sizeof(float)).wait())); + sycl::context ctx = q_ct1.get_context(); + int size = NUM_BLOCK; + + *((T **)&buff_A) = sycl::malloc_device(size, A, ctx); + + q_ct1.memcpy((T*)(buff_A), (T*)(A), NUM_BLOCK); + //sycl::buffer buff_A(A,sycl::range<1>(num_blocks)); + /* + DPCT1049:54: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load; + using group_radix_sort = dpct::group::radix_sort; + size_t sort_temp_storage_size = group_radix_sort::get_local_memory_size(NUM_BLOCK); + sycl::local_accessor tacc(temp_storage_size, cgh); + + /* + DPCT1054:293: The type of variable temp_storage is declared in device function with the name type_ct1. Adjust the code to make the type_ct1 declaration visible at the accessor declaration point. + */ + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + auto std_numeric_limits_T_max_ct3 = std::numeric_limits::max(); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kEstimateQuantiles(buff_A, code, offset, std_numeric_limits_T_max_ct3, n, item_ct1, tacc); + }); + }); + } + //back memcpy + q_ct1.memcpy((T*)(A), (T*)(buff_A), NUM_BLOCK); + +} + +void quantize(float *code, float *A, unsigned char *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int size = NUM_BLOCK; + + *((float **)&buff_A) = sycl::malloc_device(size, A, ctx); + *((unsigned char **)&buff_out = sycl::malloc_device(size, out, ctx); + q_ct1.memcpy((float*)(buff_A), (float*)(A), NUM_BLOCK); + q_ct1.memcpy((unsigned char*)(buff_out), (unsigned char*)(out), NUM_BLOCK); + + /* + DPCT1049:55: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + using group_load = dpct::group::workgroup_load; + size_t load_temp_storage_size = group_load::get_local_memory_size(NUM_BLOCK); + using group_store = dpct::group::workgroup_store; + size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc(load_temp_storage_size, cgh); + sycl::local_accessor stacc(store_temp_storage_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kQuantize(code, buff_A, buff_out, n, item_ct1, ltacc, stacc); + }); + }); + } + //back memcpy + q_ct1.memcpy((float*)(A), (float*)(buff_A), NUM_BLOCK); + q_ct1.memcpy((unsigned char*)(out), (unsigned char*)(buff_out), NUM_BLOCK); + /* + DPCT1010:232: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + //CUDA_CHECK_RETURN(0); +} + +void dequantize(float *code, unsigned char *A, float *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int size = NUM_BLOCK; + + *((unsigned char **)&buff_A) = sycl::malloc_device(size, A, ctx); + *((float **)&buff_out = sycl::malloc_device(size, out, ctx); + q_ct1.memcpy((float*)(buff_out), (float*)(out), NUM_BLOCK); + q_ct1.memcpy((unsigned char*)(buff_A), (unsigned char*)(A), NUM_BLOCK); + + /* + DPCT1049:56: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kDequantize(code, buff_A, buff_out, n, item_ct1); + }); + }); + //q_ct1.wait(); + + } + //back memcpy + q_ct1.memcpy((float*)(out), (float*)(buff_out), NUM_BLOCK); + q_ct1.memcpy((unsigned char*)(A), (unsigned char*)(buff_A), NUM_BLOCK); + /* + DPCT1010:233: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + //CUDA_CHECK_RETURN(0); +} + +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +{ + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + sycl::context ctx = q_ct1.get_context(); + int size= NUM_BLOCK; + + *((T **)&buff_A) = sycl::malloc_device(size, A, ctx); + *((unsigned char **)&buff_out = sycl::malloc_device(size, out, ctx); + *((float **)&buff_rand = sycl::malloc_device(size, rand, ctx); + q_ct1.memcpy((T*)(buff_A), (T*)(A), NUM_BLOCK); + q_ct1.memcpy((unsigned char*)(buff_out), (unsigned char*)(out), NUM_BLOCK); + q_ct1.memcpy((float*)(buff_rand), (float*)(rand), NUM_BLOCK); + + for(int i=0; i< NUM_BLOCK; i++){ buff_out[i]=buff_out[(DATA_TYPE > 0) ? i/2 : i]}; + + if(blocksize == 4096) + /* + DPCT1049:57: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + using group_load_T = dpct::group::workgroup_load; + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + using group_store = dpct::group::workgroup_store; + size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); + using group_load_float = dpct::group::workgroup_load; + size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); + sycl::local_accessor stacc(store_temp_storage_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + }); + }); + } + else if(blocksize == 2048) + /* + DPCT1049:58: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + using group_load_T = dpct::group::workgroup_load; + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + using group_store = dpct::group::workgroup_store; + size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); + using group_load_float = dpct::group::workgroup_load; + size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); + sycl::local_accessor stacc(store_temp_storage_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + }); + }); + } + else if(blocksize == 1024) + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + using group_load_T = dpct::group::workgroup_load; + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + using group_store = dpct::group::workgroup_store; + size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); + using group_load_float = dpct::group::workgroup_load; + size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); + sycl::local_accessor stacc(store_temp_storage_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + }); + }); + } + else if(blocksize == 512) + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + using group_load_T = dpct::group::workgroup_load; + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + using group_store = dpct::group::workgroup_store; + size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); + using group_load_float = dpct::group::workgroup_load; + size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); + sycl::local_accessor stacc(store_temp_storage_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + }); + }); + } + else if(blocksize == 256) + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + using group_load_T = dpct::group::workgroup_load; + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + using group_store = dpct::group::workgroup_store; + size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); + using group_load_float = dpct::group::workgroup_load; + size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); + sycl::local_accessor stacc(store_temp_storage_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + }); + }); + } + else if(blocksize == 128) + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + using group_load_T = dpct::group::workgroup_load; + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + using group_store = dpct::group::workgroup_store; + size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); + using group_load_float = dpct::group::workgroup_load; + size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); + sycl::local_accessor stacc(store_temp_storage_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + }); + }); + } + else if(blocksize == 64) + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + using group_load_T = dpct::group::workgroup_load; + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + using group_store = dpct::group::workgroup_store; + size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); + using group_load_float = dpct::group::workgroup_load; + size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); + sycl::local_accessor stacc(store_temp_storage_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + }); + }); + } + + + /* + DPCT1010:234: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + //back memcpy + q_ct1.memcpy((T*)(A), (T*)(buff_A), NUM_BLOCK); + q_ct1.memcpy((unsigned char*)(out), (unsigned char*)(buff_out), NUM_BLOCK); + q_ct1.memcpy((float*)(rand), (float*)(buff_rand), NUM_BLOCK); + //CUDA_CHECK_RETURN(0); +} + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) +{ + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + sycl::context ctx = q_ct1.get_context(); + + *((unsigned char **)&buff_A) = sycl::malloc_device(tile_size, A, ctx); + *((T **)&buff_out = sycl::malloc_device(tile_size, out, ctx); + q_ct1.memcpy((unsigned char*)(buff_A), (unsigned char*)(A), tile_size); + q_ct1.memcpy((T*)(buff_out), (T*)(out), tile_size); + + if(DATA_TYPE > 0) + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh){ + + using group_load = dpct::group::workgroup_load; + using group_store = dpct::group::workgroup_store; + size_t store_temp_storage_size = group_store::get_local_memory_size(tile_size); + size_t load_temp_storage_size = group_load::get_local_memory_size(tile_size); + sycl::local_accessor ltacc(load_temp_storage_size, cgh); + sycl::local_accessor stacc(store_temp_storage_size, cgh); + + + + q_ct1.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, (n+tile_size-1)/tile_size) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + kDequantizeBlockwise(code, buff_A, absmax, buff_out, blocksize/2, n, item_ct1, ltacc, stacc); + }); + }); + } + else + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh){ + + using group_load = dpct::group::workgroup_load; + using group_store = dpct::group::workgroup_store; + size_t store_temp_storage_size = group_store::get_local_memory_size(tile_size); + size_t load_temp_storage_size = group_load::get_local_memory_size(tile_size); + sycl::local_accessor ltacc(load_temp_storage_size, cgh); + sycl::local_accessor stacc(store_temp_storage_size, cgh); + + q_ct1.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, (n+tile_size-1)/tile_size) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + kDequantizeBlockwise(code, buff_A, absmax, buff_out, blocksize, n, item_ct1); + }); + }); + } + + /* + DPCT1010:235: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + //back memcpy + q_ct1.memcpy((unsigned char*)(A), (unsigned char*)(buff_A), tile_size); + q_ct1.memcpy((T*)(out), (T*)(buff_out), tile_size); + + //CUDA_CHECK_RETURN(0); +} + + +//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +// int num_blocks = (colsB+32-1)/32; +// kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); +// CUDA_CHECK_RETURN(cudaPeekAtLastError()); +//} + + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) + try { + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + int size= NUM_BLOCK; + + *((T **)&buff_g) = sycl::malloc_device(size, g, ctx); + *((T **)&buff_p) = sycl::malloc_device(size, p, ctx); + *((float **)&buff_state1 = sycl::malloc_device(size, state1, ctx); + *((float **)&buff_state2 = sycl::malloc_device(size, state2, ctx); + q_ct1.memcpy((T*)(buff_g), (T*)(g), size); + q_ct1.memcpy((T*)(buff_p), (T*)(p), size); + q_ct1.memcpy((float*)(buff_state1), (float*)(state1), size); + q_ct1.memcpy((float*)(buff_state2), (float*)(state2), size); + + + switch(OPTIMIZER) + { + case ADAM: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); + /* + DPCT1049:61: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load_T = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + using group_load_float2 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_float2 = group_load_float2::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + sycl::local_accessor ltacc_float2(load_temp_storage_size_float2, cgh); + + + /* + DPCT1054:294: The type of variable temp_storage is declared in device function with the name type_ct2. Adjust the code to make the type_ct2 declaration visible at the accessor declaration point. + */ + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kPreconditionOptimizer32bit2State(g, p, buff_state1, buff_state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float1, ltacc_float2); + }); + }); + } + /* + DPCT1010:236: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + //CUDA_CHECK_RETURN(0); + } + /* + DPCT1049:59: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + /* + DPCT1054:295: The type of variable temp_storage is declared in device function with the name type_ct3. Adjust the code to make the type_ct3 declaration visible at the accessor declaration point. + */ + + using group_load_T = dpct::group::workgroup_load; + using group_load_T1 = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + using group_load_float2 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_float2 = group_load_float2::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + sycl::local_accessor ltacc_float2(load_temp_storage_size_float2, cgh); + + + using group_store_T = dpct::group::workgroup_store; + using group_store_float1 = dpct::group::workgroup_store; + using group_store_float2 = dpct::group::workgroup_store; + size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_float2 = group_store_float2::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); + sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + sycl::local_accessor stacc_float2(store_temp_storage_size_float2, cgh); + + + + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizer32bit2State(buff_g, buff_p, buff_state1, buff_state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, ltacc_float2,stacc_T, stacc_float1, stacc_float2); + }); + }); + } + /* + DPCT1010:237: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + //CUDA_CHECK_RETURN(0); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); + /* + DPCT1049:62: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + /* + DPCT1054:296: The type of variable temp_storage is declared in device function with the name type_ct4. Adjust the code to make the type_ct4 declaration visible at the accessor declaration point. + */ + + using group_load_T = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + + + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kPreconditionOptimizer32bit1State(buff_g, buff_p, buff_state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float); + }); + }); + } + /* + DPCT1010:238: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + //CUDA_CHECK_RETURN(0); + } + + /* + DPCT1049:60: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + /* + DPCT1054:297: The type of variable temp_storage is declared in device function with the name type_ct5. Adjust the code to make the type_ct5 declaration visible at the accessor declaration point. + */ + using group_load_T = dpct::group::workgroup_load; + using group_load_T1 = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + + + using group_store_T = dpct::group::workgroup_store; + using group_store_float1 = dpct::group::workgroup_store; + size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); + sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + + + + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizer32bit1State(buff_g, buff_p, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, stacc_T,stacc_float1); + }); + }); + } + /* + DPCT1010:239: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); + break; + case LION: + // in lion, the momentum update after the parameter update + /* + DPCT1049:63: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + /* + DPCT1054:298: The type of variable temp_storage is declared in device function with the name type_ct5. Adjust the code to make the type_ct5 declaration visible at the accessor declaration point. + */ + sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizer32bit1State(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, temp_storage_ct1_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:240: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); + + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); + /* + DPCT1049:64: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + /* + DPCT1054:299: The type of variable temp_storage is declared in device function with the name type_ct4. Adjust the code to make the type_ct4 declaration visible at the accessor declaration point. + */ + sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kPreconditionOptimizer32bit1State(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, temp_storage_ct1_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:241: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); + } + break; + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +template void optimizerStatic8bit(T* p, T* g, + unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n) + try { + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + + if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); } + + switch(OPTIMIZER) + { + case ADAM: + CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait())); + CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max2, 0, 1*sizeof(float)).wait())); + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + /* + DPCT1054:300: The type of variable temp_storage is declared in device function with the name type_ct6. Adjust the code to make the type_ct6 declaration visible at the accessor declaration point. + */ + sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item_ct1) { + kPreconditionOptimizerStatic8bit2State(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n, item_ct1, temp_storage_ct1_acc_ct1.get_pointer(), smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:242: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); + /* + DPCT1049:65: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); + /* + DPCT1054:301: The type of variable temp_storage is declared in device function with the name type_ct7. Adjust the code to make the type_ct7 declaration visible at the accessor declaration point. + */ + sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizerStatic8bit2State(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), temp_storage_ct1_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:243: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait())); + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + /* + DPCT1054:302: The type of variable temp_storage is declared in device function with the name type_ct8. Adjust the code to make the type_ct8 declaration visible at the accessor declaration point. + */ + sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item_ct1) { + kPreconditionOptimizerStatic8bit1State(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, temp_storage_ct1_acc_ct1.get_pointer(), smem_quantiles1_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:244: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); + /* + DPCT1049:66: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + /* + DPCT1054:303: The type of variable temp_storage is declared in device function with the name type_ct9. Adjust the code to make the type_ct9 declaration visible at the accessor declaration point. + */ + sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizerStatic8bit1State(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), temp_storage_ct1_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:245: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); + break; + case LION: + // in lion, the momentum update happens after the parameter update + /* + DPCT1049:67: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + /* + DPCT1054:304: The type of variable temp_storage is declared in device function with the name type_ct9. Adjust the code to make the type_ct9 declaration visible at the accessor declaration point. + */ + sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizerStatic8bit1State(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), temp_storage_ct1_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:246: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); + + CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait())); + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + /* + DPCT1054:305: The type of variable temp_storage is declared in device function with the name type_ct8. Adjust the code to make the type_ct8 declaration visible at the accessor declaration point. + */ + sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item_ct1) { + kPreconditionOptimizerStatic8bit1State(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, temp_storage_ct1_acc_ct1.get_pointer(), smem_quantiles1_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:247: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); + break; + default: + break; + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +#define BLOCKSIZE_2STATE 2048 +#define NUM_2STATE 8 +#define BLOCKSIZE_1STATE 2048 +#define NUM_1STATE 8 + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) + try { + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + + int num_blocks = 0; + switch(OPTIMIZER) + { + case ADAM: + num_blocks = n/BLOCKSIZE_2STATE; + num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + /* + DPCT1101:306: 'LANES' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. + */ + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); + /* + DPCT1101:307: 'LANES' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. + */ + sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); + sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); + sycl::local_accessor smem_exchange2_acc_ct1(sycl::range<1>(1), cgh); + /* + DPCT1054:308: The type of variable temp_storage is declared in device function with the name type_ct10. Adjust the code to make the type_ct10 declaration visible at the accessor declaration point. + */ + sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_2STATE/NUM_2STATE), sycl::range<3>(1, 1, BLOCKSIZE_2STATE/NUM_2STATE)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizerStatic8bit2StateBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, smem_quantiles1_acc_ct1, smem_quantiles2_acc_ct1, smem_exchange1_acc_ct1.get_pointer(), smem_exchange2_acc_ct1.get_pointer(), temp_storage_ct1_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:248: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + case LION: + num_blocks = n/BLOCKSIZE_1STATE; + num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + /* + DPCT1101:309: 'LANES' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. + */ + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); + sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); + /* + DPCT1054:310: The type of variable temp_storage is declared in device function with the name type_ct11. Adjust the code to make the type_ct11 declaration visible at the accessor declaration point. + */ + sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE), sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE)), + [=](sycl::nd_item<3> item_ct1) { + kOptimizerStatic8bit1StateBlockwise(p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, smem_quantiles1_acc_ct1, smem_exchange1_acc_ct1.get_pointer(), temp_storage_ct1_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:249: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); + break; + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + + + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) +{ + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + int num_blocks = n/2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(&gnorm_vec[step % 100], 0, 1*sizeof(float)).wait())); + /* + DPCT1049:68: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kPercentileClipping(g, gnorm_vec, step, n, item_ct1); + }); + } + /* + DPCT1010:250: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); +} + +void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) + try { + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + int status; + + status = DPCT_CHECK_ERROR(dpct::gemm(*context->m_handle, transposeA ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, transposeB ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, m, n, k, alpha, A, dpct::library_data_t::real_int8, lda, B, dpct::library_data_t::real_int8, ldb, beta, C, dpct::library_data_t::real_int32, ldc, dpct::library_data_t::real_int32)); + + if (status != 0) + { + std::cout << "CUBLAS ERROR: Status " << status << std::endl; + } + +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount) + try { + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + int status; + + //cout << transposeA << transposeB << endl; + //printf("%i %i %i\n", m,n,k); + //printf("%i %i %i\n", lda,ldb,ldc); + //printf("%i %i %i\n", strideA, strideB, strideC); + //printf("%i\n", batchCount); + + status = DPCT_CHECK_ERROR(dpct::gemm_batch(*context->m_handle, transposeA ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, transposeB ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, m, n, k, alpha, A, dpct::library_data_t::real_int8, lda, (long long int)strideA, B, dpct::library_data_t::real_int8, ldb, (long long int)strideB, beta, C, dpct::library_data_t::real_int32, ldc, (long long int)strideC, batchCount, dpct::library_data_t::real_int32)); + + if (status != 0) + { + std::cout << "CUBLAS ERROR: Status " << status << std::endl; + } + +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + + +#ifdef NO_CUBLASLT +#else +template cublasLtOrder_t get_order() +{ + switch(ORDER) + { + case ROW: + return CUBLASLT_ORDER_ROW; + break; + case COL: + return CUBLASLT_ORDER_COL; + break; + case COL32: + return CUBLASLT_ORDER_COL32; + break; + case COL_TURING: + return CUBLASLT_ORDER_COL4_4R2_8C; + break; + case COL_AMPERE: + return CUBLASLT_ORDER_COL32_2R_4R4; + break; + default: + break; + } + + return CUBLASLT_ORDER_ROW; +} + +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +#endif + + +template int get_leading_dim(int dim1, int dim2) +{ + switch(ORDER) + { + case ROW: + return dim2; + break; + case COL: + return dim1; + break; + case COL32: + // 32*row tiles + return dim1*32; + break; + case COL_TURING: + return 32*roundoff(dim1, 8); + break; + case COL_AMPERE: + // 32*32 tiles + return 32*roundoff(dim1, 32); + break; + default: + return 0; + break; + } +} + +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); + +template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) +{ +#ifdef NO_CUBLASLT +#else + cublasLtOrder_t orderA = get_order(); + cublasLtOrder_t orderOut = get_order(); + int ldA = get_leading_dim(dim1, dim2); + int ldOut = get_leading_dim(dim1, dim2); + + cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL; + cublasLtMatrixTransformDesc_t A2Out_desc = NULL; + oneapi::mkl::transpose opTranspose = oneapi::mkl::transpose::trans; + float transformAlpha = 1.0f, transformBeta = 0.0f; + + + if(DTYPE == 8) + { + /* + DPCT1007:251: Migration of cublasLtMatrixLayoutCreate is not supported. + */ + checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, dpct::library_data_t::real_int8, dim1, dim2, ldA)); + /* + DPCT1007:252: Migration of cublasLtMatrixLayoutCreate is not supported. + */ + checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, dpct::library_data_t::real_int8, dim1, dim2, ldOut)); + } + else if(DTYPE == 32) + { + /* + DPCT1007:253: Migration of cublasLtMatrixLayoutCreate is not supported. + */ + checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, dpct::library_data_t::real_int32, dim1, dim2, ldA)); + /* + DPCT1007:254: Migration of cublasLtMatrixLayoutCreate is not supported. + */ + checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, dpct::library_data_t::real_int32, dim1, dim2, ldOut)); + } + else + { + printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE); + } + + /* + DPCT1007:255: Migration of cublasLtMatrixLayoutSetAttribute is not supported. + */ + checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA))); + /* + DPCT1007:256: Migration of cublasLtMatrixLayoutSetAttribute is not supported. + */ + checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut))); + + /* + DPCT1007:257: Migration of cublasLtMatrixTransformDescCreate is not supported. + */ + checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, dpct::library_data_t::real_float)); + + /* + DPCT1007:258: Migration of cublasLtMatrixTransformDescSetAttribute is not supported. + */ + if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); } + + checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0)); + + /* + DPCT1007:259: Migration of cublasLtMatrixLayoutDestroy is not supported. + */ + if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); + /* + DPCT1007:260: Migration of cublasLtMatrixLayoutDestroy is not supported. + */ + if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); + /* + DPCT1007:261: Migration of cublasLtMatrixTransformDescDestroy is not supported. + */ + if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); +#endif +} + +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); + +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + try { + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); +#ifdef NO_CUBLASLT + return ERR_NOT_IMPLEMENTED; +#else + int has_error = 0; + cublasLtMatmulDesc_t matmulDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + oneapi::mkl::transpose opT = oneapi::mkl::transpose::trans; + cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C; + cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4; + + /* + DPCT1007:262: Migration of cublasLtMatrixLayoutCreate is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, dpct::library_data_t::real_int8, m, k, lda)); + /* + DPCT1007:263: Migration of cublasLtMatrixLayoutCreate is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, dpct::library_data_t::real_int8, n, k, ldb)); + + /* + DPCT1007:264: Migration of cublasLtMatrixLayoutSetAttribute is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + if(FORMATB == COL_TURING) + /* + DPCT1007:265: Migration of cublasLtMatrixLayoutSetAttribute is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); + else + /* + DPCT1007:266: Migration of cublasLtMatrixLayoutSetAttribute is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); + + if(DTYPE_OUT == 32) + { + /* + DPCT1007:267: Migration of cublasLtMatmulDescCreate is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, dpct::library_data_t::real_int32)); + /* + DPCT1007:268: Migration of cublasLtMatmulDescSetAttribute is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); + /* + DPCT1007:269: Migration of cublasLtMatrixLayoutCreate is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, dpct::library_data_t::real_int32, m, n, ldc)); + /* + DPCT1007:270: Migration of cublasLtMatrixLayoutSetAttribute is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + int alpha = 1, beta = 0; + /* + DPCT1007:271: Migration of cublasLtMatmul is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, &q_ct1)); + } + else + { + /* + DPCT1007:272: Migration of cublasLtMatmulDescCreate is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, dpct::library_data_t::real_float)); + /* + DPCT1007:273: Migration of cublasLtMatmulDescSetAttribute is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); + /* + DPCT1007:274: Migration of cublasLtMatrixLayoutCreate is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, dpct::library_data_t::real_int8, m, n, ldc)); + /* + DPCT1007:275: Migration of cublasLtMatrixLayoutSetAttribute is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + if(!SCALE_ROWS) + { + float alpha = 1.0f, beta = 0.0f; + /* + DPCT1007:276: Migration of cublasLtMatmul is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, &q_ct1)); + } + else + { + /* + DPCT1007:277: Migration of cublasLtMatmulDescSetAttribute is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); + /* + DPCT1007:278: Migration of cublasLtMatmul is not supported. + */ + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, &q_ct1)); + } + } + + + /* + DPCT1007:279: Migration of cublasLtMatrixLayoutDestroy is not supported. + */ + if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc)); + /* + DPCT1007:280: Migration of cublasLtMatrixLayoutDestroy is not supported. + */ + if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc)); + /* + DPCT1007:281: Migration of cublasLtMatrixLayoutDestroy is not supported. + */ + if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc)); + /* + DPCT1007:282: Migration of cublasLtMatmulDescDestroy is not supported. + */ + if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); + if(has_error == 1) + printf("error detected"); + + return has_error; +#endif // NO_CUBLASLT +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +int fill_up_to_nearest_multiple(int value, int multiple) +{ + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); +} + +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *bias, int numRows, int numCols) +{ + int threads = 512; + int tileCols = fill_up_to_nearest_multiple(numCols, 32); + int n = numRows*tileCols; + int subtile_rows = 128; + int tilesize = 32*subtile_rows; + int num_blocks = numRows/subtile_rows; + num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; + num_blocks = num_blocks*(tileCols/32); + assert(threads <= tilesize); + + kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); + /* + DPCT1010:283: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); +} + +#define STATS_THREADS 64 +#define STATS_ITEMS 4 +#define STATS_ROWS 16 +void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) +{ + int tile_cols = STATS_THREADS*STATS_ITEMS; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); + int row_tiles = (tiledRows/STATS_ROWS); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + if(nnz_threshold == 0.0) + kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + else if(nnz_threshold != 0.0) + kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + /* + DPCT1010:284: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); + +} + +void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int *nnz_block_ptr, float threshold, int rows, int cols) +{ + int threads = 64; + int items_per_thread = 4; + int tile_cols = threads*items_per_thread; + int tile_rows = 16; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + + if(threshold > 0.0f) + kDoubleRowColQuant<64, 4, 16, 64*4, 1><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + else + kDoubleRowColQuant<64, 4, 16, 64*4, 0><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + + /* + DPCT1010:285: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); +} + +template void transformRowToFormat(char * A, char *out, int rows, int cols) +{ + int threads = 256; + int items_per_thread = 8; + // we load 128 column values per warp + int tile_cols = 32*items_per_thread; + int tile_rows = 32; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + int outCols = fill_up_to_nearest_multiple(cols, 32); + int outRows = fill_up_to_nearest_multiple(rows, 32); + if(FORMAT == COL_TURING) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 8); + else + outRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 32); + else + outRows = fill_up_to_nearest_multiple(rows, 32); + } + else + { + if(TRANSPOSE) + { + outCols = fill_up_to_nearest_multiple(rows, 32); + outRows = cols; + } + } + + /* + DPCT1049:69: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + sycl::local_accessor smem_data_acc_ct1(sycl::range<1>(32*33*8), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, threads), sycl::range<3>(1, 1, threads)), + [=](sycl::nd_item<3> item_ct1) { + kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT>(A, out, rows, cols, tiledCols, outRows, outCols, item_ct1, smem_data_acc_ct1.get_pointer()); + }); + }); + /* + DPCT1010:286: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); +} + +void spmm_coo(sycl::queue* handle, int *A_rowidx, int *A_colidx, sycl::half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, sycl::half *B, int ldc, sycl::half* C, bool transposed_B) +{ + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + +#ifdef NO_CUBLASLT +#else + + dpct::sparse::sparse_matrix_desc_t descA; + std::shared_ptr descB, descC; + + float alpha = 1.0f; + float beta = 0.0f; + void *dBuffer = NULL; + size_t bufferSize = 0; + + /* + DPCT1007:287: Migration of cusparseCreateCoo is not supported. + */ + CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz, + A_rowidx, A_colidx, A_vals, + dpct::library_data_t::real_int32, + oneapi::mkl::index_base::zero, dpct::library_data_t::real_half) ); + // Create dense matrix C + CHECK_CUSPARSE( DPCT_CHECK_ERROR(descC = std::make_shared(A_rows, B_cols, ldc, C, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major)) ); + // Create dense matrix B + if(transposed_B) + { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; + } + + CHECK_CUSPARSE( DPCT_CHECK_ERROR(descB = std::make_shared(A_cols, B_cols, ldb, B, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major)) ); + // allocate an external buffer if needed + CHECK_CUSPARSE( DPCT_CHECK_ERROR(bufferSize = 0) ); + CUDA_CHECK_RETURN( DPCT_CHECK_ERROR(dBuffer = (void *)sycl::malloc_device(bufferSize, q_ct1)) ); + + // execute SpMM + CHECK_CUSPARSE( DPCT_CHECK_ERROR(dpct::sparse::spmm(*handle, oneapi::mkl::transpose::nontrans, transposed_B ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, &alpha, descA, descB, &beta, descC, dpct::library_data_t::real_float))); + + // destroy matrix/vector descriptors + CHECK_CUSPARSE( DPCT_CHECK_ERROR(descA.reset()) ); + CHECK_CUSPARSE( DPCT_CHECK_ERROR(descB.reset()) ); + CHECK_CUSPARSE( DPCT_CHECK_ERROR(descC.reset()) ); + CUDA_CHECK_RETURN( DPCT_CHECK_ERROR(sycl::free(dBuffer, q_ct1)) ); +#endif +} + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ + + { + dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + /* + DPCT1101:311: 'SMEM_SIZE' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. + */ + sycl::local_accessor smem_dequant_stats_acc_ct1(sycl::range<1>(2048/*SMEM_SIZE*/), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nnz_rows) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item_ct1) { + kspmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB, item_ct1, smem_dequant_stats_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:289: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); +} + + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols) +{ + int threads = 256; + // we load 128 column values per warp + int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32); + int tiledRows = 0; + + int num_blocks = idx_size; + + if(FORMAT == COL_TURING) + { + tiledRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + tiledRows = fill_up_to_nearest_multiple(rows, 32); + } + + /* + DPCT1049:70: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + dpct::get_in_order_queue().parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, threads), sycl::range<3>(1, 1, threads)), + [=](sycl::nd_item<3> item_ct1) { + kExtractOutliers(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols, item_ct1); + }); + /* + DPCT1010:290: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); +} + + + + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + //if(bits == 32) + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + { + dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + /* + DPCT1101:312: '8*16 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. + */ + sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(224/*8*16 + (2*16*(batch_size_warps-1))*/), cgh); + /* + DPCT1101:313: '2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. + */ + sycl::local_accessor smem_B_acc_ct1(sycl::range<1>(4192/*2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))*/), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 160), sycl::range<3>(1, 1, 160)), + [=](sycl::nd_item<3> item_ct1) { + gemm_device(m, n, k, A, B, out, lda, ldb, ldc, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer()); + }); + }); + } + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); +} + +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + { + dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + /* + DPCT1101:314: '8*16 + (16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. + */ + sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(176/*8*16 + (16*(batch_size_warps-1))*/), cgh); + /* + DPCT1101:315: '2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. + */ + sycl::local_accessor smem_B_acc_ct1(sycl::range<1>(4192/*2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))*/), cgh); + sycl::local_accessor smem_C_acc_ct1(sycl::range<1>(8*32), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 96), sycl::range<3>(1, 1, 96)), + [=](sycl::nd_item<3> item_ct1) { + kgemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer(), smem_C_acc_ct1.get_pointer()); + }); + }); + } + //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+3)/4; + + { + dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + sycl::local_accessor quant_map_acc_ct1(sycl::range<1>(16), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + kgemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, item_ct1, quant_map_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:291: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); +} + +template void func(T *A, T *B, T value, long n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + /* + DPCT1049:71: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + dpct::get_in_order_queue().parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kfunc(A, B, value, n, item_ct1); + }); + /* + DPCT1010:292: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + CUDA_CHECK_RETURN(0); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void func(float *A, float *B, float value, long n); +template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); +template void func(float *A, float *B, float value, long n); +template void func(float *A, float *B, float value, long n); + +template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); + +//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +template void estimateQuantiles(half *A, float *code, float offset, int n); +template void estimateQuantiles(float *A, float *code, float offset, int n); + +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); + +#define MAKE_optimizer32bit(name, gtype) \ +template void optimizer32bit(gtype* g, gtype* p, \ + float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +MAKE_optimizer32bit(ADAM, half) +MAKE_optimizer32bit(ADAM, float) +MAKE_optimizer32bit(ADAM, __nv_bfloat16) +MAKE_optimizer32bit(MOMENTUM, half) +MAKE_optimizer32bit(MOMENTUM, float) +MAKE_optimizer32bit(RMSPROP, half) +MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(LION, half) +MAKE_optimizer32bit(LION, float) +MAKE_optimizer32bit(LION, __nv_bfloat16) +MAKE_optimizer32bit(ADAGRAD, half) +MAKE_optimizer32bit(ADAGRAD, float) + +#define MAKE_optimizerStatic8bit(name, gtype) \ +template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + float *unorm, float max_unorm, float param_norm, \ + float beta1, float beta2, \ + float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, int n); \ + +MAKE_optimizerStatic8bit(ADAM, half) +MAKE_optimizerStatic8bit(ADAM, float) +MAKE_optimizerStatic8bit(MOMENTUM, half) +MAKE_optimizerStatic8bit(MOMENTUM, float) +MAKE_optimizerStatic8bit(RMSPROP, half) +MAKE_optimizerStatic8bit(RMSPROP, float) +MAKE_optimizerStatic8bit(LION, half) +MAKE_optimizerStatic8bit(LION, float) + +#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ +template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ + +MAKE_optimizerStatic8bitBlockwise(half, ADAM); +MAKE_optimizerStatic8bitBlockwise(float, ADAM); +MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, LION); +MAKE_optimizerStatic8bitBlockwise(float, LION); +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION); +MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); + +template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); +template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); + +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); From adda60de573e0fb3209296a1118ddce1408d9be4 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Sun, 31 Mar 2024 21:25:03 -0700 Subject: [PATCH 14/66] add 8 bit opt with ops --- csrc/sycl/kernels.cpp | 110 ++++++++++++------------------------------ csrc/sycl/ops.cpp | 97 ++++++++++++++++++++++++++++--------- 2 files changed, 106 insertions(+), 101 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 767ea5b03..3030b062d 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -1783,15 +1783,15 @@ DPCT1110:6: The total declared local variable size in device function kPrecondit */ SYCL_EXTERNAL void -kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, +kPreconditionOptimizerStatic8bit2State(T* buff_p, T* __restrict__ const buff_g, unsigned char*__restrict__ const buff_state1, unsigned char* __restrict__ const buff_state2, float *unorm, const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, - float *smem_quantiles1, float *smem_quantiles2) + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, + sycl_la_float ltacc_float1, sycl_la_float ltacc_float2) { const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); @@ -1810,11 +1810,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c //typedef cub::BlockLoad LoadT; //typedef cub::BlockLoad LoadUInt8; //typedef sycl::group<3> BlockReduce; - sycl::buffer buff_g(g, sycl::range<1>(NUM_THREADS)); - sycl::buffer buff_p(p, sycl::range<1>(NUM_THREADS)); - sycl::buffer buff_state1(state1,sycl::range<1>(NUM_THREADS)); - sycl::buffer buff_state2(state2,sycl::range<1>(NUM_THREADS)); - + /* union type_ct6{ typename LoadT::TempStorage loadh; @@ -1845,29 +1841,16 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c DPCT1007:156: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_THREADS); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_g[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + /* DPCT1065:153: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -1876,30 +1859,15 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c DPCT1007:157: Migration of cub::BlockLoad::Load is not supported. */ //LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_THREADS); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, m_c1); - }); - - }); - + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], m_c1); /* DPCT1065:154: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1909,29 +1877,15 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c DPCT1007:158: Migration of cub::BlockLoad::Load is not supported. */ //LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_THREADS); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state2[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, r_c2); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ;ltacc_float2.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state2[0], r_c2); /* DPCT1065:155: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2894,7 +2848,7 @@ DPCT1110:10: The total declared local variable size in device function kOptimize */ SYCL_EXTERNAL void -kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, +kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, unsigned char* buff_state1, unsigned char* buff_state2, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index a07a05ce9..f361257fa 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -671,7 +671,7 @@ template void optimizer32bit(T* g, T* p, /* DPCT1010:239: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); break; case LION: // in lion, the momentum update after the parameter update @@ -685,19 +685,38 @@ template void optimizer32bit(T* g, T* p, /* DPCT1054:298: The type of variable temp_storage is declared in device function with the name type_ct5. Adjust the code to make the type_ct5 declaration visible at the accessor declaration point. */ - sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + using group_load_T = dpct::group::workgroup_load; + using group_load_T1 = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + + + using group_store_T = dpct::group::workgroup_store; + using group_store_float1 = dpct::group::workgroup_store; + size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); + sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizer32bit1State(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, temp_storage_ct1_acc_ct1.get_pointer()); + kOptimizer32bit1State(buff_g, buff_p, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, stacc_T,stacc_float1); }); }); } /* DPCT1010:240: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); if(max_unorm > 0.0f) { @@ -712,19 +731,26 @@ template void optimizer32bit(T* g, T* p, /* DPCT1054:299: The type of variable temp_storage is declared in device function with the name type_ct4. Adjust the code to make the type_ct4 declaration visible at the accessor declaration point. */ - sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + using group_load_T = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizer32bit1State(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, temp_storage_ct1_acc_ct1.get_pointer()); + kPreconditionOptimizer32bit1State(buff_g, buff_p, buff_state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float); }); }); } /* DPCT1010:241: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); } break; } @@ -734,8 +760,8 @@ catch (sycl::exception const &exc) { std::exit(1); } -template void optimizerStatic8bit(T* p, T* g, - unsigned char* state1, unsigned char* state2, +template void optimizerStatic8bit(T* buff_p, T* buff_g, + unsigned char* buff_state1, unsigned char* buff_state2, float *unorm, float max_unorm, float param_norm, float beta1, float beta2, float eps, int step, float lr, @@ -744,10 +770,21 @@ template void optimizerStatic8bit(T* p, T* g, float weight_decay, const float gnorm_scale, int n) try { - dpct::device_ext &dev_ct1 = dpct::get_current_device(); - sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); int num_blocks = n/4096; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + sycl::context ctx = q_ct1.get_context(); + + *((T **)&buff_g) = sycl::malloc_device(size, g, ctx); + *((T **)&buff_p) = sycl::malloc_device(size, p, ctx); + *((float **)&buff_state1 = sycl::malloc_device(size, state1, ctx); + *((float **)&buff_state2 = sycl::malloc_device(size, state2, ctx); + q_ct1.memcpy((T*)(buff_g), (T*)(g), size); + q_ct1.memcpy((T*)(buff_p), (T*)(p), size); + q_ct1.memcpy((float*)(buff_state1), (float*)(state1), size); + q_ct1.memcpy((float*)(buff_state2), (float*)(state2), size); + if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); } @@ -763,21 +800,33 @@ template void optimizerStatic8bit(T* p, T* g, /* DPCT1054:300: The type of variable temp_storage is declared in device function with the name type_ct6. Adjust the code to make the type_ct6 declaration visible at the accessor declaration point. */ - sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); - sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); + + using group_load_T = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + using group_load_float2 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_float2 = group_load_float2::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + sycl::local_accessor ltacc_float2(load_temp_storage_size_float2, cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizerStatic8bit2State(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n, item_ct1, temp_storage_ct1_acc_ct1.get_pointer(), smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer()); + kPreconditionOptimizerStatic8bit2State(buff_p, buff_g, buff_state1, buff_state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float1, ltacc_float2); }); }); } /* DPCT1010:242: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); /* DPCT1049:65: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ @@ -802,7 +851,7 @@ template void optimizerStatic8bit(T* p, T* g, /* DPCT1010:243: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); break; case MOMENTUM: case RMSPROP: @@ -828,7 +877,7 @@ template void optimizerStatic8bit(T* p, T* g, /* DPCT1010:244: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); /* DPCT1049:66: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ @@ -852,7 +901,7 @@ template void optimizerStatic8bit(T* p, T* g, /* DPCT1010:245: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); break; case LION: // in lion, the momentum update happens after the parameter update @@ -879,7 +928,7 @@ template void optimizerStatic8bit(T* p, T* g, /* DPCT1010:246: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait())); { @@ -902,7 +951,7 @@ template void optimizerStatic8bit(T* p, T* g, /* DPCT1010:247: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); break; default: break; @@ -942,6 +991,8 @@ template void optimizerStatic8bitBlockwise(T* p, T* g /* DPCT1101:307: 'LANES' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. */ + + sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); sycl::local_accessor smem_exchange2_acc_ct1(sycl::range<1>(1), cgh); From beffe78c67d54f5ecf658552746e44488e4b8c80 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Sun, 31 Mar 2024 23:35:23 -0700 Subject: [PATCH 15/66] add 8 bit 2 state adam kernel --- csrc/sycl/kernels.cpp | 246 ++++++++++++------------------------------ csrc/sycl/ops.cpp | 36 ++++++- 2 files changed, 102 insertions(+), 180 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 3030b062d..4a3959d3c 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -1973,7 +1973,7 @@ kPreconditionOptimizerStatic8bit2State(T* buff_p, T* __restrict__ const buff_g, template SYCL_EXTERNAL void -kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, +kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state1, unsigned char* buff_state2, const float *unorm, const float max_unorm, const float param_norm, \ const float beta1, const float beta2, const float eps, const int step, const float lr, @@ -2012,10 +2012,6 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha T p_vals[NUM_PER_THREAD2]; T g_vals[NUM_PER_THREAD2]; - sycl::buffer buff_g(g, sycl::range<1>(NUM_PER_THREAD2)); - sycl::buffer buff_p(p, sycl::range<1>(NUM_PER_THREAD2)); - sycl::buffer buff_state1(state1,sycl::range<1>(NUM_PER_THREAD2)); - sycl::buffer buff_state2(state2,sycl::range<1>(NUM_PER_THREAD2)); //typedef cub::BlockLoad LoadT; //typedef cub::BlockLoad LoadChar; @@ -2314,7 +2310,7 @@ DPCT1110:3: The total declared local variable size in device function kPrecondit */ SYCL_EXTERNAL void -kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, +kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, unsigned char*__restrict__ const buff_state1, float *unorm, const float beta1, const float beta2, const float eps, const int step, @@ -2856,9 +2852,9 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, - sycl::local_accessor smem_quantiles1, - sycl::local_accessor smem_quantiles2, - float *smem_exchange1, float *smem_exchange2) + sycl_la_T ltacc_T, sycl_la_T ltacc_T1, + sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, + sycl_la_T stacc_T, sycl_la_float1 stacc_float1, sycl_la_float stacc_float2) { //const int n_full = n + (n%BLOCK_SIZE); @@ -2883,12 +2879,6 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; - sycl::buffer buff_g(g, sycl::range<1>(N_PER_TH)); - sycl::buffer buff_p(p, sycl::range<1>(N_PER_TH)); - sycl::buffer buff_state1(state1,sycl::range<1>(N_PER_TH)); - sycl::buffer buff_state2(state2,sycl::range<1>(N_PER_TH)); - - //typedef cub::BlockLoad LoadT; //typedef cub::BlockLoad LoadChar; @@ -2942,29 +2932,15 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns DPCT1007:183: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_g[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); /* DPCT1065:176: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2974,29 +2950,15 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns DPCT1007:184: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, c1s); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], c1s); /* DPCT1065:177: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -3006,29 +2968,16 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns DPCT1007:185: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state2[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, c2s); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float2.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state2[0], c2s); + new_local_abs_max1 = -FLT_MAX; new_local_abs_max2 = -FLT_MAX; @@ -3095,31 +3044,15 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns DPCT1007:186: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_p[i], h, sycl::read_write); - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, p_vals); - }); - - }); - - + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_p[0], p_vals); // reduce: 2.67/1.69 -> 2.67/1.70 # pragma unroll N_PER_TH @@ -3143,29 +3076,16 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns DPCT1007:187: Migration of cub::BlockStore::Store is not supported. */ //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_p[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, p_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_T.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_p[0], p_vals); + // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) @@ -3192,29 +3112,16 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns DPCT1007:188: Migration of cub::BlockStore::Store is not supported. */ //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, c1s); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_float1.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_state1[0], c1s); + /* DPCT1065:182: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -3224,29 +3131,16 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns DPCT1007:189: Migration of cub::BlockStore::Store is not supported. */ //StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state2[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, c2s); - }); - }); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_float2.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_state2[0], c2s); + } } diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index f361257fa..2dd6ffae2 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -775,6 +775,7 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ int num_blocks = n/4096; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; sycl::context ctx = q_ct1.get_context(); + int size = NUM_BLOCK; *((T **)&buff_g) = sycl::malloc_device(size, g, ctx); *((T **)&buff_p) = sycl::malloc_device(size, p, ctx); @@ -834,17 +835,44 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); - sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); /* DPCT1054:301: The type of variable temp_storage is declared in device function with the name type_ct7. Adjust the code to make the type_ct7 declaration visible at the accessor declaration point. */ - sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + using group_load_T = dpct::group::workgroup_load; + using group_load_T1 = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + using group_load_float2 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_float2 = group_load_float2::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + sycl::local_accessor ltacc_float2(load_temp_storage_size_float2, cgh); + + + using group_store_T = dpct::group::workgroup_store; + using group_store_float1 = dpct::group::workgroup_store; + using group_store_float2 = dpct::group::workgroup_store; + size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_float2 = group_store_float2::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); + sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + sycl::local_accessor stacc_float2(store_temp_storage_size_float2, cgh); + + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit2State(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), temp_storage_ct1_acc_ct1.get_pointer()); + kOptimizerStatic8bit2State(buff_p, buff_g, buff_state1, buff_state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, ltacc_float2, stacc_T, stacc_float1, stacc_float2); }); }); } From ed431fefb89e7c099c5ccb5f0809cb6b7457ae89 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 1 Apr 2024 00:06:12 -0700 Subject: [PATCH 16/66] 8 bit optimizer ops integrate --- csrc/sycl/kernels.cpp | 245 +++++++++++++----------------------------- csrc/sycl/ops.cpp | 97 ++++++++++++++--- 2 files changed, 156 insertions(+), 186 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 4a3959d3c..963f70ceb 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -2319,7 +2319,7 @@ kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, const float weight_decay, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, - float *smem_quantiles1) + sycl_la_T ltacc_T, sycl_la_float ltacc_float1) { const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); @@ -2332,9 +2332,6 @@ kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, T g_vals[NUM8BIT]; unsigned char m_c1[NUM8BIT]; - sycl::buffer buff_g(g, sycl::range<1>(NUM_THREADS)); - sycl::buffer buff_p(p, sycl::range<1>(NUM_THREADS)); - sycl::buffer buff_state1(state1,sycl::range<1>(NUM_THREADS)); //typedef cub::BlockLoad LoadT; //typedef cub::BlockLoad LoadUInt8; @@ -2370,29 +2367,16 @@ kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, DPCT1007:137: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_THREADS); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_g[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); - }); - }); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + /* DPCT1065:136: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2401,30 +2385,16 @@ kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, DPCT1007:138: Migration of cub::BlockLoad::Load is not supported. */ //LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_THREADS); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, m_c1); - }); - - }); - + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], m_c1); + #pragma unroll 16 for(int j = 0; j < NUM8BIT; j++) { @@ -2482,7 +2452,7 @@ kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, template SYCL_EXTERNAL void -kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, +kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state1, const float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, @@ -2490,7 +2460,9 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_quantiles1) + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, + sycl_la_float ltacc_float1, sycl_la_T stacc_T, + sycl_la_float stacc_float1) { const int n_full = (item_ct1.get_local_range(2) * item_ct1.get_group_range(2))*NUM_PER_THREAD2; @@ -2513,10 +2485,6 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, T p_vals[NUM_PER_THREAD2]; T g_vals[NUM_PER_THREAD2]; - sycl::buffer buff_g(g, sycl::range<1>(NUM_PER_THREAD2)); - sycl::buffer buff_p(p, sycl::range<1>(NUM_PER_THREAD2)); - sycl::buffer buff_state1(state1,sycl::range<1>(NUM_PER_THREAD2)); - //typedef cub::BlockLoad LoadT; //typedef cub::BlockLoad LoadChar; @@ -2549,29 +2517,15 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, DPCT1007:145: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_g[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); /* @@ -2582,30 +2536,17 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, DPCT1007:146: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, c1s); - }); - - }); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], c1s); + + /* DPCT1065:142: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2614,29 +2555,16 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, DPCT1007:147: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_p[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, p_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_p[0], p_vals); + if((i + (item_ct1.get_local_id(2)*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } # pragma unroll 4 @@ -2695,29 +2623,15 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, DPCT1007:148: Migration of cub::BlockStore::Store is not supported. */ //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_p[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, p_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_T.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_p[0], p_vals); /* DPCT1065:143: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2727,29 +2641,16 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, DPCT1007:149: Migration of cub::BlockStore::Store is not supported. */ //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, c1s); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_float1.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_state1[0], c1s); + /* DPCT1065:144: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -3153,7 +3054,7 @@ DPCT1110:11: The total declared local variable size in device function kOptimize */ SYCL_EXTERNAL void -kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, +kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, unsigned char* buff_state1, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 2dd6ffae2..e47012454 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -754,6 +754,13 @@ template void optimizer32bit(T* g, T* p, } break; } + + //back memcpy + q_ct1.memcpy((T*)(g), (T*)(buff_g), size); + q_ct1.memcpy((T*)(p), (T*)(buff_p), size); + q_ct1.memcpy((float*)(state1), (float*)(buff_state1), size); + q_ct1.memcpy((float*)(state2), (float*)(buff_state2), size); + } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; @@ -892,13 +899,20 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ /* DPCT1054:302: The type of variable temp_storage is declared in device function with the name type_ct8. Adjust the code to make the type_ct8 declaration visible at the accessor declaration point. */ - sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); - + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + using group_load_T = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizerStatic8bit1State(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, temp_storage_ct1_acc_ct1.get_pointer(), smem_quantiles1_acc_ct1.get_pointer()); + kPreconditionOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float1); }); }); } @@ -913,16 +927,37 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); /* DPCT1054:303: The type of variable temp_storage is declared in device function with the name type_ct9. Adjust the code to make the type_ct9 declaration visible at the accessor declaration point. */ - sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + using group_load_T = dpct::group::workgroup_load; + using group_load_T1 = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + + + using group_store_T = dpct::group::workgroup_store; + using group_store_float1 = dpct::group::workgroup_store; + size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); + sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit1State(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), temp_storage_ct1_acc_ct1.get_pointer()); + kOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1,ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); }); }); } @@ -940,16 +975,36 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); /* DPCT1054:304: The type of variable temp_storage is declared in device function with the name type_ct9. Adjust the code to make the type_ct9 declaration visible at the accessor declaration point. */ - sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + + using group_load_T = dpct::group::workgroup_load; + using group_load_T1 = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + + + using group_store_T = dpct::group::workgroup_store; + using group_store_float1 = dpct::group::workgroup_store; + size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); + sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit1State(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), temp_storage_ct1_acc_ct1.get_pointer()); + kOptimizerStatic8bit1State(`buff_p, buff_g, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); }); }); } @@ -966,13 +1021,20 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ /* DPCT1054:305: The type of variable temp_storage is declared in device function with the name type_ct8. Adjust the code to make the type_ct8 declaration visible at the accessor declaration point. */ - sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); - + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + using group_load_T = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizerStatic8bit1State(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, temp_storage_ct1_acc_ct1.get_pointer(), smem_quantiles1_acc_ct1.get_pointer()); + kPreconditionOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float1); }); }); } @@ -984,6 +1046,13 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ default: break; } + + //back memcpy + q_ct1.memcpy((T*)(buff_g), (T*)(g), size); + q_ct1.memcpy((T*)(buff_p), (T*)(p), size); + q_ct1.memcpy((float*)(buff_state1), (float*)(state1), size); + q_ct1.memcpy((float*)(buff_state2), (float*)(state2), size); + } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; From 7ab327341a97c665c9005193e8072a8f616849d8 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 1 Apr 2024 01:36:32 -0700 Subject: [PATCH 17/66] add all opt & quant kernels --- csrc/sycl/kernels.cpp | 285 +++++++++++++----------------------------- csrc/sycl/ops.cpp | 164 ++++++++++++++++++++---- 2 files changed, 232 insertions(+), 217 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 963f70ceb..cd7e684c1 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -635,6 +635,8 @@ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, cons typedef sycl::accessor sycl_la_float; typedef sycl::accessor sycl_la_T; typedef sycl::accessor sycl_la_unsigned_char; +typedef sycl::accessor sycl_la_half; +typedef sycl::accessor sycl_la_unsigned; template @@ -2660,8 +2662,8 @@ kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state template -SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n, - const sycl::nd_item<3> &item_ct1) +SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ buff_g, float *gnorm_vec, int step, const int n, + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); int valid_items = 0; @@ -2669,7 +2671,7 @@ SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int //typedef cub::BlockLoad LoadT; - sycl::buffer buff_g(g, sycl::range<1>(NUM_VALS)); + //sycl::buffer buff_g(g, sycl::range<1>(NUM_VALS)); T vals[NUM_VALS]; @@ -2688,29 +2690,16 @@ SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int DPCT1007:203: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_VALS); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_g[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], vals); + #pragma unroll NUM_VALS for(int j = 0; j < NUM_VALS; j++) local_sum += ((float)vals[j])*((float)vals[j]); @@ -3062,8 +3051,10 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, - sycl::local_accessor smem_quantiles1, - float *smem_exchange1) + sycl_la_T ltacc_T, sycl_la_T ltacc_T1, + sycl_la_float ltacc_float1,sycl_la_T stacc_T, + sycl_la_float stacc_float1 + ) { //const int n_full = n + (n%BLOCK_SIZE); @@ -3086,13 +3077,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns //typedef cub::BlockStore StoreChar; //typedef cub::BlockStore StoreT; - - - sycl::buffer buff_g(g, sycl::range<1>(N_PER_TH)); - sycl::buffer buff_p(p, sycl::range<1>(N_PER_TH)); - sycl::buffer buff_state1(state1,sycl::range<1>(N_PER_TH)); - sycl::buffer buff_state2(state2,sycl::range<1>(N_PER_TH)); - + /* union type_ct11{ typename LoadT::TempStorage loadh; @@ -3132,29 +3117,16 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns DPCT1007:197: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_g[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); - }); - }); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + /* DPCT1065:192: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -3163,29 +3135,16 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns DPCT1007:198: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, c1s); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], c1s); + /* DPCT1065:193: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -3194,29 +3153,16 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns DPCT1007:199: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_p[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, p_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_p[0], p_vals); + new_local_abs_max1 = -FLT_MAX; // update: 2.48/1.57 -> 2.51/1.60 @@ -3319,29 +3265,16 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns DPCT1007:200: Migration of cub::BlockStore::Store is not supported. */ //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_p[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, p_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_T.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_p[0], p_vals); + // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) @@ -3367,34 +3300,20 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns DPCT1007:201: Migration of cub::BlockStore::Store is not supported. */ //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_load::get_local_memory_size(N_PER_TH); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, c1s); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_float1.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_state1[0], c1s); } } -template void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values) +template void kgetColRowStats(T * __restrict__ buff_A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, sycl_la_half ltacc_half, sycl_la_unsigned exacc) { // 0. reset stats to -FLT_MAX // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) @@ -3425,9 +3344,7 @@ template buff_A(A, sycl::range<1>(ITEMS_PER_THREAD)); - sycl::half local_data[ITEMS_PER_THREAD]; float local_data_fp32[ITEMS_PER_THREAD]; float local_col_absmax_values[ITEMS_PER_THREAD]; @@ -3474,29 +3391,16 @@ template(0.0f).convert()[0]); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(ITEMS_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_A[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, local_data); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_half.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], local_data); + #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_data[j] = sycl::fabs(local_data[j]); @@ -3570,27 +3474,16 @@ template; - size_t temp_storage_size = group_exchange::get_local_memory_size(ITEMS_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *tmp = tacc.get_multi_ptr().get(); - group_exchange(tmp).blocked_to_striped(item, local_col_absmax_values); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = exacc.get_multi_ptr().get(); + group_exchange(tmp).blocked_to_striped(item, local_col_absmax_values); + #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) if(base_col+item_ct1.get_local_id(2)+(j*THREADS) < cols) diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index e47012454..24b671a5b 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1064,14 +1064,27 @@ catch (sycl::exception const &exc) { #define BLOCKSIZE_1STATE 2048 #define NUM_1STATE 8 -template void optimizerStatic8bitBlockwise(T* p, T* g, - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, +template void optimizerStatic8bitBlockwise(T* buff_p, T* buff_g, + unsigned char* buff_state1, unsigned char* buff_state2, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) try { - dpct::device_ext &dev_ct1 = dpct::get_current_device(); - sycl::queue &q_ct1 = dev_ct1.in_order_queue(); - + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); int num_blocks = 0; + int size = NUM_BLOCK; + + *((T **)&buff_g) = sycl::malloc_device(size, g, ctx); + *((T **)&buff_p) = sycl::malloc_device(size, p, ctx); + *((float **)&buff_state1 = sycl::malloc_device(size, state1, ctx); + *((float **)&buff_state2 = sycl::malloc_device(size, state2, ctx); + q_ct1.memcpy((T*)(buff_g), (T*)(g), size); + q_ct1.memcpy((T*)(buff_p), (T*)(p), size); + q_ct1.memcpy((float*)(buff_state1), (float*)(state1), size); + q_ct1.memcpy((float*)(buff_state2), (float*)(state2), size); + + switch(OPTIMIZER) { case ADAM: @@ -1090,25 +1103,50 @@ template void optimizerStatic8bitBlockwise(T* p, T* g */ - sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); - sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); - sycl::local_accessor smem_exchange2_acc_ct1(sycl::range<1>(1), cgh); + //sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); + //sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); + //sycl::local_accessor smem_exchange2_acc_ct1(sycl::range<1>(1), cgh); /* DPCT1054:308: The type of variable temp_storage is declared in device function with the name type_ct10. Adjust the code to make the type_ct10 declaration visible at the accessor declaration point. */ - sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + using group_load_T = dpct::group::workgroup_load; + using group_load_T1 = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + using group_load_float2 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_float2 = group_load_float2::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + sycl::local_accessor ltacc_float2(load_temp_storage_size_float2, cgh); + + + using group_store_T = dpct::group::workgroup_store; + using group_store_float1 = dpct::group::workgroup_store; + using group_store_float2 = dpct::group::workgroup_store; + size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_float2 = group_store_float2::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); + sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + sycl::local_accessor stacc_float2(store_temp_storage_size_float2, cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_2STATE/NUM_2STATE), sycl::range<3>(1, 1, BLOCKSIZE_2STATE/NUM_2STATE)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit2StateBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, smem_quantiles1_acc_ct1, smem_quantiles2_acc_ct1, smem_exchange1_acc_ct1.get_pointer(), smem_exchange2_acc_ct1.get_pointer(), temp_storage_ct1_acc_ct1.get_pointer()); + kOptimizerStatic8bit2StateBlockwise(buff_p, buff_g, buff_state1, buff_state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, ltacc_float2, stacc_T, stacc_float1, stacc_float2); }); }); } /* DPCT1010:248: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); break; case MOMENTUM: case RMSPROP: @@ -1123,26 +1161,50 @@ template void optimizerStatic8bitBlockwise(T* p, T* g /* DPCT1101:309: 'LANES' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. */ - sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); - sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); + //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); + //sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); /* DPCT1054:310: The type of variable temp_storage is declared in device function with the name type_ct11. Adjust the code to make the type_ct11 declaration visible at the accessor declaration point. */ - sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + using group_load_T = dpct::group::workgroup_load; + using group_load_T1 = dpct::group::workgroup_load; + using group_load_float1 = dpct::group::workgroup_load; + size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); + sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + + + using group_store_T = dpct::group::workgroup_store; + using group_store_float1 = dpct::group::workgroup_store; + size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); + sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE), sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit1StateBlockwise(p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, smem_quantiles1_acc_ct1, smem_exchange1_acc_ct1.get_pointer(), temp_storage_ct1_acc_ct1.get_pointer()); + kOptimizerStatic8bit1StateBlockwise(buff_p, buff_g, buff_state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); }); }); } /* DPCT1010:249: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); break; } + q_ct1.memcpy((T*)(g), (T*)(buff_g), size); + q_ct1.memcpy((T*)(p), (T*)(buff_p), size); + q_ct1.memcpy((float*)(state1), (float*)(buff_state1), size); + q_ct1.memcpy((float*)(state2), (float*)(buff_state2), size); + } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; @@ -1155,8 +1217,15 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, { dpct::device_ext &dev_ct1 = dpct::get_current_device(); sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int num_blocks = n/2048; num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + int size = NUM_BLOCK; + *((T **)&buff_g) = sycl::malloc_device(size, g, ctx); + q_ct1.memcpy((T*)(buff_g), (T*)(g), size); + + CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(&gnorm_vec[step % 100], 0, 1*sizeof(float)).wait())); /* DPCT1049:68: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. @@ -1164,15 +1233,20 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.parallel_for( + using group_load_T = dpct::group::workgroup_load; + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kPercentileClipping(g, gnorm_vec, step, n, item_ct1); + kPercentileClipping(g, gnorm_vec, step, n, item_ct1, ltacc); }); } /* DPCT1010:250: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); } void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) @@ -1543,6 +1617,12 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, sycl::half #define STATS_ROWS 16 void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) { + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + int tile_cols = STATS_THREADS*STATS_ITEMS; int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); @@ -1551,15 +1631,57 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c row_tiles = row_tiles > 0 ? row_tiles : 1; col_tiles = col_tiles > 0 ? col_tiles : 1; int num_blocks = row_tiles * col_tiles; + + int size = NUM_BLOCK; + *((sycl::half **)&buff_A) = sycl::malloc_device(size, A, ctx); + q_ct1.memcpy((sycl::half*)(buff_A), (sycl::half*)(A), size); + if(nnz_threshold == 0.0) - kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.parallel_for( + using group_load_half = dpct::group::workgroup_load; + using group_exchange = dpct::group::exchange; + + size_t load_temp_storage_size_half = group_load_half::get_local_memory_size(NUM_BLOCK); + size_t exchange_temp_storage_size = group_exchange::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor exacc(exchange_temp_storage_size, cgh); + sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); + + + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kgetColRowStats(buff_A, rowStats, colStats, + nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols, ltacc_half, exacc); + }); + } else if(nnz_threshold != 0.0) - kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.parallel_for( + using group_load_half = dpct::group::workgroup_load; + using group_exchange = dpct::group::exchange; + + size_t load_temp_storage_size_half = group_load_half::get_local_memory_size(NUM_BLOCK); + size_t exchange_temp_storage_size = group_exchange::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor exacc(exchange_temp_storage_size, cgh); + sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); + + + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kgetColRowStats(buff_A, rowStats, colStats, + nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols, ltacc_half, exacc); + }); + } + /* DPCT1010:284: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); } From 43a18999bcab1a59b8993f01aa9911749766b653 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 1 Apr 2024 04:47:28 -0700 Subject: [PATCH 18/66] add shared vars --- csrc/sycl/kernels.cpp | 252 ++++++++++++++---------------------------- csrc/sycl/ops.cpp | 158 +++++++++++++++++--------- 2 files changed, 193 insertions(+), 217 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index cd7e684c1..a356ae439 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -749,7 +749,7 @@ void kEstimateQuantiles(const T *buff_A, float *code, const float offset, const SYCL_EXTERNAL void kQuantize(float * code, float * __restrict__ const buff_A, unsigned char *buff_out, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_code, sycl_la_float ltacc, sycl_la_unsigned_char stacc) + const sycl::nd_item<3> &item_ct1, float* smem_code, sycl_la_float ltacc, sycl_la_unsigned_char stacc) { const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*NUM_BLOCK) : NUM_BLOCK; @@ -1792,7 +1792,9 @@ kPreconditionOptimizerStatic8bit2State(T* buff_p, T* __restrict__ const buff_g, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, + const sycl::nd_item<3> &item_ct1, + float* smem_quantiles1, float* smem_quantiles2, + sycl_la_T ltacc_T, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2) { const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; @@ -1983,8 +1985,11 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, - float *smem_quantiles2) + const sycl::nd_item<3> &item_ct1, float* smem_quantiles1, float* smem_quantiles2, + sycl_la_T ltacc_T, sycl_la_T ltacc_T1, + sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, + sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2 + ) { const int n_full = (item_ct1.get_local_range(2) * item_ct1.get_group_range(2))*NUM_PER_THREAD2; @@ -2049,29 +2054,16 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state DPCT1007:167: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_g[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, g_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + /* DPCT1065:161: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2081,29 +2073,16 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state DPCT1007:168: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, c1s); - }); - }); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], c1s); + /* DPCT1065:162: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2113,29 +2092,16 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state DPCT1007:169: Migration of cub::BlockLoad::Load is not supported. */ //LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state2[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, c2s); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_float2.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_state2[0], c2s); + /* DPCT1065:163: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2145,29 +2111,16 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state DPCT1007:170: Migration of cub::BlockLoad::Load is not supported. */ //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_p[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, p_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T1.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_p[0], p_vals); + if((i + (item_ct1.get_local_id(2)*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } # pragma unroll 4 @@ -2210,30 +2163,17 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state DPCT1007:171: Migration of cub::BlockStore::Store is not supported. */ //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_p[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, p_vals); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_T.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_p[0], p_vals); + /* DPCT1065:164: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2242,29 +2182,16 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state DPCT1007:172: Migration of cub::BlockStore::Store is not supported. */ //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state1[i], h, sycl::read_write); - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, c1s); - }); - - }); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_float1.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_state1[0], c1s); + /* DPCT1065:165: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2274,30 +2201,17 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state DPCT1007:173: Migration of cub::BlockStore::Store is not supported. */ //StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_store = dpct::group::workgroup_store; - size_t temp_storage_size = group_load::get_local_memory_size(NUM_PER_THREAD2); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_state2[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, c2s); - }); - - }); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_float2.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_state2[0], c2s); + + /* DPCT1065:166: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ @@ -2321,6 +2235,7 @@ kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, const float weight_decay, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, + float* smem_quantiles1, sycl_la_T ltacc_T, sycl_la_float ltacc_float1) { const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; @@ -2742,6 +2657,9 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + sycl::local_accessor smem_quantiles2, + float *smem_exchange1, float *smem_exchange2, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, sycl_la_T stacc_T, sycl_la_float1 stacc_float1, sycl_la_float stacc_float2) @@ -3051,6 +2969,8 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, + sycl::local_accessor smem_quantiles1, + float *smem_exchange1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1,sycl_la_T stacc_T, sycl_la_float stacc_float1 @@ -3509,12 +3429,10 @@ template(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, - float *smem_row_absmax_values, - int *smem_row_nnz_values); + sycl_la_half ltacc_half, sycl_la_unsigned exacc); template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, - float *smem_row_absmax_values, - int *smem_row_nnz_values); + sycl_la_half ltacc_half, sycl_la_unsigned exacc); #define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 24b671a5b..33d46dac9 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -127,10 +127,13 @@ void quantize(float *code, float *A, unsigned char *out, int n) sycl::local_accessor ltacc(load_temp_storage_size, cgh); sycl::local_accessor stacc(store_temp_storage_size, cgh); + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kQuantize(code, buff_A, buff_out, n, item_ct1, ltacc, stacc); + kQuantize(code, buff_A, buff_out, n, item_ct1, smem_code_acc_ct1.get_pointer(), ltacc, stacc); }); }); } @@ -165,10 +168,13 @@ void dequantize(float *code, unsigned char *A, float *out, int n) q_ct1.submit( [&](sycl::handler &cgh) { + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kDequantize(code, buff_A, buff_out, n, item_ct1); + kDequantize(code, buff_A, buff_out, n, item_ct1, smem_code_acc_ct1.get_pointer()); }); }); //q_ct1.wait(); @@ -209,8 +215,6 @@ template void quantizeBlockwise(floa dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); - //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); using group_load_T = dpct::group::workgroup_load; size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); using group_store = dpct::group::workgroup_store; @@ -221,11 +225,16 @@ template void quantizeBlockwise(floa sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); sycl::local_accessor stacc(store_temp_storage_size, cgh); - + + //__shared__ vars for funtions + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); }); }); } @@ -238,8 +247,7 @@ template void quantizeBlockwise(floa q_ct1.submit( [&](sycl::handler &cgh) { - //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); - //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + using group_load_T = dpct::group::workgroup_load; size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); using group_store = dpct::group::workgroup_store; @@ -251,10 +259,14 @@ template void quantizeBlockwise(floa sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); sycl::local_accessor stacc(store_temp_storage_size, cgh); + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); }); }); } @@ -263,8 +275,7 @@ template void quantizeBlockwise(floa dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); - //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + using group_load_T = dpct::group::workgroup_load; size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); using group_store = dpct::group::workgroup_store; @@ -276,10 +287,14 @@ template void quantizeBlockwise(floa sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); sycl::local_accessor stacc(store_temp_storage_size, cgh); + //__shared__vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); }); }); } @@ -288,8 +303,7 @@ template void quantizeBlockwise(floa dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); - //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + using group_load_T = dpct::group::workgroup_load; size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); using group_store = dpct::group::workgroup_store; @@ -301,10 +315,14 @@ template void quantizeBlockwise(floa sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); sycl::local_accessor stacc(store_temp_storage_size, cgh); + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); }); }); } @@ -313,8 +331,7 @@ template void quantizeBlockwise(floa dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); - //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + using group_load_T = dpct::group::workgroup_load; size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); using group_store = dpct::group::workgroup_store; @@ -326,10 +343,15 @@ template void quantizeBlockwise(floa sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); sycl::local_accessor stacc(store_temp_storage_size, cgh); + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); }); }); } @@ -338,8 +360,7 @@ template void quantizeBlockwise(floa dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); - //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + using group_load_T = dpct::group::workgroup_load; size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); using group_store = dpct::group::workgroup_store; @@ -351,10 +372,13 @@ template void quantizeBlockwise(floa sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); sycl::local_accessor stacc(store_temp_storage_size, cgh); + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); }); }); } @@ -363,8 +387,7 @@ template void quantizeBlockwise(floa dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - //sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); - //sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + using group_load_T = dpct::group::workgroup_load; size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); using group_store = dpct::group::workgroup_store; @@ -375,11 +398,16 @@ template void quantizeBlockwise(floa sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); sycl::local_accessor stacc(store_temp_storage_size, cgh); + + //__shared__ vars + sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); }); }); } @@ -822,12 +850,16 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); sycl::local_accessor ltacc_float2(load_temp_storage_size_float2, cgh); + + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizerStatic8bit2State(buff_p, buff_g, buff_state1, buff_state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float1, ltacc_float2); + kPreconditionOptimizerStatic8bit2State(buff_p, buff_g, buff_state1, buff_state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), ltacc_T, ltacc_float1, ltacc_float2); }); }); } @@ -842,8 +874,8 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); - //sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); /* DPCT1054:301: The type of variable temp_storage is declared in device function with the name type_ct7. Adjust the code to make the type_ct7 declaration visible at the accessor declaration point. */ @@ -876,10 +908,14 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit2State(buff_p, buff_g, buff_state1, buff_state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, ltacc_float2, stacc_T, stacc_float1, stacc_float2); + kOptimizerStatic8bit2State(buff_p, buff_g, buff_state1, buff_state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), ltacc_T, ltacc_T1, ltacc_float1, ltacc_float2, stacc_T, stacc_float1, stacc_float2); }); }); } @@ -900,7 +936,7 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ DPCT1054:302: The type of variable temp_storage is declared in device function with the name type_ct8. Adjust the code to make the type_ct8 declaration visible at the accessor declaration point. */ //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + ; using group_load_T = dpct::group::workgroup_load; using group_load_float1 = dpct::group::workgroup_load; size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); @@ -909,10 +945,13 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - cgh.parallel_for( + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh) + + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float1); + kPreconditionOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), ltacc_T, ltacc_float1); }); }); } @@ -927,7 +966,7 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + /* DPCT1054:303: The type of variable temp_storage is declared in device function with the name type_ct9. Adjust the code to make the type_ct9 declaration visible at the accessor declaration point. */ @@ -953,11 +992,12 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); - + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1,ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); + kOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1,smem_quantiles1_acc_ct1.get_pointer(), ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); }); }); } @@ -975,7 +1015,7 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + /* DPCT1054:304: The type of variable temp_storage is declared in device function with the name type_ct9. Adjust the code to make the type_ct9 declaration visible at the accessor declaration point. */ @@ -1001,10 +1041,13 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit1State(`buff_p, buff_g, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); + kOptimizerStatic8bit1State(`buff_p, buff_g, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); }); }); } @@ -1022,8 +1065,6 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ DPCT1054:305: The type of variable temp_storage is declared in device function with the name type_ct8. Adjust the code to make the type_ct8 declaration visible at the accessor declaration point. */ //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); - using group_load_T = dpct::group::workgroup_load; using group_load_float1 = dpct::group::workgroup_load; size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); @@ -1031,10 +1072,13 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float1); + kPreconditionOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), ltacc_T, ltacc_float1); }); }); } @@ -1097,16 +1141,11 @@ template void optimizerStatic8bitBlockwise(T* buff_p, /* DPCT1101:306: 'LANES' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. */ - sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); /* DPCT1101:307: 'LANES' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. */ - - //sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); - //sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); - //sycl::local_accessor smem_exchange2_acc_ct1(sycl::range<1>(1), cgh); - /* + /* DPCT1054:308: The type of variable temp_storage is declared in device function with the name type_ct10. Adjust the code to make the type_ct10 declaration visible at the accessor declaration point. */ //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); @@ -1135,11 +1174,19 @@ template void optimizerStatic8bitBlockwise(T* buff_p, sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); sycl::local_accessor stacc_float2(store_temp_storage_size_float2, cgh); + + + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); + sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); + sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); + sycl::local_accessor smem_exchange2_acc_ct1(sycl::range<1>(1), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_2STATE/NUM_2STATE), sycl::range<3>(1, 1, BLOCKSIZE_2STATE/NUM_2STATE)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit2StateBlockwise(buff_p, buff_g, buff_state1, buff_state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, ltacc_float2, stacc_T, stacc_float1, stacc_float2); + kOptimizerStatic8bit2StateBlockwise(buff_p, buff_g, buff_state1, buff_state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n,item_ct1, smem_quantiles1_acc_ct1, smem_quantiles2_acc_ct1,smem_exchange1_acc_ct1.get_pointer(), smem_exchange2_acc_ct1.get_pointer(),ltacc_T, ltacc_T1, ltacc_float1, ltacc_float2, stacc_T, stacc_float1, stacc_float2); }); }); } @@ -1161,8 +1208,7 @@ template void optimizerStatic8bitBlockwise(T* buff_p, /* DPCT1101:309: 'LANES' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. */ - //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); - //sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); + /* DPCT1054:310: The type of variable temp_storage is declared in device function with the name type_ct11. Adjust the code to make the type_ct11 declaration visible at the accessor declaration point. */ @@ -1187,10 +1233,15 @@ template void optimizerStatic8bitBlockwise(T* buff_p, sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + + //__shared__ vars + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); + sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE), sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit1StateBlockwise(buff_p, buff_g, buff_state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); + kOptimizerStatic8bit1StateBlockwise(buff_p, buff_g, buff_state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, smem_quantiles1_acc_ct1, smem_exchange1_acc_ct1.get_pointer(), ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); }); }); } @@ -1247,8 +1298,15 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, DPCT1010:250: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ //CUDA_CHECK_RETURN(0); + //back memcpy + q_ct1.memcpy((T*)(g), (T*)(buff_g), size); } + + + +//=======================GEMM================================= + void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) try { const int falpha = 1; @@ -1829,7 +1887,7 @@ template void spmm_coo_very_sparse_naive(int *max_count, /* DPCT1101:311: 'SMEM_SIZE' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. */ - sycl::local_accessor smem_dequant_stats_acc_ct1(sycl::range<1>(2048/*SMEM_SIZE*/), cgh); + //sycl::local_accessor smem_dequant_stats_acc_ct1(sycl::range<1>(2048/*SMEM_SIZE*/), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, nnz_rows) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), From 38c2f8c991520d203473b869d896ea847c38a13c Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 1 Apr 2024 04:54:08 -0700 Subject: [PATCH 19/66] add shared for spm --- csrc/sycl/ops.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 33d46dac9..f32ff6494 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1305,7 +1305,7 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, -//=======================GEMM================================= +//========================GEMM============================ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) try { @@ -1670,6 +1670,9 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, sycl::half CUDA_CHECK_RETURN(0); } + +//===========================Row col stats================================= + #define STATS_THREADS 64 #define STATS_ITEMS 4 #define STATS_ROWS 16 @@ -1887,7 +1890,7 @@ template void spmm_coo_very_sparse_naive(int *max_count, /* DPCT1101:311: 'SMEM_SIZE' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. */ - //sycl::local_accessor smem_dequant_stats_acc_ct1(sycl::range<1>(2048/*SMEM_SIZE*/), cgh); + sycl::local_accessor smem_dequant_stats_acc_ct1(sycl::range<1>(2048/*SMEM_SIZE*/), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, nnz_rows) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), From 04a3ba12f44b97c41185c4a97e90869a60b7fb9a Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 1 Apr 2024 06:46:19 -0700 Subject: [PATCH 20/66] add gemm ops + kernel fix --- csrc/sycl/kernels.cpp | 15 ++------- csrc/sycl/ops.cpp | 75 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index a356ae439..c6791d794 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -843,9 +843,6 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ int local_rand_idx = 0; - sycl::buffer buff_A(A,sycl::range<1>(NUM_PER_TH)); - sycl::buffer buff_out(out,sycl::range<1>(NUM_PER_TH)); - sycl::buffer buff_rand(rand,sycl::range<1>(NUM_PER_TH)); //typedef cub::BlockLoad LoadT; @@ -4677,7 +4674,7 @@ DPCT1110:32: The total declared local variable size in device function kgemm_4bi template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *smem_A, - T *smem_B, + unsigned char *smem_B, T *smem_C) { @@ -4705,10 +4702,6 @@ template SYCL_EXTERNAL void kgemm_4bit_inference(int M const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); - - - - /* DPCT1082:33: Migration of nvcuda::wmma::fragment type is not supported. */ @@ -4946,17 +4939,13 @@ wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_o /* DPCT1110:49: The total declared local variable size in device function kgemm_4bit_inference_naive exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. */ -template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, - const sycl::nd_item<3> &item_ct1, - T *quant_map) +template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *quant_map) { // per threadblock: // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] // 4 warps -> 4 loads per iter // 1x32 * 32x4 -> 1x4 outputs per thread block - - const int warp_idx = item_ct1.get_local_id(2) / 32; const int warp_lane = item_ct1.get_local_id(2) % 32; diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index f32ff6494..72ee233b6 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1902,7 +1902,7 @@ template void spmm_coo_very_sparse_naive(int *max_count, /* DPCT1010:289: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); } @@ -1935,7 +1935,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id /* DPCT1010:290: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); } @@ -1945,6 +1945,19 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out { int num_blocks = (m+31)/32; + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + int size = NUM_BLOCK; + *((T **)&buff_A) = sycl::malloc_device(size, A, ctx); + q_ct1.memcpy((T*)(buff_A), (T*)(A), size); + *((T **)&buff_B) = sycl::malloc_device(size, B, ctx); + q_ct1.memcpy((T*)(buff_B), (T*)(B), size); + *((T **)&buff_out) = sycl::malloc_device(size, out, ctx); + q_ct1.memcpy((T*)(buff_out), (T*)(out), size); + //cout << num_blocks << endl; //cout << lda << endl; @@ -1966,16 +1979,18 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out /* DPCT1101:312: '8*16 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. */ - sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(224/*8*16 + (2*16*(batch_size_warps-1))*/), cgh); + //sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(224/*8*16 + (2*16*(batch_size_warps-1))*/), cgh); /* DPCT1101:313: '2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. */ + //__shared__ vars + sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(224/*8*16 + (2*16*(batch_size_warps-1))*/), cgh); sycl::local_accessor smem_B_acc_ct1(sycl::range<1>(4192/*2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))*/), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 160), sycl::range<3>(1, 1, 160)), [=](sycl::nd_item<3> item_ct1) { - gemm_device(m, n, k, A, B, out, lda, ldb, ldc, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer()); + gemm_device(m, n, k, buff_A, buff_B, buff_out, lda, ldb, ldc, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer()); }); }); } @@ -1983,6 +1998,11 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //back memcpy + q_ct1.memcpy((T*)(A), (T*)(buff_A), size); + q_ct1.memcpy((T*)(B), (T*)(buff_B), size); + q_ct1.memcpy((T*)(out), (T*)(buff_out), size); + } template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) @@ -1998,6 +2018,19 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //cout << m << endl; //cout << n << endl; //cout << k << endl; + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + int size = NUM_BLOCK; + *((T **)&buff_A) = sycl::malloc_device(size, A, ctx); + q_ct1.memcpy((T*)(buff_A), (T*)(A), size); + *(( unsigned char**)&buff_B) = sycl::malloc_device(size, B, ctx); + q_ct1.memcpy((unsigned char*)(buff_B), (unsigned char*)(B), size); + *((T **)&buff_out) = sycl::malloc_device(size, out, ctx); + q_ct1.memcpy((T*)(buff_out), (T*)(out), size); + { dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); dpct::get_in_order_queue().submit( @@ -2005,29 +2038,48 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi /* DPCT1101:314: '8*16 + (16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. */ - sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(176/*8*16 + (16*(batch_size_warps-1))*/), cgh); + /* DPCT1101:315: '2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. */ + + //__shared__ vars + sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(176/*8*16 + (16*(batch_size_warps-1))*/), cgh); sycl::local_accessor smem_B_acc_ct1(sycl::range<1>(4192/*2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))*/), cgh); sycl::local_accessor smem_C_acc_ct1(sycl::range<1>(8*32), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 96), sycl::range<3>(1, 1, 96)), [=](sycl::nd_item<3> item_ct1) { - kgemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer(), smem_C_acc_ct1.get_pointer()); + kgemm_4bit_inference(m, n, k, buff_A, buff_B, absmax, buff_out, lda, ldb, ldc, blocksize, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer(), smem_C_acc_ct1.get_pointer()); }); }); } //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //back memcpy + q_ct1.memcpy((T*)(A), (T*)(buff_A), size); + q_ct1.memcpy((unsigned char*)(B), (unsigned char*)(buff_B), size); + q_ct1.memcpy((T*)(out), (T*)(buff_out), size); + } template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { int num_blocks = (m+3)/4; + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + int size = NUM_BLOCK; + *((T **)&buff_A) = sycl::malloc_device(size, A, ctx); + q_ct1.memcpy((T*)(buff_A), (T*)(A), size); + *(( unsigned char**)&buff_B) = sycl::malloc_device(size, B, ctx); + q_ct1.memcpy((unsigned char*)(buff_B), (unsigned char*)(B), size); + *((T **)&buff_out) = sycl::malloc_device(size, out, ctx); + q_ct1.memcpy((T*)(buff_out), (T*)(out), size); { dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); @@ -2038,14 +2090,19 @@ template void gemm_4bit_inference_naive(int m, int n, int cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { - kgemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, item_ct1, quant_map_acc_ct1.get_pointer()); + kgemm_4bit_inference_naive(m, n, k, buff_A, buff_B, absmax, datatype, buff_out, lda, ldb, ldc, blocksize, item_ct1, quant_map_acc_ct1.get_pointer()); }); }); } /* DPCT1010:291: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); + q_ct1.memcpy((T*)(A), (T*)(buff_A), size); + q_ct1.memcpy((unsigned char*)(B), (unsigned char*)(buff_B), size); + q_ct1.memcpy((T*)(out), (T*)(buff_out), size); + + } template void func(T *A, T *B, T value, long n) @@ -2065,7 +2122,7 @@ template void func(T *A, T *B, T value, long n) /* DPCT1010:292: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); } //============================================================== From 06ba2217bd5cebae102036b38e173169f296ae7f Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Tue, 2 Apr 2024 00:35:13 -0700 Subject: [PATCH 21/66] update kernel headers --- csrc/sycl/kernels.cpp | 52 +++++++++++++++++++++------------------ csrc/sycl/kernels.h | 57 ++++++++++++++++++++++++++++++------------- 2 files changed, 68 insertions(+), 41 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index c6791d794..f2aa0a34c 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -5345,15 +5345,15 @@ template unsigned char dQuantize<0>(float* smem_code, const float rand, float x) template unsigned char dQuantize<1>(float* smem_code, const float rand, float x); template SYCL_EXTERNAL void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc); template SYCL_EXTERNAL void kEstimateQuantiles(sycl::half *__restrict__ const A, float *code, const float offset, const sycl::half max_val, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc); #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ float* state1, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1; \ + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_float ltacc_float1); \ SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(MOMENTUM, sycl::half) SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) @@ -5367,7 +5367,7 @@ SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ template void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1); \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1,sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_T stacc_T, sycl_la_float, stacc_float1); \ SYCL_EXTERNAL MAKE_Optimizer32bit1State(MOMENTUM, sycl::half) SYCL_EXTERNAL MAKE_Optimizer32bit1State(MOMENTUM, float) @@ -5383,7 +5383,7 @@ SYCL_EXTERNAL MAKE_Optimizer32bit1State(ADAGRAD, float) template void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ float* state1, float* state2, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1); \ + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,sycl_la_T ltacc_T, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2); \ SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, float) SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::half) @@ -5391,13 +5391,16 @@ SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::ext::oneapi::bfl template SYCL_EXTERNAL void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2 + sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::half* g, sycl::half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2 + sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::ext::oneapi::bfloat16* g, sycl::ext::oneapi::bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2 + sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ template void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ @@ -5409,7 +5412,7 @@ template void kPreconditionOptimizerStatic8bit1State(gtype* p, gty float* max1, float* new_max1, \ const float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1,sycl_la_T ltacc_T, sycl_la_float ltacc_float1); \ SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(MOMENTUM, sycl::half) SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(MOMENTUM, float) @@ -5428,7 +5431,8 @@ template void kOptimizerStatic8bit1State(gtype* p, gtype* const g, float* max1, float* new_max1, \ float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1,sycl_la_T ltacc_T, sycl_la_T ltacc_T1,sycl_la_float ltacc_float1, sycl_la_T stacc_T, \ + sycl_la_float stacc_float1); \ SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(MOMENTUM, sycl::half) SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(MOMENTUM, float) @@ -5445,7 +5449,7 @@ template void kPreconditionOptimizerStatic8bit2State(gtype* p, gty float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2,sycl_la_T ltacc_T, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2); \ SYCL_EXTERNAL MAKE_PreconditionStatic8bit2State(ADAM, sycl::half) SYCL_EXTERNAL MAKE_PreconditionStatic8bit2State(ADAM, float) @@ -5459,18 +5463,18 @@ template void kOptimizerStatic8bit2State(gtype* p, gtype* const g, float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2,sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); \ SYCL_EXTERNAL MAKE_optimizerStatic8bit2State(ADAM, sycl::half) SYCL_EXTERNAL MAKE_optimizerStatic8bit2State(ADAM, float) template SYCL_EXTERNAL void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T); template SYCL_EXTERNAL void kPercentileClipping(sycl::half * __restrict__ g, float *gnorm_vec, int step, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T); #define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ -template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value); \ +template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value,sycl_la_T ltacc_T, sycl_la_float ltacc_float,sycl_la_unsigned_char stacc); \ SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, General8bit) SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 1, General8bit) @@ -5541,23 +5545,23 @@ SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, N SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, NF4) template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ template void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ @@ -5566,7 +5570,7 @@ template void kOptimizerStatic8bit2StateBlockwise &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2); \ + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2,sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, sycl_la_T stacc_T, sycl_la_float1 stacc_float1, sycl_la_float stacc_float2); \ SYCL_EXTERNAL MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) SYCL_EXTERNAL MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::half, 2048, 8) @@ -5581,7 +5585,7 @@ template void kOptimizerStatic8bit1StateBlockwise &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1); \ + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1,sycl_la_T ltacc_T, sycl_la_T ltacc_T1,sycl_la_float ltacc_float1,sycl_la_T stacc_T, sycl_la_float stacc_float1); \ SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::half, 2048, 8) diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index e334d1f6c..5342f115d 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -13,49 +13,58 @@ #pragma once +typedef sycl::accessor sycl_la_float; +typedef sycl::accessor sycl_la_T; +typedef sycl::accessor sycl_la_unsigned_char; +typedef sycl::accessor sycl_la_half; +typedef sycl::accessor sycl_la_unsigned; + //template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); template extern SYCL_EXTERNAL void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc); extern SYCL_EXTERNAL void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_float ltacc, sycl_la_unsigned_char stacc); extern SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, float *smem_code); template extern SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, - float *smem_absmax_value); -template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1); + float *smem_absmax_value, + sycl_la_T ltacc_T, sycl_la_float ltacc_float,sycl_la_unsigned_char stacc); +template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); template extern SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2); template extern SYCL_EXTERNAL void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2 + sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); template extern SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_float ltacc_float1); template extern SYCL_EXTERNAL void kOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, + sycl_la_T stacc_T, sycl_la_float, stacc_float1); template extern SYCL_EXTERNAL void @@ -68,7 +77,8 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c const float weight_decay, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, - float *smem_quantiles1); + float *smem_quantiles1, + sycl_la_T ltacc_T, sycl_la_float ltacc_float1); template @@ -80,7 +90,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_quantiles1); + const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, + sycl_la_T ltacc_T, sycl_la_T ltacc_T1, + sycl_la_float ltacc_float1, sycl_la_T stacc_T, + sycl_la_float stacc_float1); @@ -94,7 +107,9 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c float* max1, float* max2, float* new_max1, float* new_max2, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, - float *smem_quantiles1, float *smem_quantiles2); + float *smem_quantiles1, float *smem_quantiles2, + sycl_la_T ltacc_T, + sycl_la_float ltacc_float1, sycl_la_float ltacc_float2); template @@ -107,7 +122,9 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, - float *smem_quantiles2); + float *smem_quantiles2, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, + sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, + sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); template extern SYCL_EXTERNAL void kOptimizerStatic8bit2StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, @@ -117,7 +134,10 @@ template extern SYCL_EX const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, - float *smem_exchange1, float *smem_exchange2); + float *smem_exchange1, float *smem_exchange2, + sycl_la_T ltacc_T, sycl_la_T ltacc_T1, + sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, + sycl_la_T stacc_T, sycl_la_float1 stacc_float1, sycl_la_float stacc_float2); template extern SYCL_EXTERNAL void kOptimizerStatic8bit1StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, @@ -129,10 +149,13 @@ template extern SYCL_EX const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, - float *smem_exchange1); + float *smem_exchange1, + sycl_la_T ltacc_T, sycl_la_T ltacc_T1, + sycl_la_float ltacc_float1,sycl_la_T stacc_T, + sycl_la_float stacc_float1); -template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n,const sycl::nd_item<3> &item_ct1); +template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n,const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T); void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, @@ -156,7 +179,7 @@ extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( const sycl::nd_item<3> &item_ct1, float *smem_rowStats); template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, - const sycl::nd_item<3> &item_ct1,float *smem_row_absmax_values,int *smem_row_nnz_values); + const sycl::nd_item<3> &item_ct1,float *smem_row_absmax_values,int *smem_row_nnz_values, sycl_la_half ltacc_half, sycl_la_unsigned exacc); template extern SYCL_EXTERNAL void kDoubleRowColQuant(sycl::half *__restrict__ const A, From ee7997dfbbf7c6d140baf4bb1946a961cfbe708b Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 3 Apr 2024 02:36:33 -0700 Subject: [PATCH 22/66] port transform kernels --- csrc/sycl/kernels.cpp | 86 +++++++------------- csrc/sycl/kernels.h | 4 +- csrc/sycl/ops.cpp | 179 ++++++++++++++++++++++++++++++++---------- 3 files changed, 166 insertions(+), 103 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index f2aa0a34c..80a5ebad6 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -637,7 +637,7 @@ typedef sycl::accessor sycl_la_unsigned_char; typedef sycl::accessor sycl_la_half; typedef sycl::accessor sycl_la_unsigned; - +typedef sycl::accessor sycl_la_char; template SYCL_EXTERNAL @@ -3608,7 +3608,7 @@ template void kdequant_mm_i } -template void kDoubleRowColQuant(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx) +template void kDoubleRowColQuant(sycl::half *__restrict__ const buff_A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *buff_out_col_normed, char *buff_out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, sycl_la_half ltacc_half, sycl_la_char stacc_char1, sycl_la_char stacc_char2 ) { // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD // Each thread reads the same column but multiple rows @@ -3633,11 +3633,6 @@ template StoreInt8; - sycl::buffer buff_A(A,sycl::range<1>(ITEMS_PER_THREAD)); - sycl::buffer buff_out_row_normed(out_row_normed,sycl::range<1>(ITEMS_PER_THREAD)); - sycl::buffer buff_out_col_normed(out_col_normed,sycl::range<1>(ITEMS_PER_THREAD)); - - sycl::half local_data[ITEMS_PER_THREAD]; float local_col_stats[ITEMS_PER_THREAD]; @@ -3678,30 +3673,16 @@ template ; - size_t temp_storage_size = group_load::get_local_memory_size(ITEMS_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_A[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, local_data); - }); - - }); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_half.get_multi_ptr().get(); + group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], local_data); + float row_stat = 127.0f / smem_row_stats[row]; // 2. quantize data with row/col stats @@ -3735,29 +3716,16 @@ template ; - size_t temp_storage_size = group_store::get_local_memory_size(ITEMS_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_out_row_normed[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, local_quantized_data); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = stacc_char1.get_multi_ptr().get(); + group_store(tmp).store(item,item.get_local_linear_id(), &buff_out_row_normed[0], local_quantized_data); + // 2. quantize data with row/col stats #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) @@ -3882,7 +3850,7 @@ template sycl_la_unsigned_char; typedef sycl::accessor sycl_la_half; typedef sycl::accessor sycl_la_unsigned; +typedef sycl::accessor sycl_la_char; //template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); @@ -190,7 +191,8 @@ extern SYCL_EXTERNAL void kDoubleRowColQuant(sycl::half *__restrict__ const A, int *__restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, - float *smem_row_stats, unsigned int *smem_nnz_row_idx); + float *smem_row_stats, unsigned int *smem_nnz_row_idx, + sycl_la_half ltacc_half, sycl_la_char stacc_char1, sycl_la_char stacc_char2); template extern SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 72ee233b6..8b435ea71 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -472,7 +472,7 @@ template void dequantizeBlockwise(float *code, unsign sycl::local_accessor ltacc(load_temp_storage_size, cgh); sycl::local_accessor stacc(store_temp_storage_size, cgh); - q_ct1.parallel_for( + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, (n+tile_size-1)/tile_size) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { kDequantizeBlockwise(code, buff_A, absmax, buff_out, blocksize, n, item_ct1); @@ -1283,16 +1283,18 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, */ { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); - q_ct1.parallel_for( - using group_load_T = dpct::group::workgroup_load; - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + q_ct1.submit( + [&](sycl::handler &cgh) { + using group_load_T = dpct::group::workgroup_load; + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { kPercentileClipping(g, gnorm_vec, step, n, item_ct1, ltacc); }); + }); } /* DPCT1010:250: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. @@ -1700,43 +1702,48 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c if(nnz_threshold == 0.0) { - dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); - q_ct1.parallel_for( - using group_load_half = dpct::group::workgroup_load; - using group_exchange = dpct::group::exchange; - - size_t load_temp_storage_size_half = group_load_half::get_local_memory_size(NUM_BLOCK); - size_t exchange_temp_storage_size = group_exchange::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor exacc(exchange_temp_storage_size, cgh); - sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); - + + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + using group_load_half = dpct::group::workgroup_load; + using group_exchange = dpct::group::exchange; - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), - [=](sycl::nd_item<3> item_ct1) { - kgetColRowStats(buff_A, rowStats, colStats, - nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols, ltacc_half, exacc); - }); + size_t load_temp_storage_size_half = group_load_half::get_local_memory_size(NUM_BLOCK); + size_t exchange_temp_storage_size = group_exchange::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor exacc(exchange_temp_storage_size, cgh); + sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kgetColRowStats(buff_A, rowStats, colStats, + nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols, ltacc_half, exacc); + }); + }); } else if(nnz_threshold != 0.0) { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); - q_ct1.parallel_for( - using group_load_half = dpct::group::workgroup_load; - using group_exchange = dpct::group::exchange; - - size_t load_temp_storage_size_half = group_load_half::get_local_memory_size(NUM_BLOCK); - size_t exchange_temp_storage_size = group_exchange::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor exacc(exchange_temp_storage_size, cgh); - sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); - + q_ct1.submit( + [&](sycl::handler &cgh) { + using group_load_half = dpct::group::workgroup_load; + using group_exchange = dpct::group::exchange; - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), - [=](sycl::nd_item<3> item_ct1) { - kgetColRowStats(buff_A, rowStats, colStats, - nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols, ltacc_half, exacc); + size_t load_temp_storage_size_half = group_load_half::get_local_memory_size(NUM_BLOCK); + size_t exchange_temp_storage_size = group_exchange::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor exacc(exchange_temp_storage_size, cgh); + sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + kgetColRowStats(buff_A, rowStats, colStats, + nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols, ltacc_half, exacc); }); + }); } /* @@ -1748,6 +1755,19 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int *nnz_block_ptr, float threshold, int rows, int cols) { + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int num_blocks = 0; + int size = NUM_BLOCK; + + *((sycl::half **)&buff_A) = sycl::malloc_device(size, A, ctx); + *((char **)&buff_out_row_normed) = sycl::malloc_device(size, out_row_normed, ctx); + *((char **)&buff_out_col_normed = sycl::malloc_device(size, out_col_normed, ctx); + q_ct1.memcpy((sycl::half*)(buff_A), (sycl::half*)(A), size); + q_ct1.memcpy((char*)(buff_out_row_normed), (char*)(out_row_normed), size); + q_ct1.memcpy((char*)(buff_out_col_normed), (char*)(out_col_normed), size); + int threads = 64; int items_per_thread = 4; int tile_cols = threads*items_per_thread; @@ -1762,18 +1782,82 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o if(threshold > 0.0f) - kDoubleRowColQuant<64, 4, 16, 64*4, 1><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + using group_load_half = dpct::group::workgroup_load; + using group_store_char1 = dpct::group::workgroup_store; + using group_store_char2 = dpct::group::workgroup_store; + + size_t load_temp_storage_size_half = group_load_half::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_char1 = group_store_char1::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_char2 = group_store_char2::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); + sycl::local_accessor stacc_char1(store_temp_storage_size_char1, cgh); + sycl::local_accessor stacc_char2(store_temp_storage_size_char2, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + + kDoubleRowColQuant(buff_A, rowStats, colStats, buff_out_col_normed, buff_out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols, ltacc_half, stacc_char1, stacc_char2); + }); + }); + } else - kDoubleRowColQuant<64, 4, 16, 64*4, 0><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); - + { + + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + using group_load_half = dpct::group::workgroup_load; + using group_store_char1 = dpct::group::workgroup_store; + using group_store_char2 = dpct::group::workgroup_store; + + size_t load_temp_storage_size_half = group_load_half::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_char1 = group_store_char1::get_local_memory_size(NUM_BLOCK); + size_t store_temp_storage_size_char2 = group_store_char2::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); + sycl::local_accessor stacc_char1(store_temp_storage_size_char1, cgh); + sycl::local_accessor stacc_char2(store_temp_storage_size_char2, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), + [=](sycl::nd_item<3> item_ct1) { + + kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols, ltacc_half, stacc_char1, stacc_char2); + }); + }); + + } /* DPCT1010:285: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); + q_ct1.memcpy((sycl::half*)(A), (sycl::half*)(buff_A), size); + q_ct1.memcpy((char*)(out_row_normed), (char*)(buff_out_row_normed), size); + q_ct1.memcpy((char*)(out_col_normed), (char*)(buff_out_col_normed), size); + } template void transformRowToFormat(char * A, char *out, int rows, int cols) { + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int num_blocks = 0; + int size = NUM_BLOCK; + + *((char **)&buff_A) = sycl::malloc_device(size, A, ctx); + *((char **)&buff_out) = sycl::malloc_device(size, out, ctx); + q_ct1.memcpy((char*)(buff_A), (sycl::half*)(A), size); + q_ct1.memcpy((char*)(buff_out), (char*)(out), size); + + int threads = 256; int items_per_thread = 8; // we load 128 column values per warp @@ -1817,18 +1901,27 @@ template void transformRowToFormat(char * A, char *o */ dpct::get_in_order_queue().submit( [&](sycl::handler &cgh) { + + + + //__shared__ vars sycl::local_accessor smem_data_acc_ct1(sycl::range<1>(32*33*8), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, threads), sycl::range<3>(1, 1, threads)), [=](sycl::nd_item<3> item_ct1) { - kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT>(A, out, rows, cols, tiledCols, outRows, outCols, item_ct1, smem_data_acc_ct1.get_pointer()); + kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT>(buff_A, buff_out, rows, cols, tiledCols, outRows, outCols, item_ct1, smem_data_acc_ct1.get_pointer()); }); }); /* DPCT1010:286: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + + //CUDA_CHECK_RETURN(0); + + q_ct1.memcpy((char*)(A), (sycl::half*)(buff_A), size); + q_ct1.memcpy((char*)(out), (char*)(buff_out), size); + } void spmm_coo(sycl::queue* handle, int *A_rowidx, int *A_colidx, sycl::half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, sycl::half *B, int ldc, sycl::half* C, bool transposed_B) From 26863eadc66fee44d28ada3677e8d3ab8c5289ec Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Thu, 4 Apr 2024 04:55:07 -0700 Subject: [PATCH 23/66] use ldg & sparse csr --- csrc/sycl/kernels.cpp | 8 +++---- csrc/sycl/ops.cpp | 53 ++++++++++++++++++++++++++++--------------- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 80a5ebad6..ac6c6d383 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -1020,7 +1020,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * buff_A, flo /* DPCT1064:96: Migrated __ldg call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code. */ - local_abs_max = absmax[(i+item_ct1.get_local_id(2)*NUM_PER_TH)/(blocksize)]; + local_abs_max = sycl::ext::oneapi::experimental::cuda::ldg(&absmax[(i+item_ct1.get_local_id(2)*NUM_PER_TH)/(blocksize)]); /* DPCT1065:90: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1054,7 +1054,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * buff_A, flo /* DPCT1064:228: Migrated __ldg call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code. */ - vals[j] = code[qvals[j]]*local_abs_max; + vals[j] = sycl::ext::oneapi::experimental::cuda::ldg(&code[qvals[j]]*local_abs_max); break; case FP4: #pragma unroll NUM_PER_TH @@ -4790,7 +4790,7 @@ template SYCL_EXTERNAL void kgemm_4bit_inference(int M /* DPCT1098:222: The '*' expression is used instead of the __ldg call. These two expressions do not provide the exact same functionality. Check the generated code for potential precision and/or performance issues. */ - sycl::half local_absmax = absmax[absidx]; + sycl::half local_absmax = sycl::ext::oneapi::experimental::cuda::ldg(&absmax[absidx]); #pragma unroll 64 for(int col = 0; col < 64; col+=2) @@ -4941,7 +4941,7 @@ template SYCL_EXTERNAL void kgemm_4bit_infer /* DPCT1098:223: The '*' expression is used instead of the __ldg call. These two expressions do not provide the exact same functionality. Check the generated code for potential precision and/or performance issues. */ - local_absmax = absmax[absidx]; + local_absmax = sycl::ext::oneapi::experimental::cuda::ldg(&absmax[absidx]); if(row_B < M) { diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 8b435ea71..3625d592e 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1925,13 +1925,16 @@ template void transformRowToFormat(char * A, char *o } void spmm_coo(sycl::queue* handle, int *A_rowidx, int *A_colidx, sycl::half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, sycl::half *B, int ldc, sycl::half* C, bool transposed_B) -{ +{ + + try{ dpct::device_ext &dev_ct1 = dpct::get_current_device(); sycl::queue &q_ct1 = dev_ct1.in_order_queue(); -#ifdef NO_CUBLASLT -#else +//#ifdef NO_CUBLASLT +//#else + dpct::sparse::sparse_matrix_desc_t descA; std::shared_ptr descB, descC; @@ -1943,12 +1946,13 @@ void spmm_coo(sycl::queue* handle, int *A_rowidx, int *A_colidx, sycl::half *A_v /* DPCT1007:287: Migration of cusparseCreateCoo is not supported. */ - CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz, - A_rowidx, A_colidx, A_vals, - dpct::library_data_t::real_int32, - oneapi::mkl::index_base::zero, dpct::library_data_t::real_half) ); + //CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz, + // A_rowidx, A_colidx, A_vals, + // dpct::library_data_t::real_int32, + // oneapi::mkl::index_base::zero, dpct::library_data_t::real_half) ); // Create dense matrix C - CHECK_CUSPARSE( DPCT_CHECK_ERROR(descC = std::make_shared(A_rows, B_cols, ldc, C, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major)) ); + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descC = std::make_shared(A_rows, B_cols, ldc, C, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major)) ); + descC = std::make_shared(A_rows, B_cols, ldc, C, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major); // Create dense matrix B if(transposed_B) { @@ -1957,20 +1961,33 @@ void spmm_coo(sycl::queue* handle, int *A_rowidx, int *A_colidx, sycl::half *A_v B_cols = tmp; } - CHECK_CUSPARSE( DPCT_CHECK_ERROR(descB = std::make_shared(A_cols, B_cols, ldb, B, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major)) ); + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descB = std::make_shared(A_cols, B_cols, ldb, B, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major)) ); + descB = std::make_shared(A_cols, B_cols, ldb, B, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major); // allocate an external buffer if needed - CHECK_CUSPARSE( DPCT_CHECK_ERROR(bufferSize = 0) ); - CUDA_CHECK_RETURN( DPCT_CHECK_ERROR(dBuffer = (void *)sycl::malloc_device(bufferSize, q_ct1)) ); + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(bufferSize = 0) ); + bufferSize = 0 + //CUDA_CHECK_RETURN( DPCT_CHECK_ERROR(dBuffer = (void *)sycl::malloc_device(bufferSize, q_ct1)) ); + dBuffer = (void *)sycl::malloc_device(bufferSize, q_ct1); // execute SpMM - CHECK_CUSPARSE( DPCT_CHECK_ERROR(dpct::sparse::spmm(*handle, oneapi::mkl::transpose::nontrans, transposed_B ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, &alpha, descA, descB, &beta, descC, dpct::library_data_t::real_float))); - + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(dpct::sparse::spmm(*handle, oneapi::mkl::transpose::nontrans, transposed_B ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, &alpha, descA, descB, &beta, descC, dpct::library_data_t::real_float))); + dpct::sparse::spmm(*handle, oneapi::mkl::transpose::nontrans, transposed_B ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, &alpha, descA, descB, &beta, descC, dpct::library_data_t::real_float); // destroy matrix/vector descriptors - CHECK_CUSPARSE( DPCT_CHECK_ERROR(descA.reset()) ); - CHECK_CUSPARSE( DPCT_CHECK_ERROR(descB.reset()) ); - CHECK_CUSPARSE( DPCT_CHECK_ERROR(descC.reset()) ); - CUDA_CHECK_RETURN( DPCT_CHECK_ERROR(sycl::free(dBuffer, q_ct1)) ); -#endif + descA.reset(); + descB.reset(); + descC.reset(); + sycl::free(dBuffer, q_ct1); + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descA.reset()) ); + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descB.reset()) ); + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descC.reset()) ); + //CUDA_CHECK_RETURN( DPCT_CHECK_ERROR(sycl::free(dBuffer, q_ct1)) ); +//#endif + } + catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); + } + } template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) From caa72f63e95c1e4522f3a6d1d9797212f2f6f053 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 8 Apr 2024 01:05:08 -0700 Subject: [PATCH 24/66] add dnn prototype --- csrc/sycl/ops.cpp | 155 +++++++++++++++++++++++++++++----------------- 1 file changed, 98 insertions(+), 57 deletions(-) diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 3625d592e..290cbe0c4 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -16,6 +16,7 @@ #include #include +#include "oneapi/dnnl/dnnl.hpp" #define ERR_NOT_IMPLEMENTED 100 @@ -1519,129 +1520,170 @@ template void transform(cublasLtHandle_t ltHandl template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) try { - dpct::device_ext &dev_ct1 = dpct::get_current_device(); - sycl::queue &q_ct1 = dev_ct1.in_order_queue(); -#ifdef NO_CUBLASLT - return ERR_NOT_IMPLEMENTED; -#else - int has_error = 0; - cublasLtMatmulDesc_t matmulDesc = NULL; - cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; - oneapi::mkl::transpose opT = oneapi::mkl::transpose::trans; - cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; - cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; - cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C; - cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4; + using namespace dnnl; + using tag = memory::format_tag; + using dt = memory::data_type; + auto dev = sycl::device(sycl::gpu_selector_v); + auto ctx = sycl::context(dev); + + dnnl::engine engine = sycL_interop::make_engine(dev, ctx); + // column major + const memory::dims a_strides = memory::dims {1, lda}; + const auto a_md = memory::desc({m, k}, dt::s8, a_strides); + const memory::dims b_strides = memory::dims {ldb, 1}; + const auto b_md = memory::desc({k, n}, dt::s8, b_strides); + const memory::dims c_strides = memory::dims {ldc, 1}; + const auto c_md = DTYPE_OUT == 32 ? memory::desc({m, n}, dt::s32 c_strides) : memory::desc({m, n}, dt::s8 c_strides); + + //memory align + memory a_mem(a_md, engine A); + memory b_mem(b_md, engine, B); + memory c_mem(c_md, engine, C); + memory scales_C_mem({{1}, dt::f32, {1}}, engine, row_scale); + + //create dnnl stream + auto q_ct1 = sycl::queue(ctx, dev); + dnnl::stream stream = sycl_interop::make_stream(q_ct1); + + primitive_attr attr; + if (SCALE_ROWS) { + attr.set_scales_mask(DNNL_ARG_DST, /* mask */ 1 << 1); + } + + auto matmul_pd = matmul::primitive_desc(engine, a_md, b_md, c_md, attr); + auto matmul_prim = matmul(matmul_pd); + std::unordered_map matmul_args; + matmul_args.insert({DNNL_ARG_SRC, a_mem}); + matmul_args.insert({DNNL_ARG_WEIGHTS, b_mem}); + matmul_args.insert({DNNL_ARG_DST, c_mem}); + + if (SCALE_ROWS) { + matmul_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, scales_C_mem}); + } + matmul_prim.execute(stream, matmul_args); + stream.wait(); + +//#ifdef NO_CUBLASLT +// return ERR_NOT_IMPLEMENTED; +//#else + //int has_error = 0; + //cublasLtMatmulDesc_t matmulDesc = NULL; + //cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + //oneapi::mkl::transpose opT = oneapi::mkl::transpose::trans; + //cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + //cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; + //cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C; + //cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4; /* DPCT1007:262: Migration of cublasLtMatrixLayoutCreate is not supported. */ - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, dpct::library_data_t::real_int8, m, k, lda)); + //has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, dpct::library_data_t::real_int8, m, k, lda)); /* DPCT1007:263: Migration of cublasLtMatrixLayoutCreate is not supported. */ - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, dpct::library_data_t::real_int8, n, k, ldb)); + //has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, dpct::library_data_t::real_int8, n, k, ldb)); /* DPCT1007:264: Migration of cublasLtMatrixLayoutSetAttribute is not supported. */ - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - if(FORMATB == COL_TURING) + //has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + //if(FORMATB == COL_TURING) /* DPCT1007:265: Migration of cublasLtMatrixLayoutSetAttribute is not supported. */ - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); - else + //has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); + //else /* DPCT1007:266: Migration of cublasLtMatrixLayoutSetAttribute is not supported. */ - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); - - if(DTYPE_OUT == 32) - { + //has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); + + //if(DTYPE_OUT == 32) + //{ /* DPCT1007:267: Migration of cublasLtMatmulDescCreate is not supported. */ - has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, dpct::library_data_t::real_int32)); + //has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, dpct::library_data_t::real_int32)); /* DPCT1007:268: Migration of cublasLtMatmulDescSetAttribute is not supported. */ - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); + //has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); /* DPCT1007:269: Migration of cublasLtMatrixLayoutCreate is not supported. */ - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, dpct::library_data_t::real_int32, m, n, ldc)); + //has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, dpct::library_data_t::real_int32, m, n, ldc)); /* DPCT1007:270: Migration of cublasLtMatrixLayoutSetAttribute is not supported. */ - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - int alpha = 1, beta = 0; + //has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + //int alpha = 1, beta = 0; /* DPCT1007:271: Migration of cublasLtMatmul is not supported. */ - has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, &q_ct1)); - } - else - { + //has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, &q_ct1)); + //} + //else + //{ /* DPCT1007:272: Migration of cublasLtMatmulDescCreate is not supported. */ - has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, dpct::library_data_t::real_float)); + //has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, dpct::library_data_t::real_float)); /* DPCT1007:273: Migration of cublasLtMatmulDescSetAttribute is not supported. */ - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); + //has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); /* DPCT1007:274: Migration of cublasLtMatrixLayoutCreate is not supported. */ - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, dpct::library_data_t::real_int8, m, n, ldc)); + //has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, dpct::library_data_t::real_int8, m, n, ldc)); /* DPCT1007:275: Migration of cublasLtMatrixLayoutSetAttribute is not supported. */ - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - if(!SCALE_ROWS) - { - float alpha = 1.0f, beta = 0.0f; + //has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + //if(!SCALE_ROWS) + //{ + //float alpha = 1.0f, beta = 0.0f; /* DPCT1007:276: Migration of cublasLtMatmul is not supported. */ - has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, &q_ct1)); - } - else - { + //has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, &q_ct1)); + //} + //else + //{ /* DPCT1007:277: Migration of cublasLtMatmulDescSetAttribute is not supported. */ - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); + //has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); /* DPCT1007:278: Migration of cublasLtMatmul is not supported. */ - has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, &q_ct1)); - } - } + //has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, &q_ct1)); + //} + //} /* DPCT1007:279: Migration of cublasLtMatrixLayoutDestroy is not supported. */ - if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc)); + //if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc)); /* DPCT1007:280: Migration of cublasLtMatrixLayoutDestroy is not supported. */ - if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc)); + //if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc)); /* DPCT1007:281: Migration of cublasLtMatrixLayoutDestroy is not supported. */ - if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc)); + //if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc)); /* DPCT1007:282: Migration of cublasLtMatmulDescDestroy is not supported. */ - if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); - if(has_error == 1) - printf("error detected"); + //if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); + //if(has_error == 1) + //printf("error detected"); - return has_error; -#endif // NO_CUBLASLT + //return has_error; +//#endif // NO_CUBLASLT } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; @@ -2211,8 +2253,7 @@ template void gemm_4bit_inference_naive(int m, int n, int q_ct1.memcpy((T*)(A), (T*)(buff_A), size); q_ct1.memcpy((unsigned char*)(B), (unsigned char*)(buff_B), size); q_ct1.memcpy((T*)(out), (T*)(buff_out), size); - - + } template void func(T *A, T *B, T value, long n) From 46f1f85fd79bd06b6c2b60eac978ef17bb977a6d Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 8 Apr 2024 05:36:44 -0700 Subject: [PATCH 25/66] dnn kernel --- csrc/sycl/ops.cpp | 90 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 67 insertions(+), 23 deletions(-) diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 290cbe0c4..6403ed841 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1433,6 +1433,47 @@ template int get_leading_dim(int dim1, int dim2); template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) { + + using namespace dnnl; + using tag = memory::format_tag; + using dt = memory::data_type; + auto dev = sycl::device(sycl::gpu_selector_v); + auto ctx = sycl::context(dev); + int ldA = get_leading_dim(dim1, dim2); + int ldOut = get_leading_dim(dim1, dim2); + int ldAOut = get_leading_dim(dim1, dim2); + + dnnl::engine engine = sycL_interop::make_engine(dev, ctx); + // column major + const memory::dims a_strides = memory::dims {1, ldA}; + const auto a_md = DTYPE_OUT ==32 ? memory::desc({dim1, dim2}, dt::s32, a_strides) : memory::desc({dim1, dim2}, dt::s8, a_strides); + const memory::dims out_strides = memory::dims {ldOut, 1}; + const auto out_md = DTYPE_OUT ==32 ? memory::desc({dim1, dim2}, dt::s32, out_strides) : memory::desc({dim1, dim2}, dt::s8, out_strides); + const memory::dims Aout_strides = memory::dims {ldAOut, 1}; + const auto aout_md = DTYPE_OUT == 32 ? memory::desc({dim1, dim2}, dt::s32) : memory::desc({dim1, dim2}, dt::s8); + + //memory align + memory a_mem(a_md, engine A); + memory out_mem(out_md, engine, Out); + memory aout_mem(aout_md, engine, AOut); + + //create dnnl stream + auto q_ct1 = sycl::queue(ctx, dev); + dnnl::stream stream = sycl_interop::make_stream(q_ct1); + + primitive_attr attr; + + auto matmul_pd = matmul::primitive_desc(engine, a_md, out_md, aout_md, attr); + auto matmul_prim = matmul(matmul_pd); + std::unordered_map matmul_args; + matmul_args.insert({DNNL_ARG_SRC, a_mem}); + matmul_args.insert({DNNL_ARG_WEIGHTS, out_mem}); + matmul_args.insert({DNNL_ARG_DST, aout_mem}); + + matmul_prim.execute(stream, matmul_args); + stream.wait(); + +/* #ifdef NO_CUBLASLT #else cublasLtOrder_t orderA = get_order(); @@ -1444,28 +1485,30 @@ template void trans cublasLtMatrixTransformDesc_t A2Out_desc = NULL; oneapi::mkl::transpose opTranspose = oneapi::mkl::transpose::trans; float transformAlpha = 1.0f, transformBeta = 0.0f; + + if(DTYPE == 8) { - /* + DPCT1007:251: Migration of cublasLtMatrixLayoutCreate is not supported. - */ + checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, dpct::library_data_t::real_int8, dim1, dim2, ldA)); - /* + DPCT1007:252: Migration of cublasLtMatrixLayoutCreate is not supported. - */ + checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, dpct::library_data_t::real_int8, dim1, dim2, ldOut)); } else if(DTYPE == 32) { - /* + DPCT1007:253: Migration of cublasLtMatrixLayoutCreate is not supported. - */ + checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, dpct::library_data_t::real_int32, dim1, dim2, ldA)); - /* + DPCT1007:254: Migration of cublasLtMatrixLayoutCreate is not supported. - */ + checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, dpct::library_data_t::real_int32, dim1, dim2, ldOut)); } else @@ -1473,40 +1516,41 @@ template void trans printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE); } - /* + DPCT1007:255: Migration of cublasLtMatrixLayoutSetAttribute is not supported. - */ + checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA))); - /* + DPCT1007:256: Migration of cublasLtMatrixLayoutSetAttribute is not supported. - */ + checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut))); - /* + DPCT1007:257: Migration of cublasLtMatrixTransformDescCreate is not supported. - */ + checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, dpct::library_data_t::real_float)); - /* + DPCT1007:258: Migration of cublasLtMatrixTransformDescSetAttribute is not supported. - */ + if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); } checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0)); - /* + DPCT1007:259: Migration of cublasLtMatrixLayoutDestroy is not supported. - */ + if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); - /* + DPCT1007:260: Migration of cublasLtMatrixLayoutDestroy is not supported. - */ + if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); - /* + DPCT1007:261: Migration of cublasLtMatrixTransformDescDestroy is not supported. - */ + if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); #endif +*/ } template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); @@ -1533,7 +1577,7 @@ template int igemmlt(cublasLtHandle const memory::dims b_strides = memory::dims {ldb, 1}; const auto b_md = memory::desc({k, n}, dt::s8, b_strides); const memory::dims c_strides = memory::dims {ldc, 1}; - const auto c_md = DTYPE_OUT == 32 ? memory::desc({m, n}, dt::s32 c_strides) : memory::desc({m, n}, dt::s8 c_strides); + const auto c_md = DTYPE_OUT == 32 ? memory::desc({m, n}, dt::s32, c_strides) : memory::desc({m, n}, dt::s8, c_strides); //memory align memory a_mem(a_md, engine A); From a94c25372905c959d159762748e77cccfc880d64 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 8 Apr 2024 21:22:26 -0700 Subject: [PATCH 26/66] add cmake initial --- CMakeLists.txt | 33 ++++++++++++++++++++++++++++----- csrc/sycl/ops.cpp | 4 ++-- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index be0d3555f..c5c74bf53 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ # For GCC: `cmake -B build . && cmake --build build` # For MSVC: `cmake -B build . && cmake --build build --config Release` # You can also use the following options and variables -# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend +# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `mps`, or `sycl` to select the backend # - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support # - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version # is whatever CMake finds on your path. @@ -28,11 +28,13 @@ set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) +set(SYCL_FILES csrc/sycl/ops.cpp csrc/sycl/kernels.cpp) + # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) -set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)") -set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps) +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps, sycl)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps sycl) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) if(APPLE) @@ -50,6 +52,7 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") option(NO_CUBLASLT "Disable CUBLAS" OFF) set(BUILD_CUDA ON) set(BUILD_MPS OFF) + set(BUILD_SYCL OFF) message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}") elseif(${COMPUTE_BACKEND} STREQUAL "mps") if(NOT APPLE) @@ -57,9 +60,15 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps") endif() set(BUILD_CUDA OFF) set(BUILD_MPS ON) + set(BUILD_SYCL OFF) +elseif(${COMPUTE_BACKEND} STREQUAL "sycl") + set(BUILD_CUDA OFF) + set(BUILD_SYCL ON) + set(BUILD_MPS OFF) else() set(BUILD_CUDA OFF) set(BUILD_MPS OFF) + set(BUILD_SYCL OFF) endif() @@ -177,12 +186,26 @@ elseif(BUILD_MPS) COMMENT "Compiling Metal kernels" VERBATIM) add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") +elseif(BUILD_SYCL) + if ( NOT DEFINED ENV{ONEAPI_ROOT}) + message(FATAL_ERROR "Not detect ENV {ONEAPI_ROOT}, please install oneAPI & source it, like: source /opt/intel/oneapi/setvars.sh") + endif() + find_package(IntelSYCL REQUIRED) + set(CMAKE_CXX_STANDARD 17) + add_compile_options(-I./) #include DPCT + add_compile_options(-I/${SYCL_INCLUDE_DIR}) + + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=spir64 -L${MKLROOT}/lib") + set(SRC_FILES ${SYCL_FILES}) + target_link_libraries(bitsandbytes PUBLIC OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread oneDNN) else() string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) endif() - - + if(WIN32) # Export all symbols set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 6403ed841..8e424273c 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1364,8 +1364,8 @@ int roundoff(int v, int d) { } -#ifdef NO_CUBLASLT -#else +//#ifdef NO_CUBLASLT +//#else template cublasLtOrder_t get_order() { switch(ORDER) From 89fb73b59212710c47d23a63cf652ec89edd4a4a Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 8 Apr 2024 23:50:55 -0700 Subject: [PATCH 27/66] cmake fix --- CMakeLists.txt | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c5c74bf53..3984ebd34 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ # Separate by semicolons, i.e. `-DCOMPUTE_CAPABILITY=89;90` # Check your compute capability here: https://developer.nvidia.com/cuda-gpus # - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler -cmake_minimum_required(VERSION 3.22.1) +cmake_minimum_required(VERSION 3.20.4) project(bitsandbytes LANGUAGES CXX) @@ -200,7 +200,7 @@ elseif(BUILD_SYCL) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=spir64 -L${MKLROOT}/lib") set(SRC_FILES ${SYCL_FILES}) - target_link_libraries(bitsandbytes PUBLIC OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread oneDNN) + else() string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) @@ -218,10 +218,13 @@ endif() set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) add_library(bitsandbytes SHARED ${SRC_FILES}) -target_compile_features(bitsandbytes PUBLIC cxx_std_14) +if(BUILD_SYCL) + target_compile_features(bitsandbytes PUBLIC cxx_std_17) +else() + target_compile_features(bitsandbytes PUBLIC cxx_std_14) +endif() target_include_directories(bitsandbytes PUBLIC csrc include) - if(BUILD_CUDA) target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse) @@ -241,6 +244,9 @@ if(BUILD_MPS) target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") endif() +if(BUILD_SYCL) + target_link_libraries(bitsandbytes PUBLIC OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread oneDNN) +endif() if(WIN32) set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") endif() From 2644374f7bcbbd0223df8e70714c49d21411afcd Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Tue, 9 Apr 2024 04:04:11 -0700 Subject: [PATCH 28/66] fix build --- CMakeLists.txt | 4 +- csrc/sycl/kernels.cpp | 132 ++++++++++++++++++------------------ csrc/sycl/kernels.h | 22 +++--- csrc/sycl/ops.cpp | 151 +++++++++++++++++++++++------------------- 4 files changed, 163 insertions(+), 146 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3984ebd34..e5487a6c5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,7 +199,7 @@ elseif(BUILD_SYCL) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=spir64 -L${MKLROOT}/lib") - set(SRC_FILES ${SYCL_FILES}) + list(APPEND SRC_FILES ${SYCL_FILES}) else() string(APPEND BNB_OUTPUT_NAME "_cpu") @@ -223,7 +223,7 @@ if(BUILD_SYCL) else() target_compile_features(bitsandbytes PUBLIC cxx_std_14) endif() -target_include_directories(bitsandbytes PUBLIC csrc include) +target_include_directories(bitsandbytes PUBLIC csrc csrc/sycl include) if(BUILD_CUDA) target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index ac6c6d383..1f226ac5a 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -632,12 +632,12 @@ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, cons #define THREADS_ESTIMATE 512 #define NUM_ESTIMATE 8 #define BLOCK_ESTIMATE 4096 -typedef sycl::accessor sycl_la_float; -typedef sycl::accessor sycl_la_T; -typedef sycl::accessor sycl_la_unsigned_char; -typedef sycl::accessor sycl_la_half; -typedef sycl::accessor sycl_la_unsigned; -typedef sycl::accessor sycl_la_char; +typedef sycl::local_accessor sycl_la_float; +typedef sycl::local_accessor sycl_la_T; +typedef sycl::local_accessor sycl_la_unsigned_char; +typedef sycl::local_accessor sycl_la_half; +typedef sycl::local_accessor sycl_la_unsigned; +typedef sycl::local_accessor sycl_la_char; template SYCL_EXTERNAL @@ -686,7 +686,7 @@ void kEstimateQuantiles(const T *buff_A, float *code, const float offset, const // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], vals); + group_load(tmp).load(item, &buff_A[0], vals); #pragma unroll 4 for(int j = 0; j < NUM_ESTIMATE; j++) @@ -793,7 +793,7 @@ void kQuantize(float * code, float * __restrict__ const buff_A, unsigned char *b // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], vals); + group_load(tmp).load(item, &buff_A[0], vals); //LoadFloat(loadf).Load(&(A[i]), vals, valid_items); @@ -818,7 +818,7 @@ void kQuantize(float * code, float * __restrict__ const buff_A, unsigned char *b // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_out[0], qvals); + group_store(tmp).store(item, &buff_out[0], qvals); //StoreChar(storec).Store(&(out[i]), qvals, valid_items); } } @@ -876,7 +876,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], vals); + group_load(tmp).load(item, &buff_A[0], vals); // 1. compute local max // 2. broadcast local max @@ -925,7 +925,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_rand[0], rand_vals); + group_load(tmp).load(item, &buff_rand[0], rand_vals); } @@ -978,7 +978,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_out[0], qvals); + group_store(tmp).store(item, &buff_out[0], qvals); //StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); } @@ -1037,7 +1037,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * buff_A, flo // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], qvals); + group_load(tmp).load(item, &buff_A[0], qvals); @@ -1092,7 +1092,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * buff_A, flo // 6. store with byte index auto *tmp = stacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_out[0], vals); + group_store(tmp).store(item, &buff_out[0], vals); } } @@ -1178,7 +1178,7 @@ void kPreconditionOptimizer32bit2State(T* buff_g, T* p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], g_vals); + group_load(tmp).load(item, &buff_A[0], g_vals); /* DPCT1065:98: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1196,7 +1196,7 @@ void kPreconditionOptimizer32bit2State(T* buff_g, T* p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state[0], s1_vals); + group_load(tmp).load(item, &buff_state[0], s1_vals); /* @@ -1215,7 +1215,7 @@ void kPreconditionOptimizer32bit2State(T* buff_g, T* p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float2.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state2[0], s2_vals); + group_load(tmp).load(item, &buff_state2[0], s2_vals); # pragma unroll NUM_VALS for(unsigned int j = 0; j < NUM_VALS; j++) @@ -1325,7 +1325,7 @@ void kOptimizer32bit2State(T* buff_g, T* buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + group_load(tmp).load(item, &buff_g[0], g_vals); /* @@ -1343,7 +1343,7 @@ void kOptimizer32bit2State(T* buff_g, T* buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], s1_vals); + group_load(tmp).load(item, &buff_state1[0], s1_vals); /* DPCT1065:106: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1362,7 +1362,7 @@ void kOptimizer32bit2State(T* buff_g, T* buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float2.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state2[0], s2_vals); + group_load(tmp).load(item, &buff_state2[0], s2_vals); /* DPCT1065:107: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1380,7 +1380,7 @@ void kOptimizer32bit2State(T* buff_g, T* buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_p[0], p_vals); + group_load(tmp).load(item, &buff_p[0], p_vals); @@ -1422,7 +1422,7 @@ void kOptimizer32bit2State(T* buff_g, T* buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_T.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_p[0], p_vals); + group_store(tmp).store(item, &buff_p[0], p_vals); /* @@ -1441,7 +1441,7 @@ void kOptimizer32bit2State(T* buff_g, T* buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_float1.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_state1[0], s1_vals); + group_store(tmp).store(item, &buff_state1[0], s1_vals); /* DPCT1065:110: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1458,7 +1458,7 @@ void kOptimizer32bit2State(T* buff_g, T* buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_float2.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_state2[0], s2_vals); + group_store(tmp).store(item, &buff_state2[0], s2_vals); } @@ -1513,7 +1513,7 @@ void kPreconditionOptimizer32bit1State(T* buff_g, T* buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + group_load(tmp).load(item, &buff_g[0], g_vals); /* @@ -1532,7 +1532,7 @@ void kPreconditionOptimizer32bit1State(T* buff_g, T* buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_stat1[0], s1_vals); + group_load(tmp).load(item, &buff_stat1[0], s1_vals); # pragma unroll NUM_VALS @@ -1652,7 +1652,7 @@ void kOptimizer32bit1State(T *buff_g, T *buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + group_load(tmp).load(item, &buff_g[0], g_vals); /* @@ -1671,7 +1671,7 @@ void kOptimizer32bit1State(T *buff_g, T *buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], s1_vals); + group_load(tmp).load(item, &buff_state1[0], s1_vals); /* @@ -1690,7 +1690,7 @@ void kOptimizer32bit1State(T *buff_g, T *buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_p[0], p_vals); + group_load(tmp).load(item, &buff_p[0], p_vals); # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) @@ -1747,7 +1747,7 @@ void kOptimizer32bit1State(T *buff_g, T *buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_T.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_p[0], p_vals); + group_store(tmp).store(item, &buff_p[0], p_vals); /* DPCT1065:127: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1765,7 +1765,7 @@ void kOptimizer32bit1State(T *buff_g, T *buff_p, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_float1.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_state1[0], s1_vals); + group_store(tmp).store(item, &buff_state1[0], s1_vals); } @@ -1850,7 +1850,7 @@ kPreconditionOptimizerStatic8bit2State(T* buff_p, T* __restrict__ const buff_g, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + group_load(tmp).load(item, &buff_g[0], g_vals); /* DPCT1065:153: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1868,7 +1868,7 @@ kPreconditionOptimizerStatic8bit2State(T* buff_p, T* __restrict__ const buff_g, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], m_c1); + group_load(tmp).load(item, &buff_state1[0], m_c1); /* DPCT1065:154: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -1886,7 +1886,7 @@ kPreconditionOptimizerStatic8bit2State(T* buff_p, T* __restrict__ const buff_g, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ;ltacc_float2.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state2[0], r_c2); + group_load(tmp).load(item, &buff_state2[0], r_c2); /* DPCT1065:155: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2059,7 +2059,7 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + group_load(tmp).load(item, &buff_g[0], g_vals); /* @@ -2078,7 +2078,7 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], c1s); + group_load(tmp).load(item, &buff_state1[0], c1s); /* @@ -2097,7 +2097,7 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float2.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state2[0], c2s); + group_load(tmp).load(item, &buff_state2[0], c2s); /* @@ -2116,7 +2116,7 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_p[0], p_vals); + group_load(tmp).load(item, &buff_p[0], p_vals); if((i + (item_ct1.get_local_id(2)*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } @@ -2168,7 +2168,7 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_T.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_p[0], p_vals); + group_store(tmp).store(item, &buff_p[0], p_vals); /* @@ -2187,7 +2187,7 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_float1.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_state1[0], c1s); + group_store(tmp).store(item, &buff_state1[0], c1s); /* @@ -2206,7 +2206,7 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_float2.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_state2[0], c2s); + group_store(tmp).store(item, &buff_state2[0], c2s); /* @@ -2289,7 +2289,7 @@ kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + group_load(tmp).load(item, &buff_g[0], g_vals); /* DPCT1065:136: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2307,7 +2307,7 @@ kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], m_c1); + group_load(tmp).load(item, &buff_state1[0], m_c1); #pragma unroll 16 for(int j = 0; j < NUM8BIT; j++) @@ -2439,7 +2439,7 @@ kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + group_load(tmp).load(item, &buff_g[0], g_vals); /* @@ -2458,7 +2458,7 @@ kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], c1s); + group_load(tmp).load(item, &buff_state1[0], c1s); /* @@ -2477,7 +2477,7 @@ kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_p[0], p_vals); + group_load(tmp).load(item, &buff_p[0], p_vals); if((i + (item_ct1.get_local_id(2)*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } @@ -2545,7 +2545,7 @@ kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_T.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_p[0], p_vals); + group_store(tmp).store(item, &buff_p[0], p_vals); /* DPCT1065:143: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2563,7 +2563,7 @@ kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_float1.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_state1[0], c1s); + group_store(tmp).store(item, &buff_state1[0], c1s); /* DPCT1065:144: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2610,7 +2610,7 @@ SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ buff_g, float *gnorm_vec // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], vals); + group_load(tmp).load(item, &buff_g[0], vals); #pragma unroll NUM_VALS for(int j = 0; j < NUM_VALS; j++) @@ -2745,7 +2745,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + group_load(tmp).load(item, &buff_g[0], g_vals); /* DPCT1065:176: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2763,7 +2763,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], c1s); + group_load(tmp).load(item, &buff_state1[0], c1s); /* DPCT1065:177: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -2781,7 +2781,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float2.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state2[0], c2s); + group_load(tmp).load(item, &buff_state2[0], c2s); new_local_abs_max1 = -FLT_MAX; @@ -2857,7 +2857,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_p[0], p_vals); + group_load(tmp).load(item, &buff_p[0], p_vals); // reduce: 2.67/1.69 -> 2.67/1.70 # pragma unroll N_PER_TH @@ -2889,7 +2889,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_T.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_p[0], p_vals); + group_store(tmp).store(item, &buff_p[0], p_vals); // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH @@ -2925,7 +2925,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_float1.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_state1[0], c1s); + group_store(tmp).store(item, &buff_state1[0], c1s); /* @@ -2944,7 +2944,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_float2.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_state2[0], c2s); + group_store(tmp).store(item, &buff_state2[0], c2s); } } @@ -3042,7 +3042,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_g[0], g_vals); + group_load(tmp).load(item, &buff_g[0], g_vals); /* DPCT1065:192: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -3060,7 +3060,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_state1[0], c1s); + group_load(tmp).load(item, &buff_state1[0], c1s); /* DPCT1065:193: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. @@ -3078,7 +3078,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = ltacc_T1.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_p[0], p_vals); + group_load(tmp).load(item, &buff_p[0], p_vals); new_local_abs_max1 = -FLT_MAX; @@ -3190,7 +3190,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_T.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_p[0], p_vals); + group_store(tmp).store(item, &buff_p[0], p_vals); // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH @@ -3225,7 +3225,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = stacc_float1.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_state1[0], c1s); + group_store(tmp).store(item, &buff_state1[0], c1s); } } @@ -3316,7 +3316,7 @@ template().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], local_data); + group_load(tmp).load(item, &buff_A[0], local_data); #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) @@ -3551,7 +3551,7 @@ template void kdequant_mm_i [=](sycl::nd_item<3> item) { auto *d = dacc.get_multi_ptr().get(); auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item,item.get_local_linear_id(), d, local_values); + group_load(tmp).load(item, d, local_values); }); }); @@ -3681,7 +3681,7 @@ template ().get(); - group_load(tmp).load(item,item.get_local_linear_id(), &buff_A[0], local_data); + group_load(tmp).load(item, &buff_A[0], local_data); float row_stat = 127.0f / smem_row_stats[row]; @@ -3724,7 +3724,7 @@ template ().get(); - group_store(tmp).store(item,item.get_local_linear_id(), &buff_out_row_normed[0], local_quantized_data); + group_store(tmp).store(item, &buff_out_row_normed[0], local_quantized_data); // 2. quantize data with row/col stats #pragma unroll ITEMS_PER_THREAD @@ -3762,7 +3762,7 @@ template item) { auto *d = dacc.get_multi_ptr().get(); auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item,item.get_local_linear_id(), d, local_quantized_data); + group_store(tmp).store(item, d, local_quantized_data); }); }); diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 018507610..5f1d4166a 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -6,19 +6,19 @@ #include #include #include -#include "ops.dp.hpp" +#include "ops.h" #ifndef kernels #define kernels #pragma once -typedef sycl::accessor sycl_la_float; -typedef sycl::accessor sycl_la_T; -typedef sycl::accessor sycl_la_unsigned_char; -typedef sycl::accessor sycl_la_half; -typedef sycl::accessor sycl_la_unsigned; -typedef sycl::accessor sycl_la_char; +typedef sycl::local_accessor sycl_la_float; +typedef sycl::local_accessor sycl_la_T; +typedef sycl::local_accessor sycl_la_unsigned_char; +typedef sycl::local_accessor sycl_la_half; +typedef sycl::local_accessor sycl_la_unsigned; +typedef sycl::local_accessor sycl_la_char; //template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); @@ -26,7 +26,7 @@ template extern SYCL_EXTERNAL void kEstimateQuantiles(T *__restrict_ const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc); extern SYCL_EXTERNAL void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_float ltacc, sycl_la_unsigned_char stacc); + const sycl::nd_item<3> &item_ct1, float* smem_code, sycl_la_float ltacc, sycl_la_unsigned_char stacc); extern SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code); @@ -49,7 +49,7 @@ extern SYCL_EXTERNAL void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2 + const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); template @@ -65,7 +65,7 @@ extern SYCL_EXTERNAL void kOptimizer32bit1State(T* g, T* p, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, - sycl_la_T stacc_T, sycl_la_float, stacc_float1); + sycl_la_T stacc_T, sycl_la_float,sycl_la_float stacc_float1); template extern SYCL_EXTERNAL void @@ -138,7 +138,7 @@ template extern SYCL_EX float *smem_exchange1, float *smem_exchange2, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, - sycl_la_T stacc_T, sycl_la_float1 stacc_float1, sycl_la_float stacc_float2); + sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); template extern SYCL_EXTERNAL void kOptimizerStatic8bit1StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 8e424273c..34ec419e5 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -63,7 +63,8 @@ template void estimateQuantiles(T *A, float *code, float offset, in sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; - *((T **)&buff_A) = sycl::malloc_device(size, A, ctx); + T *buff_A; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); q_ct1.memcpy((T*)(buff_A), (T*)(A), NUM_BLOCK); //sycl::buffer buff_A(A,sycl::range<1>(num_blocks)); @@ -78,7 +79,7 @@ template void estimateQuantiles(T *A, float *code, float offset, in using group_load = dpct::group::workgroup_load; using group_radix_sort = dpct::group::radix_sort; size_t sort_temp_storage_size = group_radix_sort::get_local_memory_size(NUM_BLOCK); - sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::local_accessor tacc(sort_temp_storage_size, cgh); /* DPCT1054:293: The type of variable temp_storage is declared in device function with the name type_ct1. Adjust the code to make the type_ct1 declaration visible at the accessor declaration point. @@ -108,10 +109,12 @@ void quantize(float *code, float *A, unsigned char *out, int n) sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; - *((float **)&buff_A) = sycl::malloc_device(size, A, ctx); - *((unsigned char **)&buff_out = sycl::malloc_device(size, out, ctx); - q_ct1.memcpy((float*)(buff_A), (float*)(A), NUM_BLOCK); - q_ct1.memcpy((unsigned char*)(buff_out), (unsigned char*)(out), NUM_BLOCK); + float *buff_A; + unsigned char *buff_out; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_A), (void*)(A), NUM_BLOCK); + q_ct1.memcpy((void*)(buff_out), (void*)(out), NUM_BLOCK); /* DPCT1049:55: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. @@ -139,8 +142,8 @@ void quantize(float *code, float *A, unsigned char *out, int n) }); } //back memcpy - q_ct1.memcpy((float*)(A), (float*)(buff_A), NUM_BLOCK); - q_ct1.memcpy((unsigned char*)(out), (unsigned char*)(buff_out), NUM_BLOCK); + q_ct1.memcpy((void*)(A), (void *)(buff_A), NUM_BLOCK); + q_ct1.memcpy((void*)(out), (void*)(buff_out), NUM_BLOCK); /* DPCT1010:232: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ @@ -156,10 +159,12 @@ void dequantize(float *code, unsigned char *A, float *out, int n) sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; - *((unsigned char **)&buff_A) = sycl::malloc_device(size, A, ctx); - *((float **)&buff_out = sycl::malloc_device(size, out, ctx); - q_ct1.memcpy((float*)(buff_out), (float*)(out), NUM_BLOCK); - q_ct1.memcpy((unsigned char*)(buff_A), (unsigned char*)(A), NUM_BLOCK); + unsigned char *buff_A; + float *buff_out; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_out), (void*)(out), NUM_BLOCK); + q_ct1.memcpy((void*)(buff_A), (void*)(A), NUM_BLOCK); /* DPCT1049:56: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. @@ -182,8 +187,8 @@ void dequantize(float *code, unsigned char *A, float *out, int n) } //back memcpy - q_ct1.memcpy((float*)(out), (float*)(buff_out), NUM_BLOCK); - q_ct1.memcpy((unsigned char*)(A), (unsigned char*)(buff_A), NUM_BLOCK); + q_ct1.memcpy((void *)(out), (void*)(buff_out), NUM_BLOCK); + q_ct1.memcpy((void*)(A), (void*)(buff_A), NUM_BLOCK); /* DPCT1010:233: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ @@ -199,14 +204,17 @@ template void quantizeBlockwise(floa sycl::context ctx = q_ct1.get_context(); int size= NUM_BLOCK; - *((T **)&buff_A) = sycl::malloc_device(size, A, ctx); - *((unsigned char **)&buff_out = sycl::malloc_device(size, out, ctx); - *((float **)&buff_rand = sycl::malloc_device(size, rand, ctx); + T *buff_A; + unsigned char *buff_out; + float *buff_rand; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_rand) = sycl::malloc_device(size, dev_ct1, ctx); q_ct1.memcpy((T*)(buff_A), (T*)(A), NUM_BLOCK); - q_ct1.memcpy((unsigned char*)(buff_out), (unsigned char*)(out), NUM_BLOCK); - q_ct1.memcpy((float*)(buff_rand), (float*)(rand), NUM_BLOCK); + q_ct1.memcpy((void*)(buff_out), (void*)(out), NUM_BLOCK); + q_ct1.memcpy((void*)(buff_rand), (void*)(rand), NUM_BLOCK); - for(int i=0; i< NUM_BLOCK; i++){ buff_out[i]=buff_out[(DATA_TYPE > 0) ? i/2 : i]}; + for(int i=0; i< NUM_BLOCK; i++){ buff_out[i]=buff_out[(DATA_TYPE > 0) ? i/2 : i];}; if(blocksize == 4096) /* @@ -418,9 +426,9 @@ template void quantizeBlockwise(floa DPCT1010:234: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ //back memcpy - q_ct1.memcpy((T*)(A), (T*)(buff_A), NUM_BLOCK); - q_ct1.memcpy((unsigned char*)(out), (unsigned char*)(buff_out), NUM_BLOCK); - q_ct1.memcpy((float*)(rand), (float*)(buff_rand), NUM_BLOCK); + q_ct1.memcpy((void*)(A), (void*)(buff_A), NUM_BLOCK); + q_ct1.memcpy((void*)(out), (void*)(buff_out), NUM_BLOCK); + q_ct1.memcpy((void*)(rand), (void*)(buff_rand), NUM_BLOCK); //CUDA_CHECK_RETURN(0); } @@ -433,9 +441,11 @@ template void dequantizeBlockwise(float *code, unsign int tile_size = (DATA_TYPE > 0) ? 1024 : 512; sycl::context ctx = q_ct1.get_context(); - *((unsigned char **)&buff_A) = sycl::malloc_device(tile_size, A, ctx); - *((T **)&buff_out = sycl::malloc_device(tile_size, out, ctx); - q_ct1.memcpy((unsigned char*)(buff_A), (unsigned char*)(A), tile_size); + unsigned char *buff_A; + T *buff_out; + *((void **)&buff_A) = sycl::malloc_device(tile_size, dev_ct1, ctx); + *((T **)&buff_out) = sycl::malloc_device(tile_size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_A), (void*)(A), tile_size); q_ct1.memcpy((T*)(buff_out), (T*)(out), tile_size); if(DATA_TYPE > 0) @@ -485,7 +495,7 @@ template void dequantizeBlockwise(float *code, unsign DPCT1010:235: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ //back memcpy - q_ct1.memcpy((unsigned char*)(A), (unsigned char*)(buff_A), tile_size); + q_ct1.memcpy((void*)(A), (void*)(buff_A), tile_size); q_ct1.memcpy((T*)(out), (T*)(buff_out), tile_size); //CUDA_CHECK_RETURN(0); @@ -512,14 +522,16 @@ template void optimizer32bit(T* g, T* p, num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; int size= NUM_BLOCK; - *((T **)&buff_g) = sycl::malloc_device(size, g, ctx); - *((T **)&buff_p) = sycl::malloc_device(size, p, ctx); - *((float **)&buff_state1 = sycl::malloc_device(size, state1, ctx); - *((float **)&buff_state2 = sycl::malloc_device(size, state2, ctx); - q_ct1.memcpy((T*)(buff_g), (T*)(g), size); - q_ct1.memcpy((T*)(buff_p), (T*)(p), size); - q_ct1.memcpy((float*)(buff_state1), (float*)(state1), size); - q_ct1.memcpy((float*)(buff_state2), (float*)(state2), size); + T *buff_g,*buff_p; + float *buff_state1,*buff_state2; + *((void **)&buff_g) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_p) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_state1) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_state2) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_g), (void*)(g), size); + q_ct1.memcpy((void*)(buff_p), (void*)(p), size); + q_ct1.memcpy((void*)(buff_state1), (void*)(state1), size); + q_ct1.memcpy((void*)(buff_state2), (void*)(state2), size); switch(OPTIMIZER) @@ -785,10 +797,10 @@ template void optimizer32bit(T* g, T* p, } //back memcpy - q_ct1.memcpy((T*)(g), (T*)(buff_g), size); - q_ct1.memcpy((T*)(p), (T*)(buff_p), size); - q_ct1.memcpy((float*)(state1), (float*)(buff_state1), size); - q_ct1.memcpy((float*)(state2), (float*)(buff_state2), size); + q_ct1.memcpy((void*)(g), (void*)(buff_g), size); + q_ct1.memcpy((void*)(p), (void*)(buff_p), size); + q_ct1.memcpy((void*)(state1), (void*)(buff_state1), size); + q_ct1.memcpy((void*)(state2), (void*)(buff_state2), size); } catch (sycl::exception const &exc) { @@ -813,14 +825,16 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; - *((T **)&buff_g) = sycl::malloc_device(size, g, ctx); - *((T **)&buff_p) = sycl::malloc_device(size, p, ctx); - *((float **)&buff_state1 = sycl::malloc_device(size, state1, ctx); - *((float **)&buff_state2 = sycl::malloc_device(size, state2, ctx); - q_ct1.memcpy((T*)(buff_g), (T*)(g), size); - q_ct1.memcpy((T*)(buff_p), (T*)(p), size); - q_ct1.memcpy((float*)(buff_state1), (float*)(state1), size); - q_ct1.memcpy((float*)(buff_state2), (float*)(state2), size); + T *buff_g,*buff_p; + float *buff_state1,*buff_state2; + *((void **)&buff_g) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_p) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_state1) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_state2) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_g), (void*)(g), size); + q_ct1.memcpy((void*)(buff_p), (void*)(p), size); + q_ct1.memcpy((void*)(buff_state1), (void*)(state1), size); + q_ct1.memcpy((void*)(buff_state2), (void*)(state2), size); if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); } @@ -1093,10 +1107,10 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ } //back memcpy - q_ct1.memcpy((T*)(buff_g), (T*)(g), size); - q_ct1.memcpy((T*)(buff_p), (T*)(p), size); - q_ct1.memcpy((float*)(buff_state1), (float*)(state1), size); - q_ct1.memcpy((float*)(buff_state2), (float*)(state2), size); + q_ct1.memcpy((void*)(buff_g), (void*)(g), size); + q_ct1.memcpy((void*)(buff_p), (void*)(p), size); + q_ct1.memcpy((void*)(buff_state1), (void*)(state1), size); + q_ct1.memcpy((void*)(buff_state2), (void*)(state2), size); } catch (sycl::exception const &exc) { @@ -1120,14 +1134,16 @@ template void optimizerStatic8bitBlockwise(T* buff_p, int num_blocks = 0; int size = NUM_BLOCK; - *((T **)&buff_g) = sycl::malloc_device(size, g, ctx); - *((T **)&buff_p) = sycl::malloc_device(size, p, ctx); - *((float **)&buff_state1 = sycl::malloc_device(size, state1, ctx); - *((float **)&buff_state2 = sycl::malloc_device(size, state2, ctx); - q_ct1.memcpy((T*)(buff_g), (T*)(g), size); - q_ct1.memcpy((T*)(buff_p), (T*)(p), size); - q_ct1.memcpy((float*)(buff_state1), (float*)(state1), size); - q_ct1.memcpy((float*)(buff_state2), (float*)(state2), size); + T *buff_g,*buff_p; + float *buff_state1,*buff_state2; + *((void **)&buff_g) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_p) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_state1) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_state2) = sycl::malloc_device(size, dev_Ct1, ctx); + q_ct1.memcpy((void*)(buff_g), (void*)(g), size); + q_ct1.memcpy((void*)(buff_p), (void*)(p), size); + q_ct1.memcpy((void*)(buff_state1), (void*)(state1), size); + q_ct1.memcpy((void*)(buff_state2), (void*)(state2), size); switch(OPTIMIZER) @@ -1252,10 +1268,10 @@ template void optimizerStatic8bitBlockwise(T* buff_p, //CUDA_CHECK_RETURN(0); break; } - q_ct1.memcpy((T*)(g), (T*)(buff_g), size); - q_ct1.memcpy((T*)(p), (T*)(buff_p), size); - q_ct1.memcpy((float*)(state1), (float*)(buff_state1), size); - q_ct1.memcpy((float*)(state2), (float*)(buff_state2), size); + q_ct1.memcpy((void*)(g), (void*)(buff_g), size); + q_ct1.memcpy((void*)(p), (void*)(buff_p), size); + q_ct1.memcpy((void*)(state1), (void*)(buff_state1), size); + q_ct1.memcpy((void*)(state2), (void*)(buff_state2), size); } catch (sycl::exception const &exc) { @@ -1274,8 +1290,9 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, int num_blocks = n/2048; num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; int size = NUM_BLOCK; - *((T **)&buff_g) = sycl::malloc_device(size, g, ctx); - q_ct1.memcpy((T*)(buff_g), (T*)(g), size); + T *buff_g; + *((void **)&buff_g) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_g), (void*)(g), size); CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(&gnorm_vec[step % 100], 0, 1*sizeof(float)).wait())); @@ -1302,7 +1319,7 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, */ //CUDA_CHECK_RETURN(0); //back memcpy - q_ct1.memcpy((T*)(g), (T*)(buff_g), size); + q_ct1.memcpy((void*)(g), (void*)(buff_g), size); } @@ -1397,7 +1414,7 @@ template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); -#endif +//#endif template int get_leading_dim(int dim1, int dim2) From 14f20e44ffed6fee6a99f594fd4a7580577190de Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Tue, 9 Apr 2024 04:30:40 -0700 Subject: [PATCH 29/66] fix build --- CMakeLists.txt | 1 + csrc/sycl/ops.cpp | 182 ++++++++++++++++++++++++---------------------- 2 files changed, 96 insertions(+), 87 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e5487a6c5..1a187c709 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,6 +199,7 @@ elseif(BUILD_SYCL) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=spir64 -L${MKLROOT}/lib") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ferror-limit=590") list(APPEND SRC_FILES ${SYCL_FILES}) else() diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 34ec419e5..8b4e1db35 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -761,7 +761,8 @@ template void optimizer32bit(T* g, T* p, if(max_unorm > 0.0f) { - CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); + //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); + DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); /* DPCT1049:64: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ @@ -1460,7 +1461,7 @@ template void trans int ldOut = get_leading_dim(dim1, dim2); int ldAOut = get_leading_dim(dim1, dim2); - dnnl::engine engine = sycL_interop::make_engine(dev, ctx); + dnnl::engine engine = dnnl::sycl_interop::make_engine(dev, ctx); // column major const memory::dims a_strides = memory::dims {1, ldA}; const auto a_md = DTYPE_OUT ==32 ? memory::desc({dim1, dim2}, dt::s32, a_strides) : memory::desc({dim1, dim2}, dt::s8, a_strides); @@ -1863,13 +1864,14 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::context ctx = q_ct1.get_context(); int num_blocks = 0; int size = NUM_BLOCK; - - *((sycl::half **)&buff_A) = sycl::malloc_device(size, A, ctx); - *((char **)&buff_out_row_normed) = sycl::malloc_device(size, out_row_normed, ctx); - *((char **)&buff_out_col_normed = sycl::malloc_device(size, out_col_normed, ctx); - q_ct1.memcpy((sycl::half*)(buff_A), (sycl::half*)(A), size); - q_ct1.memcpy((char*)(buff_out_row_normed), (char*)(out_row_normed), size); - q_ct1.memcpy((char*)(buff_out_col_normed), (char*)(out_col_normed), size); + sycl::half *buff_A, + char *buff_out_row_normed, *buff_out_col_normed; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_out_row_normed) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_out_col_normed) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_A), (void*)(A), size); + q_ct1.memcpy((void*)(buff_out_row_normed), (void*)(out_row_normed), size); + q_ct1.memcpy((void*)(buff_out_col_normed), (void*)(out_col_normed), size); int threads = 64; int items_per_thread = 4; @@ -1940,9 +1942,9 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o DPCT1010:285: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ //CUDA_CHECK_RETURN(0); - q_ct1.memcpy((sycl::half*)(A), (sycl::half*)(buff_A), size); - q_ct1.memcpy((char*)(out_row_normed), (char*)(buff_out_row_normed), size); - q_ct1.memcpy((char*)(out_col_normed), (char*)(buff_out_col_normed), size); + q_ct1.memcpy((void*)(A), (void*)(buff_A), size); + q_ct1.memcpy((void*)(out_row_normed), (void*)(buff_out_row_normed), size); + q_ct1.memcpy((void*)(out_col_normed), (void*)(buff_out_col_normed), size); } @@ -1954,11 +1956,11 @@ template void transformRowToFormat(char * A, char *o sycl::context ctx = q_ct1.get_context(); int num_blocks = 0; int size = NUM_BLOCK; - - *((char **)&buff_A) = sycl::malloc_device(size, A, ctx); - *((char **)&buff_out) = sycl::malloc_device(size, out, ctx); - q_ct1.memcpy((char*)(buff_A), (sycl::half*)(A), size); - q_ct1.memcpy((char*)(buff_out), (char*)(out), size); + char *buff_A, *buff_out; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_A), (void*)(A), size); + q_ct1.memcpy((void*)(buff_out), (void*)(out), size); int threads = 256; @@ -2022,8 +2024,8 @@ template void transformRowToFormat(char * A, char *o //CUDA_CHECK_RETURN(0); - q_ct1.memcpy((char*)(A), (sycl::half*)(buff_A), size); - q_ct1.memcpy((char*)(out), (char*)(buff_out), size); + q_ct1.memcpy((void*)(A), (void*)(buff_A), size); + q_ct1.memcpy((void*)(out), (void*)(buff_out), size); } @@ -2068,7 +2070,7 @@ void spmm_coo(sycl::queue* handle, int *A_rowidx, int *A_colidx, sycl::half *A_v descB = std::make_shared(A_cols, B_cols, ldb, B, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major); // allocate an external buffer if needed //CHECK_CUSPARSE( DPCT_CHECK_ERROR(bufferSize = 0) ); - bufferSize = 0 + bufferSize = 0; //CUDA_CHECK_RETURN( DPCT_CHECK_ERROR(dBuffer = (void *)sycl::malloc_device(bufferSize, q_ct1)) ); dBuffer = (void *)sycl::malloc_device(bufferSize, q_ct1); @@ -2164,13 +2166,14 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; - *((T **)&buff_A) = sycl::malloc_device(size, A, ctx); - q_ct1.memcpy((T*)(buff_A), (T*)(A), size); - *((T **)&buff_B) = sycl::malloc_device(size, B, ctx); - q_ct1.memcpy((T*)(buff_B), (T*)(B), size); - *((T **)&buff_out) = sycl::malloc_device(size, out, ctx); - q_ct1.memcpy((T*)(buff_out), (T*)(out), size); - + T *buff_A, *buff_B, *buff_out; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_A), (void*)(A), size); + *(( void**)&buff_B) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_B), (void*)(B), size); + *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_out), (void*)(out), size); + //cout << num_blocks << endl; //cout << lda << endl; @@ -2212,9 +2215,9 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //back memcpy - q_ct1.memcpy((T*)(A), (T*)(buff_A), size); - q_ct1.memcpy((T*)(B), (T*)(buff_B), size); - q_ct1.memcpy((T*)(out), (T*)(buff_out), size); + q_ct1.memcpy((void*)(A), (void*)(buff_A), size); + q_ct1.memcpy((void*)(B), (void*)(buff_B), size); + q_ct1.memcpy((void*)(out), (void*)(buff_out), size); } @@ -2237,12 +2240,15 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; - *((T **)&buff_A) = sycl::malloc_device(size, A, ctx); - q_ct1.memcpy((T*)(buff_A), (T*)(A), size); - *(( unsigned char**)&buff_B) = sycl::malloc_device(size, B, ctx); - q_ct1.memcpy((unsigned char*)(buff_B), (unsigned char*)(B), size); - *((T **)&buff_out) = sycl::malloc_device(size, out, ctx); - q_ct1.memcpy((T*)(buff_out), (T*)(out), size); + T *buff_A, *buff_out; + unsigned char *buff_B; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_A), (void*)(A), size); + *(( void**)&buff_B) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_B), (void*)(B), size); + *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_out), (void*)(out), size); + { dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); @@ -2272,9 +2278,9 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); //back memcpy - q_ct1.memcpy((T*)(A), (T*)(buff_A), size); - q_ct1.memcpy((unsigned char*)(B), (unsigned char*)(buff_B), size); - q_ct1.memcpy((T*)(out), (T*)(buff_out), size); + q_ct1.memcpy((void*)(A), (void*)(buff_A), size); + q_ct1.memcpy((void*)(B), (void*)(buff_B), size); + q_ct1.memcpy((void*)(out), (void*)(buff_out), size); } @@ -2287,12 +2293,14 @@ template void gemm_4bit_inference_naive(int m, int n, int sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; - *((T **)&buff_A) = sycl::malloc_device(size, A, ctx); - q_ct1.memcpy((T*)(buff_A), (T*)(A), size); - *(( unsigned char**)&buff_B) = sycl::malloc_device(size, B, ctx); - q_ct1.memcpy((unsigned char*)(buff_B), (unsigned char*)(B), size); - *((T **)&buff_out) = sycl::malloc_device(size, out, ctx); - q_ct1.memcpy((T*)(buff_out), (T*)(out), size); + T *buff_A, *buff_out; + unsigned char *buff_B; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_A), (void*)(A), size); + *(( void**)&buff_B) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_B), (void*)(B), size); + *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_out), (void*)(out), size); { dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); @@ -2311,9 +2319,9 @@ template void gemm_4bit_inference_naive(int m, int n, int DPCT1010:291: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ //CUDA_CHECK_RETURN(0); - q_ct1.memcpy((T*)(A), (T*)(buff_A), size); - q_ct1.memcpy((unsigned char*)(B), (unsigned char*)(buff_B), size); - q_ct1.memcpy((T*)(out), (T*)(buff_out), size); + q_ct1.memcpy((void*)(A), (void*)(buff_A), size); + q_ct1.memcpy((void*)(B), (void*)(buff_B), size); + q_ct1.memcpy((void*)(out), (void*)(buff_out), size); } @@ -2346,18 +2354,18 @@ template void func(unsigned char *A, unsigned char *B, unsi template void func(float *A, float *B, float value, long n); template void func(float *A, float *B, float value, long n); -template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, sycl::half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, bfloat16 * A, unsigned char* B, float *absmax, float *datatype, bfloat16 * out, int lda, int ldb, int ldc, int blocksize); template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); -template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); -template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); -template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, ycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); @@ -2373,31 +2381,31 @@ template void transformRowToFormat(char * A, char *out, int rows, template void transformRowToFormat(char * A, char *out, int rows, int cols); template void transformRowToFormat(char * A, char *out, int rows, int cols); -template void estimateQuantiles(half *A, float *code, float offset, int n); +template void estimateQuantiles(sycl::half *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, sycl::half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, sycl::half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, sycl::half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, sycl::half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); -template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); -template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, bfloat16 *out, int blocksize, const int n); #define MAKE_optimizer32bit(name, gtype) \ template void optimizer32bit(gtype* g, gtype* p, \ @@ -2405,17 +2413,17 @@ template void optimizer32bit(gtype* g, gtype* p, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -MAKE_optimizer32bit(ADAM, half) +MAKE_optimizer32bit(ADAM, sycl::half) MAKE_optimizer32bit(ADAM, float) -MAKE_optimizer32bit(ADAM, __nv_bfloat16) -MAKE_optimizer32bit(MOMENTUM, half) +MAKE_optimizer32bit(ADAM, bfloat16) +MAKE_optimizer32bit(MOMENTUM, sycl::half) MAKE_optimizer32bit(MOMENTUM, float) -MAKE_optimizer32bit(RMSPROP, half) +MAKE_optimizer32bit(RMSPROP, sycl::half) MAKE_optimizer32bit(RMSPROP, float) -MAKE_optimizer32bit(LION, half) +MAKE_optimizer32bit(LION, sycl::half) MAKE_optimizer32bit(LION, float) -MAKE_optimizer32bit(LION, __nv_bfloat16) -MAKE_optimizer32bit(ADAGRAD, half) +MAKE_optimizer32bit(LION, bfloat16) +MAKE_optimizer32bit(ADAGRAD, sycl::half) MAKE_optimizer32bit(ADAGRAD, float) #define MAKE_optimizerStatic8bit(name, gtype) \ @@ -2428,13 +2436,13 @@ template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char float weight_decay, \ const float gnorm_scale, int n); \ -MAKE_optimizerStatic8bit(ADAM, half) +MAKE_optimizerStatic8bit(ADAM, sycl::half) MAKE_optimizerStatic8bit(ADAM, float) -MAKE_optimizerStatic8bit(MOMENTUM, half) +MAKE_optimizerStatic8bit(MOMENTUM, sycl::half) MAKE_optimizerStatic8bit(MOMENTUM, float) -MAKE_optimizerStatic8bit(RMSPROP, half) +MAKE_optimizerStatic8bit(RMSPROP, sycl::half) MAKE_optimizerStatic8bit(RMSPROP, float) -MAKE_optimizerStatic8bit(LION, half) +MAKE_optimizerStatic8bit(LION, sycl::half) MAKE_optimizerStatic8bit(LION, float) #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ @@ -2444,17 +2452,17 @@ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g MAKE_optimizerStatic8bitBlockwise(half, ADAM); MAKE_optimizerStatic8bitBlockwise(float, ADAM); -MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(sycl::half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); -MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(sycl::half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); -MAKE_optimizerStatic8bitBlockwise(half, LION); +MAKE_optimizerStatic8bitBlockwise(sycl::half, LION); MAKE_optimizerStatic8bitBlockwise(float, LION); -MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION); -MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(bfloat16, LION); +MAKE_optimizerStatic8bitBlockwise(sycl::half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); -template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); +template void percentileClipping(sycl::half * g, float *gnorm_vec, int step, const int n); -MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); +MAKE_optimizerStatic8bitBlockwise(bfloat16, ADAM); From d151d0c0ef9d60d819b2253b0f54be13ad1f0dbc Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 10 Apr 2024 03:53:31 -0700 Subject: [PATCH 30/66] fix build --- CMakeLists.txt | 3 +- csrc/sycl/kernels.cpp | 107 ++++------ csrc/sycl/kernels.h | 13 +- csrc/sycl/ops.cpp | 464 +++++++++++++----------------------------- 4 files changed, 184 insertions(+), 403 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1a187c709..5332a5536 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -200,6 +200,7 @@ elseif(BUILD_SYCL) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=spir64 -L${MKLROOT}/lib") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ferror-limit=590") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w") list(APPEND SRC_FILES ${SYCL_FILES}) else() @@ -246,7 +247,7 @@ if(BUILD_MPS) endif() if(BUILD_SYCL) - target_link_libraries(bitsandbytes PUBLIC OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread oneDNN) + target_link_libraries(bitsandbytes PUBLIC OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread onednn mkl_dnn) endif() if(WIN32) set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 1f226ac5a..fe56651b7 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -1595,7 +1595,7 @@ void kOptimizer32bit1State(T *buff_g, T *buff_p, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, - sycl_la_T stacc_T, sycl_la_float, stacc_float1) + sycl_la_T stacc_T, sycl_la_float stacc_float1) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -3433,7 +3433,7 @@ template void kgetColRowStats(sycl::half * __res #define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) -template void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats) +template void kdequant_mm_int32_fp16(int *__restrict__ const buff_A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats, sycl_la_T ltacc_T, sycl_la_float exacc) { // Strategy: To dequantize we need to load col/row statistics. This can be very expensive @@ -3490,8 +3490,7 @@ template void kdequant_mm_i //typedef cub::BlockLoad LoadInt32; //typedef cub::BlockExchange ExchangeInt32; - sycl::buffer buff_A(A, sycl::range<1>(ITEMS_PER_THREAD)); - + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. float colStat = col >= numCols ? 0.0f : colStats[col]; @@ -3532,54 +3531,29 @@ template void kdequant_mm_i DPCT1007:206: Migration of cub::BlockLoad::Load is not supported. */ //LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_load = dpct::group::workgroup_load; - size_t temp_storage_size = group_load::get_local_memory_size(ITEMS_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_A[subtile_idx], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item, d, local_values); - }); - - }); - /* + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = ltacc_T.get_multi_ptr().get(); + group_load(tmp).load(item, &buff_A[0], local_values); + + /* DPCT1007:207: Migration of cub::BlockExchange::BlockedToWarpStriped is not supported. */ //ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - - - using group_exchange = dpct::group::exchange; - size_t temp_storage_size = group_exchange::get_local_memory_size(ITEMS_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *tmp = tacc.get_multi_ptr().get(); - group_exchange(tmp).blocked_to_warpstriped(item, local_values); - }); - - }); + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + auto *tmp = exacc.get_multi_ptr().get(); + group_exchange(tmp).blocked_to_warpstriped(item, local_values); #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) @@ -3743,29 +3717,16 @@ template ; - size_t temp_storage_size = group_store::get_local_memory_size(ITEMS_PER_THREAD); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_out_col_normed[i], h, sycl::read_write); - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); - auto *tmp = tacc.get_multi_ptr().get(); - group_store(tmp).store(item, d, local_quantized_data); - }); - - }); + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + + auto *tmp = stacc_char2.get_multi_ptr().get(); + group_store(tmp).store(item, &buff_out_col_normed[0], local_quantized_data); + } } @@ -5335,7 +5296,7 @@ SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ template void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1,sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_T stacc_T, sycl_la_float, stacc_float1); \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1,sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_T stacc_T, sycl_la_float stacc_float1); \ SYCL_EXTERNAL MAKE_Optimizer32bit1State(MOMENTUM, sycl::half) SYCL_EXTERNAL MAKE_Optimizer32bit1State(MOMENTUM, float) diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 5f1d4166a..8060eaec3 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -65,7 +65,7 @@ extern SYCL_EXTERNAL void kOptimizer32bit1State(T* g, T* p, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, - sycl_la_T stacc_T, sycl_la_float,sycl_la_float stacc_float1); + sycl_la_T stacc_T, sycl_la_float stacc_float1); template extern SYCL_EXTERNAL void @@ -159,7 +159,7 @@ template extern SYCL_EX template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n,const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T); -void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, +extern SYCL_EXTERNAL void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, const sycl::nd_item<3> &item_ct1); template @@ -177,7 +177,7 @@ extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( float *__restrict__ const colStats, sycl::half *out, float *newRowStats, float *newcolStats, sycl::half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_rowStats); + const sycl::nd_item<3> &item_ct1, float *smem_rowStats, sycl_la_T ltacc_T, sycl_la_float exacc); template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1,float *smem_row_absmax_values,int *smem_row_nnz_values, sycl_la_half ltacc_half, sycl_la_unsigned exacc); @@ -202,8 +202,11 @@ template extern SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1); -template extern SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); -template extern SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template extern SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B); +template extern SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, + sycl::half *smem_A, + sycl::half *smem_B, + sycl::half *smem_C); template extern SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *quant_map); diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 8b4e1db35..f87ab6e8a 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -28,7 +28,9 @@ #define THREADS_ESTIMATE 512 #define NUM_ESTIMATE 8 #define BLOCK_ESTIMATE 4096 +using namespace dnnl; +typedef sycl::ext::oneapi::bfloat16 bf16; using namespace BinSearch; using std::cout; @@ -59,7 +61,8 @@ template void estimateQuantiles(T *A, float *code, float offset, in sycl::queue &q_ct1 = dev_ct1.in_order_queue(); int num_blocks = n/4096; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(code, 0, 256*sizeof(float)).wait())); + //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(code, 0, 256*sizeof(float)).wait())); + DPCT_CHECK_ERROR(q_ct1.memset(code, 0, 256*sizeof(float)).wait()); sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; @@ -444,7 +447,7 @@ template void dequantizeBlockwise(float *code, unsign unsigned char *buff_A; T *buff_out; *((void **)&buff_A) = sycl::malloc_device(tile_size, dev_ct1, ctx); - *((T **)&buff_out) = sycl::malloc_device(tile_size, dev_ct1, ctx); + *((void **)&buff_out) = sycl::malloc_device(tile_size, dev_ct1, ctx); q_ct1.memcpy((void*)(buff_A), (void*)(A), tile_size); q_ct1.memcpy((T*)(buff_out), (T*)(out), tile_size); @@ -486,7 +489,7 @@ template void dequantizeBlockwise(float *code, unsign cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, (n+tile_size-1)/tile_size) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { - kDequantizeBlockwise(code, buff_A, absmax, buff_out, blocksize, n, item_ct1); + kDequantizeBlockwise(code, buff_A, absmax, buff_out, blocksize, n, item_ct1, ltacc, stacc); }); }); } @@ -539,7 +542,8 @@ template void optimizer32bit(T* g, T* p, case ADAM: if(max_unorm > 0.0f) { - CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); + //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); + DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); /* DPCT1049:61: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ @@ -632,7 +636,8 @@ template void optimizer32bit(T* g, T* p, case ADAGRAD: if(max_unorm > 0.0f) { - CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); + //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); + DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); /* DPCT1049:62: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ @@ -658,7 +663,7 @@ template void optimizer32bit(T* g, T* p, cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizer32bit1State(buff_g, buff_p, buff_state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float); + kPreconditionOptimizer32bit1State(buff_g, buff_p, buff_state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float1); }); }); } @@ -809,8 +814,8 @@ catch (sycl::exception const &exc) { std::exit(1); } -template void optimizerStatic8bit(T* buff_p, T* buff_g, - unsigned char* buff_state1, unsigned char* buff_state2, +template void optimizerStatic8bit(T* p, T* g, + unsigned char* state1, unsigned char* state2, float *unorm, float max_unorm, float param_norm, float beta1, float beta2, float eps, int step, float lr, @@ -827,7 +832,7 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ int size = NUM_BLOCK; T *buff_g,*buff_p; - float *buff_state1,*buff_state2; + unsigned char *buff_state1,*buff_state2; *((void **)&buff_g) = sycl::malloc_device(size, dev_ct1, ctx); *((void **)&buff_p) = sycl::malloc_device(size, dev_ct1, ctx); *((void **)&buff_state1) = sycl::malloc_device(size, dev_ct1, ctx); @@ -838,13 +843,16 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ q_ct1.memcpy((void*)(buff_state2), (void*)(state2), size); - if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); } + if(max_unorm > 0.0f){ //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); + DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); } switch(OPTIMIZER) { case ADAM: - CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait())); - CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max2, 0, 1*sizeof(float)).wait())); + //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait())); + //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max2, 0, 1*sizeof(float)).wait())); + DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); + DPCT_CHECK_ERROR(q_ct1.memset(new_max2, 0, 1*sizeof(float)).wait()); { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( @@ -890,8 +898,8 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); - sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); + //sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); /* DPCT1054:301: The type of variable temp_storage is declared in device function with the name type_ct7. Adjust the code to make the type_ct7 declaration visible at the accessor declaration point. */ @@ -943,7 +951,8 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ case MOMENTUM: case RMSPROP: case ADAGRAD: - CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait())); + //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait())); + DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( @@ -962,7 +971,7 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); //__shared__ vars - sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh) + sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), @@ -1063,7 +1072,7 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit1State(`buff_p, buff_g, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); + kOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); }); }); } @@ -1072,7 +1081,8 @@ template void optimizerStatic8bit(T* buff_p, T* buff_ */ //CUDA_CHECK_RETURN(0); - CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait())); + //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait())); + DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( @@ -1124,8 +1134,8 @@ catch (sycl::exception const &exc) { #define BLOCKSIZE_1STATE 2048 #define NUM_1STATE 8 -template void optimizerStatic8bitBlockwise(T* buff_p, T* buff_g, - unsigned char* buff_state1, unsigned char* buff_state2, float beta1, float beta2, float eps, int step, float lr, +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) try { @@ -1136,11 +1146,11 @@ template void optimizerStatic8bitBlockwise(T* buff_p, int size = NUM_BLOCK; T *buff_g,*buff_p; - float *buff_state1,*buff_state2; + unsigned char *buff_state1,*buff_state2; *((void **)&buff_g) = sycl::malloc_device(size, dev_ct1, ctx); *((void **)&buff_p) = sycl::malloc_device(size, dev_ct1, ctx); *((void **)&buff_state1) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_state2) = sycl::malloc_device(size, dev_Ct1, ctx); + *((void **)&buff_state2) = sycl::malloc_device(size, dev_ct1, ctx); q_ct1.memcpy((void*)(buff_g), (void*)(g), size); q_ct1.memcpy((void*)(buff_p), (void*)(p), size); q_ct1.memcpy((void*)(buff_state1), (void*)(state1), size); @@ -1296,7 +1306,8 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, q_ct1.memcpy((void*)(buff_g), (void*)(g), size); - CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(&gnorm_vec[step % 100], 0, 1*sizeof(float)).wait())); + //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(&gnorm_vec[step % 100], 0, 1*sizeof(float)).wait())); + DPCT_CHECK_ERROR(q_ct1.memset(&gnorm_vec[step % 100], 0, 1*sizeof(float)).wait()); /* DPCT1049:68: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ @@ -1311,7 +1322,7 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kPercentileClipping(g, gnorm_vec, step, n, item_ct1, ltacc); + kPercentileClipping(g, gnorm_vec, step, n, item_ct1, ltacc_T); }); }); } @@ -1336,12 +1347,8 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in const void * beta = &fbeta; int status; - status = DPCT_CHECK_ERROR(dpct::gemm(*context->m_handle, transposeA ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, transposeB ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, m, n, k, alpha, A, dpct::library_data_t::real_int8, lda, B, dpct::library_data_t::real_int8, ldb, beta, C, dpct::library_data_t::real_int32, ldc, dpct::library_data_t::real_int32)); + DPCT_CHECK_ERROR(dpct::gemm(*context->m_handle, transposeA ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, transposeB ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, m, n, k, alpha, A, dpct::library_data_t::real_int8, lda, B, dpct::library_data_t::real_int8, ldb, beta, C, dpct::library_data_t::real_int32, ldc, dpct::library_data_t::real_int32)); - if (status != 0) - { - std::cout << "CUBLAS ERROR: Status " << status << std::endl; - } } catch (sycl::exception const &exc) { @@ -1364,12 +1371,7 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i //printf("%i %i %i\n", strideA, strideB, strideC); //printf("%i\n", batchCount); - status = DPCT_CHECK_ERROR(dpct::gemm_batch(*context->m_handle, transposeA ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, transposeB ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, m, n, k, alpha, A, dpct::library_data_t::real_int8, lda, (long long int)strideA, B, dpct::library_data_t::real_int8, ldb, (long long int)strideB, beta, C, dpct::library_data_t::real_int32, ldc, (long long int)strideC, batchCount, dpct::library_data_t::real_int32)); - - if (status != 0) - { - std::cout << "CUBLAS ERROR: Status " << status << std::endl; - } + DPCT_CHECK_ERROR(dpct::gemm_batch(*context->m_handle, transposeA ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, transposeB ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, m, n, k, alpha, A, dpct::library_data_t::real_int8, lda, (long long int)strideA, B, dpct::library_data_t::real_int8, ldb, (long long int)strideB, beta, C, dpct::library_data_t::real_int32, ldc, (long long int)strideC, batchCount, dpct::library_data_t::real_int32)); } catch (sycl::exception const &exc) { @@ -1382,42 +1384,6 @@ int roundoff(int v, int d) { } -//#ifdef NO_CUBLASLT -//#else -template cublasLtOrder_t get_order() -{ - switch(ORDER) - { - case ROW: - return CUBLASLT_ORDER_ROW; - break; - case COL: - return CUBLASLT_ORDER_COL; - break; - case COL32: - return CUBLASLT_ORDER_COL32; - break; - case COL_TURING: - return CUBLASLT_ORDER_COL4_4R2_8C; - break; - case COL_AMPERE: - return CUBLASLT_ORDER_COL32_2R_4R4; - break; - default: - break; - } - - return CUBLASLT_ORDER_ROW; -} - -template cublasLtOrder_t get_order(); -template cublasLtOrder_t get_order(); -template cublasLtOrder_t get_order(); -template cublasLtOrder_t get_order(); -template cublasLtOrder_t get_order(); -//#endif - - template int get_leading_dim(int dim1, int dim2) { switch(ORDER) @@ -1449,12 +1415,13 @@ template int get_leading_dim(int dim1, int dim2); template int get_leading_dim(int dim1, int dim2); template int get_leading_dim(int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) +template void transform( T *A, T *out, int dim1, int dim2) { using namespace dnnl; using tag = memory::format_tag; using dt = memory::data_type; + void *Aout; auto dev = sycl::device(sycl::gpu_selector_v); auto ctx = sycl::context(dev); int ldA = get_leading_dim(dim1, dim2); @@ -1464,16 +1431,16 @@ template void trans dnnl::engine engine = dnnl::sycl_interop::make_engine(dev, ctx); // column major const memory::dims a_strides = memory::dims {1, ldA}; - const auto a_md = DTYPE_OUT ==32 ? memory::desc({dim1, dim2}, dt::s32, a_strides) : memory::desc({dim1, dim2}, dt::s8, a_strides); + const auto a_md = DTYPE ==32 ? memory::desc({dim1, dim2}, dt::s32, a_strides) : memory::desc({dim1, dim2}, dt::s8, a_strides); const memory::dims out_strides = memory::dims {ldOut, 1}; - const auto out_md = DTYPE_OUT ==32 ? memory::desc({dim1, dim2}, dt::s32, out_strides) : memory::desc({dim1, dim2}, dt::s8, out_strides); + const auto out_md = DTYPE ==32 ? memory::desc({dim1, dim2}, dt::s32, out_strides) : memory::desc({dim1, dim2}, dt::s8, out_strides); const memory::dims Aout_strides = memory::dims {ldAOut, 1}; - const auto aout_md = DTYPE_OUT == 32 ? memory::desc({dim1, dim2}, dt::s32) : memory::desc({dim1, dim2}, dt::s8); + const auto aout_md = DTYPE == 32 ? memory::desc({dim1, dim2}, dt::s32) : memory::desc({dim1, dim2}, dt::s8); //memory align - memory a_mem(a_md, engine A); - memory out_mem(out_md, engine, Out); - memory aout_mem(aout_md, engine, AOut); + memory a_mem(a_md, engine, A); + memory out_mem(out_md, engine, out); + memory aout_mem(aout_md, engine, Aout); //create dnnl stream auto q_ct1 = sycl::queue(ctx, dev); @@ -1491,98 +1458,20 @@ template void trans matmul_prim.execute(stream, matmul_args); stream.wait(); -/* -#ifdef NO_CUBLASLT -#else - cublasLtOrder_t orderA = get_order(); - cublasLtOrder_t orderOut = get_order(); - int ldA = get_leading_dim(dim1, dim2); - int ldOut = get_leading_dim(dim1, dim2); - - cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL; - cublasLtMatrixTransformDesc_t A2Out_desc = NULL; - oneapi::mkl::transpose opTranspose = oneapi::mkl::transpose::trans; - float transformAlpha = 1.0f, transformBeta = 0.0f; - - - - - if(DTYPE == 8) - { - - DPCT1007:251: Migration of cublasLtMatrixLayoutCreate is not supported. - - checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, dpct::library_data_t::real_int8, dim1, dim2, ldA)); - - DPCT1007:252: Migration of cublasLtMatrixLayoutCreate is not supported. - - checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, dpct::library_data_t::real_int8, dim1, dim2, ldOut)); - } - else if(DTYPE == 32) - { - - DPCT1007:253: Migration of cublasLtMatrixLayoutCreate is not supported. - - checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, dpct::library_data_t::real_int32, dim1, dim2, ldA)); - - DPCT1007:254: Migration of cublasLtMatrixLayoutCreate is not supported. - - checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, dpct::library_data_t::real_int32, dim1, dim2, ldOut)); - } - else - { - printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE); - } - - - DPCT1007:255: Migration of cublasLtMatrixLayoutSetAttribute is not supported. - - checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA))); - - DPCT1007:256: Migration of cublasLtMatrixLayoutSetAttribute is not supported. - - checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut))); - - - DPCT1007:257: Migration of cublasLtMatrixTransformDescCreate is not supported. - - checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, dpct::library_data_t::real_float)); - - - DPCT1007:258: Migration of cublasLtMatrixTransformDescSetAttribute is not supported. - - if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); } - - checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0)); - - - DPCT1007:259: Migration of cublasLtMatrixLayoutDestroy is not supported. - - if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); - - DPCT1007:260: Migration of cublasLtMatrixLayoutDestroy is not supported. - - if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); - - DPCT1007:261: Migration of cublasLtMatrixTransformDescDestroy is not supported. - - if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); -#endif -*/ } -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(int8_t *A, int8_t *out, int dim1, int dim2); +template void transform( int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(int8_t *A, int8_t *out, int dim1, int dim2); +template void transform( int32_t *A, int32_t *out, int dim1, int dim2); +template void transform( int8_t *A, int8_t *out, int dim1, int dim2); +template void transform( int8_t *A, int8_t *out, int dim1, int dim2); +template void transform( int8_t *A, int8_t *out, int dim1, int dim2); +template void transform( int32_t *A, int32_t *out, int dim1, int dim2); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) try { - using namespace dnnl; + using tag = memory::format_tag; using dt = memory::data_type; auto dev = sycl::device(sycl::gpu_selector_v); @@ -1595,17 +1484,17 @@ template int igemmlt(cublasLtHandle const memory::dims b_strides = memory::dims {ldb, 1}; const auto b_md = memory::desc({k, n}, dt::s8, b_strides); const memory::dims c_strides = memory::dims {ldc, 1}; - const auto c_md = DTYPE_OUT == 32 ? memory::desc({m, n}, dt::s32, c_strides) : memory::desc({m, n}, dt::s8, c_strides); + const auto c_md = DTYPE == 32 ? memory::desc({m, n}, dt::s32, c_strides) : memory::desc({m, n}, dt::s8, c_strides); //memory align - memory a_mem(a_md, engine A); + memory a_mem(a_md, engine, A); memory b_mem(b_md, engine, B); memory c_mem(c_md, engine, C); memory scales_C_mem({{1}, dt::f32, {1}}, engine, row_scale); //create dnnl stream auto q_ct1 = sycl::queue(ctx, dev); - dnnl::stream stream = sycl_interop::make_stream(q_ct1); + dnnl::stream stream = dnnl::sycl_interop::make_stream(q_ct1); primitive_attr attr; if (SCALE_ROWS) { @@ -1625,127 +1514,6 @@ template int igemmlt(cublasLtHandle matmul_prim.execute(stream, matmul_args); stream.wait(); -//#ifdef NO_CUBLASLT -// return ERR_NOT_IMPLEMENTED; -//#else - //int has_error = 0; - //cublasLtMatmulDesc_t matmulDesc = NULL; - //cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; - //oneapi::mkl::transpose opT = oneapi::mkl::transpose::trans; - //cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; - //cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; - //cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C; - //cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4; - - /* - DPCT1007:262: Migration of cublasLtMatrixLayoutCreate is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, dpct::library_data_t::real_int8, m, k, lda)); - /* - DPCT1007:263: Migration of cublasLtMatrixLayoutCreate is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, dpct::library_data_t::real_int8, n, k, ldb)); - - /* - DPCT1007:264: Migration of cublasLtMatrixLayoutSetAttribute is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - //if(FORMATB == COL_TURING) - /* - DPCT1007:265: Migration of cublasLtMatrixLayoutSetAttribute is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); - //else - /* - DPCT1007:266: Migration of cublasLtMatrixLayoutSetAttribute is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); - - //if(DTYPE_OUT == 32) - //{ - /* - DPCT1007:267: Migration of cublasLtMatmulDescCreate is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, dpct::library_data_t::real_int32)); - /* - DPCT1007:268: Migration of cublasLtMatmulDescSetAttribute is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); - /* - DPCT1007:269: Migration of cublasLtMatrixLayoutCreate is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, dpct::library_data_t::real_int32, m, n, ldc)); - /* - DPCT1007:270: Migration of cublasLtMatrixLayoutSetAttribute is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - //int alpha = 1, beta = 0; - /* - DPCT1007:271: Migration of cublasLtMatmul is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, &q_ct1)); - //} - //else - //{ - /* - DPCT1007:272: Migration of cublasLtMatmulDescCreate is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, dpct::library_data_t::real_float)); - /* - DPCT1007:273: Migration of cublasLtMatmulDescSetAttribute is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); - /* - DPCT1007:274: Migration of cublasLtMatrixLayoutCreate is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, dpct::library_data_t::real_int8, m, n, ldc)); - /* - DPCT1007:275: Migration of cublasLtMatrixLayoutSetAttribute is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - //if(!SCALE_ROWS) - //{ - //float alpha = 1.0f, beta = 0.0f; - /* - DPCT1007:276: Migration of cublasLtMatmul is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, &q_ct1)); - //} - //else - //{ - /* - DPCT1007:277: Migration of cublasLtMatmulDescSetAttribute is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); - /* - DPCT1007:278: Migration of cublasLtMatmul is not supported. - */ - //has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, &q_ct1)); - //} - //} - - - /* - DPCT1007:279: Migration of cublasLtMatrixLayoutDestroy is not supported. - */ - //if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc)); - /* - DPCT1007:280: Migration of cublasLtMatrixLayoutDestroy is not supported. - */ - //if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc)); - /* - DPCT1007:281: Migration of cublasLtMatrixLayoutDestroy is not supported. - */ - //if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc)); - /* - DPCT1007:282: Migration of cublasLtMatmulDescDestroy is not supported. - */ - //if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); - //if(has_error == 1) - //printf("error detected"); - - //return has_error; -//#endif // NO_CUBLASLT } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; @@ -1768,12 +1536,43 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, sycl::half num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; num_blocks = num_blocks*(tileCols/32); assert(threads <= tilesize); + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); - kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); + int size= NUM_BLOCK; + int *buff_A; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_A), (void*)(A), size); + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load_T = dpct::group::workgroup_load; + using group_exchange = dpct::group::exchange; + size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); + size_t exchange_temp_storage_size = group_exchange::get_local_memory_size(NUM_BLOCK); + + sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + sycl::local_accessor exacc(exchange_temp_storage_size, cgh); + + + //__shared__ vars + sycl::local_accessor smem_rowStats_acc_ct1(sycl::range<1>(256), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE), sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE)), + [=](sycl::nd_item<3> item_ct1) { + kdequant_mm_int32_fp16<4, 128, 512>(buff_A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n, item_ct1,smem_rowStats_acc_ct1.get_pointer(), ltacc_T, exacc ); + }); + + }); /* DPCT1010:283: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ - CUDA_CHECK_RETURN(0); + //CUDA_CHECK_RETURN(0); + q_ct1.memcpy((void*)(A), (void*)(buff_A), size); } @@ -1800,8 +1599,9 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c int num_blocks = row_tiles * col_tiles; int size = NUM_BLOCK; - *((sycl::half **)&buff_A) = sycl::malloc_device(size, A, ctx); - q_ct1.memcpy((sycl::half*)(buff_A), (sycl::half*)(A), size); + sycl::half *buff_A; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_A), (void*)(A), size); if(nnz_threshold == 0.0) @@ -1818,12 +1618,17 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c sycl::local_accessor exacc(exchange_temp_storage_size, cgh); sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); + + //__shared__ vars + sycl::local_accessor smem_row_absmax_values_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_row_nnz_values_acc_ct1(sycl::range<1>(256), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { kgetColRowStats(buff_A, rowStats, colStats, - nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols, ltacc_half, exacc); + nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols,item_ct1, smem_row_absmax_values_acc_ct1.get_pointer(), smem_row_nnz_values_acc_ct1.get_pointer(), ltacc_half, exacc); }); }); } @@ -1840,12 +1645,17 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c sycl::local_accessor exacc(exchange_temp_storage_size, cgh); sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); + + //__shared__ vars + sycl::local_accessor smem_row_absmax_values_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_row_nnz_values_acc_ct1(sycl::range<1>(256), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { kgetColRowStats(buff_A, rowStats, colStats, - nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols, ltacc_half, exacc); + nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols,item_ct1, smem_row_absmax_values_acc_ct1.get_pointer(), smem_row_nnz_values_acc_ct1.get_pointer(), ltacc_half, exacc); }); }); } @@ -1854,6 +1664,8 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c DPCT1010:284: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. */ //CUDA_CHECK_RETURN(0); + q_ct1.memcpy((void*)(A), (void*)(buff_A), size); + } @@ -1864,7 +1676,7 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::context ctx = q_ct1.get_context(); int num_blocks = 0; int size = NUM_BLOCK; - sycl::half *buff_A, + sycl::half *buff_A; char *buff_out_row_normed, *buff_out_col_normed; *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); *((void **)&buff_out_row_normed) = sycl::malloc_device(size, dev_ct1, ctx); @@ -1883,7 +1695,7 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o int col_tiles = (tiledCols/tile_cols); row_tiles = row_tiles > 0 ? row_tiles : 1; col_tiles = col_tiles > 0 ? col_tiles : 1; - int num_blocks = row_tiles * col_tiles; + num_blocks = row_tiles * col_tiles; if(threshold > 0.0f) @@ -1902,12 +1714,16 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); sycl::local_accessor stacc_char1(store_temp_storage_size_char1, cgh); sycl::local_accessor stacc_char2(store_temp_storage_size_char2, cgh); + + //__shared__ vars + sycl::local_accessor smem_row_stats_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_nnz_row_idx_acc_ct1(sycl::range<1>(256), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kDoubleRowColQuant(buff_A, rowStats, colStats, buff_out_col_normed, buff_out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols, ltacc_half, stacc_char1, stacc_char2); + kDoubleRowColQuant(buff_A, rowStats, colStats, buff_out_col_normed, buff_out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols, item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), ltacc_half, stacc_char1, stacc_char2); }); }); } @@ -1933,7 +1749,7 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols, ltacc_half, stacc_char1, stacc_char2); + kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols,item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), ltacc_half, stacc_char1, stacc_char2); }); }); @@ -1974,7 +1790,7 @@ template void transformRowToFormat(char * A, char *o int col_tiles = (tiledCols/tile_cols); row_tiles = row_tiles > 0 ? row_tiles : 1; col_tiles = col_tiles > 0 ? col_tiles : 1; - int num_blocks = row_tiles * col_tiles; + num_blocks = row_tiles * col_tiles; int outCols = fill_up_to_nearest_multiple(cols, 32); int outRows = fill_up_to_nearest_multiple(rows, 32); @@ -2354,25 +2170,25 @@ template void func(unsigned char *A, unsigned char *B, unsi template void func(float *A, float *B, float value, long n); template void func(float *A, float *B, float value, long n); -template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, sycl::half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, bfloat16 * A, unsigned char* B, float *absmax, float *datatype, bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference(int m, int n, int k, sycl::half * A, unsigned char* B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, sycl::half * A, unsigned char* B, float *absmax, float *datatype, sycl::half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, bf16 * A, unsigned char* B, float *absmax, float *datatype, bf16 * out, int lda, int ldb, int ldc, int blocksize); template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); -template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, sycl::half * A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); -template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, ycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template void transformRowToFormat(char * A, char *out, int rows, int cols); template void transformRowToFormat(char * A, char *out, int rows, int cols); @@ -2392,10 +2208,10 @@ template void quantizeBlockwise(float * code, float *A, f template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, bf16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, bf16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, bf16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, bf16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); @@ -2403,9 +2219,9 @@ template void dequantizeBlockwise(float *code, unsigned char *A, flo template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, bfloat16 *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, bfloat16 *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, bf16 *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, bf16 *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, bf16 *out, int blocksize, const int n); #define MAKE_optimizer32bit(name, gtype) \ template void optimizer32bit(gtype* g, gtype* p, \ @@ -2415,14 +2231,14 @@ template void optimizer32bit(gtype* g, gtype* p, \ MAKE_optimizer32bit(ADAM, sycl::half) MAKE_optimizer32bit(ADAM, float) -MAKE_optimizer32bit(ADAM, bfloat16) +MAKE_optimizer32bit(ADAM, bf16) MAKE_optimizer32bit(MOMENTUM, sycl::half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(RMSPROP, sycl::half) MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(LION, sycl::half) MAKE_optimizer32bit(LION, float) -MAKE_optimizer32bit(LION, bfloat16) +MAKE_optimizer32bit(LION, bf16) MAKE_optimizer32bit(ADAGRAD, sycl::half) MAKE_optimizer32bit(ADAGRAD, float) @@ -2450,7 +2266,7 @@ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ -MAKE_optimizerStatic8bitBlockwise(half, ADAM); +MAKE_optimizerStatic8bitBlockwise(sycl::half, ADAM); MAKE_optimizerStatic8bitBlockwise(float, ADAM); MAKE_optimizerStatic8bitBlockwise(sycl::half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); @@ -2458,11 +2274,11 @@ MAKE_optimizerStatic8bitBlockwise(sycl::half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); MAKE_optimizerStatic8bitBlockwise(sycl::half, LION); MAKE_optimizerStatic8bitBlockwise(float, LION); -MAKE_optimizerStatic8bitBlockwise(bfloat16, LION); +MAKE_optimizerStatic8bitBlockwise(bf16, LION); MAKE_optimizerStatic8bitBlockwise(sycl::half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(sycl::half * g, float *gnorm_vec, int step, const int n); -MAKE_optimizerStatic8bitBlockwise(bfloat16, ADAM); +MAKE_optimizerStatic8bitBlockwise(bf16, ADAM); From dfcd9d8fcd5240e31511719aa63be5e352c9de2e Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 10 Apr 2024 04:48:10 -0700 Subject: [PATCH 31/66] fix dnnl --- csrc/sycl/ops.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index f87ab6e8a..54449d7b9 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1434,8 +1434,8 @@ template void trans const auto a_md = DTYPE ==32 ? memory::desc({dim1, dim2}, dt::s32, a_strides) : memory::desc({dim1, dim2}, dt::s8, a_strides); const memory::dims out_strides = memory::dims {ldOut, 1}; const auto out_md = DTYPE ==32 ? memory::desc({dim1, dim2}, dt::s32, out_strides) : memory::desc({dim1, dim2}, dt::s8, out_strides); - const memory::dims Aout_strides = memory::dims {ldAOut, 1}; - const auto aout_md = DTYPE == 32 ? memory::desc({dim1, dim2}, dt::s32) : memory::desc({dim1, dim2}, dt::s8); + const memory::dims aout_strides = memory::dims {ldAOut, 1}; + const auto aout_md = DTYPE == 32 ? memory::desc({dim1, dim2}, dt::s32, aout_strides) : memory::desc({dim1, dim2}, dt::s8, aout_strides); //memory align memory a_mem(a_md, engine, A); @@ -1477,14 +1477,14 @@ template int igemmlt( int m, int n, auto dev = sycl::device(sycl::gpu_selector_v); auto ctx = sycl::context(dev); - dnnl::engine engine = sycL_interop::make_engine(dev, ctx); + dnnl::engine engine = dnnl::sycl_interop::make_engine(dev, ctx); // column major const memory::dims a_strides = memory::dims {1, lda}; const auto a_md = memory::desc({m, k}, dt::s8, a_strides); const memory::dims b_strides = memory::dims {ldb, 1}; const auto b_md = memory::desc({k, n}, dt::s8, b_strides); const memory::dims c_strides = memory::dims {ldc, 1}; - const auto c_md = DTYPE == 32 ? memory::desc({m, n}, dt::s32, c_strides) : memory::desc({m, n}, dt::s8, c_strides); + const auto c_md = DTYPE_OUT == 32 ? memory::desc({m, n}, dt::s32, c_strides) : memory::desc({m, n}, dt::s8, c_strides); //memory align memory a_mem(a_md, engine, A); @@ -1723,7 +1723,7 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kDoubleRowColQuant(buff_A, rowStats, colStats, buff_out_col_normed, buff_out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols, item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), ltacc_half, stacc_char1, stacc_char2); + kDoubleRowColQuant(buff_A, rowStats, colStats, buff_out_col_normed, buff_out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols, item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), ltacc_half, stacc_char1, stacc_char2); }); }); } @@ -1744,12 +1744,16 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); sycl::local_accessor stacc_char1(store_temp_storage_size_char1, cgh); sycl::local_accessor stacc_char2(store_temp_storage_size_char2, cgh); + //__shared__ vars + sycl::local_accessor smem_row_stats_acc_ct1(sycl::range<1>(256), cgh); + sycl::local_accessor smem_nnz_row_idx_acc_ct1(sycl::range<1>(256), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols,item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), ltacc_half, stacc_char1, stacc_char2); + kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols,item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), ltacc_half, stacc_char1, stacc_char2); }); }); From 886751c2ed821ebe544198546a3a56d435f9bf4f Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Fri, 10 May 2024 05:40:15 -0700 Subject: [PATCH 32/66] refactor with new api and use accessor --- csrc/sycl/kernels.cpp | 509 +++++++++++++++--------------------------- csrc/sycl/ops.cpp | 373 ++++++++++--------------------- 2 files changed, 293 insertions(+), 589 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index fe56651b7..4819558e8 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -632,35 +632,34 @@ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, cons #define THREADS_ESTIMATE 512 #define NUM_ESTIMATE 8 #define BLOCK_ESTIMATE 4096 -typedef sycl::local_accessor sycl_la_float; -typedef sycl::local_accessor sycl_la_T; -typedef sycl::local_accessor sycl_la_unsigned_char; -typedef sycl::local_accessor sycl_la_half; -typedef sycl::local_accessor sycl_la_unsigned; -typedef sycl::local_accessor sycl_la_char; + + +//================typedefs=================================== + +typedef sycl::local_accessor sycl_la; +typedef sycl::accessor sycl_dacc; +typedef sycl::accessor sycl_dacc_float; +typedef sycl::accessor sycl_dacc_uc; + +//=========================================================== template SYCL_EXTERNAL -void kEstimateQuantiles(const T *buff_A, float *code, const float offset, const T max_val, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T tacc) +void kEstimateQuantiles(const T *A, float *code, const float offset, const T max_val, const int n, + const sycl::nd_item<3> &item_ct1, sycl_la tacc, sycl_dacc dacc) { const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; const int base_idx = (item_ct1.get_group(2) * BLOCK_ESTIMATE); const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); - + + using group_load = dpct::group::workgroup_load>; + using group_radix_sort = dpct::group::radix_sort; + T vals[NUM_ESTIMATE]; + auto *d = dacc.get_multi_ptr().get(); + - //typedef cub::BlockRadixSort BlockRadixSort; - //typedef cub::BlockLoad LoadFloat; - /* - union type_ct1{ - typename LoadFloat::TempStorage loadf; - typename BlockRadixSort::TempStorage sort; - int smem_qidx[BLOCK_ESTIMATE]; - }; - type_ct1 &temp_storage = *(type_ct1 *)temp_storage_ct1; - */ int smem_qidx[BLOCK_ESTIMATE]; for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_ESTIMATE) @@ -674,10 +673,8 @@ void kEstimateQuantiles(const T *buff_A, float *code, const float offset, const for(int j = 0; j < NUM_ESTIMATE; j++) vals[j] = max_val; - /* - DPCT1065:76: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); + + item_ct1.barrier(sycl::access::fence_space::local_space); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -686,16 +683,14 @@ void kEstimateQuantiles(const T *buff_A, float *code, const float offset, const // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_A[0], vals); + group_load(tmp).load(item_ct1, d, vals); #pragma unroll 4 for(int j = 0; j < NUM_ESTIMATE; j++) vals[j] = ((float)vals[j]) * reciprocal_num_blocks; - /* - DPCT1065:77: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(); // sort into striped pattern to mitigate bank conflicts // striped pattern index for thread 0 [0, 1024, 2048, 3096] @@ -709,41 +704,35 @@ void kEstimateQuantiles(const T *buff_A, float *code, const float offset, const // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - group_radix_sort(tmp).sort_blocked_to_striped(item, vals); + group_radix_sort(tmp).sort_blocked_to_striped(item_ct1, vals); - /* - DPCT1065:78: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); + + item_ct1.barrier(sycl::access::fence_space::local_space); for(int j = item_ct1.get_local_id(2); j < BLOCK_ESTIMATE; j+=item_ct1.get_local_range(2)) - temp_storage.smem_qidx[j] = -1; + smem_qidx[j] = -1; - /* - DPCT1065:79: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); + + item_ct1.barrier(sycl::access::fence_space::local_space); if(item_ct1.get_local_id(2) < 256) { float q_interval = (1.0f-(2.0f*offset))/255.0f; - /* - DPCT1064:83: Migrated round call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code. - */ + int local_idx = sycl::round(((offset+(item_ct1.get_local_id(2)*q_interval))*(valid_items-1))); smem_qidx[local_idx] = item_ct1.get_local_id(2); } - /* - DPCT1065:80: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); + + item_ct1.barrier(sycl::access::fence_space::local_space); for(int i = item_ct1.get_local_id(2); i < BLOCK_ESTIMATE; i+=item_ct1.get_local_range(2)) { if(smem_qidx[i] != -1) dpct::atomic_fetch_add(&code[smem_qidx[i]], vals[i/THREADS_ESTIMATE]); } + } + } @@ -841,15 +830,6 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ //float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; - - - - - //typedef cub::BlockLoad LoadT; - //typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; - - //typedef cub::BlockLoad LoadFloat; - if(DATA_TYPE == General8bit) for(int i = item_ct1.get_local_id(2); i < 256; i+=item_ct1.get_local_range(2)) @@ -860,15 +840,9 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; local_abs_max = -FLT_MAX; - /* - DPCT1065:84: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:87: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back @@ -962,9 +936,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ break; } - /* - DPCT1065:86: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:89: Migration of cub::BlockStore::Store is not supported. @@ -985,8 +957,8 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ } template -SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * buff_A, float * absmax, T *buff_out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc) +SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, + const sycl::nd_item<3> &item_ct1, sycl_la tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out ) { const int n_load = (item_ct1.get_group_range(2) * TILE_SIZE); @@ -998,10 +970,14 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * buff_A, flo unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; - //typedef cub::BlockLoad LoadChar; - //typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; - - + using group_load_uc = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store>; + + + auto *d_A = dacc_A.template get_multi_ptr().get(); + auto *d_out = dacc_out.template get_multi_ptr().get(); + for (unsigned int i = base_idx; i < n_load; i += item_ct1.get_group_range(2)*TILE_SIZE) { if(DATA_TYPE > 0) @@ -1014,30 +990,20 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * buff_A, flo valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; } - /* - DPCT1098:92: The '*' expression is used instead of the __ldg call. These two expressions do not provide the exact same functionality. Check the generated code for potential precision and/or performance issues. - */ - /* - DPCT1064:96: Migrated __ldg call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code. - */ + local_abs_max = sycl::ext::oneapi::experimental::cuda::ldg(&absmax[(i+item_ct1.get_local_id(2)*NUM_PER_TH)/(blocksize)]); - /* - DPCT1065:90: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:93: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_A[0], qvals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load_uc(tmp).load(item_ct1, d_A, qvals); @@ -1048,12 +1014,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * buff_A, flo // load code through read-only cache via __ldg #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) - /* - DPCT1098:94: The '*' expression is used instead of the __ldg call. These two expressions do not provide the exact same functionality. Check the generated code for potential precision and/or performance issues. - */ - /* - DPCT1064:228: Migrated __ldg call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code. - */ + vals[j] = sycl::ext::oneapi::experimental::cuda::ldg(&code[qvals[j]]*local_abs_max); break; case FP4: @@ -1074,16 +1035,9 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * buff_A, flo break; } - /* - DPCT1065:91: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(); - /* - DPCT1007:95: Migration of cub::BlockStore::Store is not supported. - */ - //StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); - - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back @@ -1091,8 +1045,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * buff_A, flo // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_out[0], vals); + group_store(tmp).store(item_ct1, d_out, vals); } } @@ -1124,11 +1077,11 @@ template DPCT1110:1: The total declared local variable size in device function kPreconditionOptimizer32bit2State exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. */ SYCL_EXTERNAL -void kPreconditionOptimizer32bit2State(T* buff_g, T* p, - float* buff_state1, float* buff_state2, float *unorm, +void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2) + const sycl::nd_item<3> &item_ct1, sycl_la &tacc, sycl_dacc_float &dacc_state1, sycl_dacc_float &dacc_state2,sycl::accessor &dacc_g) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); @@ -1143,7 +1096,13 @@ void kPreconditionOptimizer32bit2State(T* buff_g, T* p, const float correction1 = 1.0f/(1.0f - dpct::pow(beta1, step)); const float correction2 = 1.0f/(1.0f - dpct::pow(beta2, step)); + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + auto *d_state2 = dacc_state2.get_multi_ptr().get(); //typedef cub::BlockLoad Load; //typedef cub::BlockLoad LoadFloat; @@ -1161,13 +1120,9 @@ void kPreconditionOptimizer32bit2State(T* buff_g, T* p, { valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; - /* - DPCT1065:97: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:101: Migration of cub::BlockLoad::Load is not supported. - */ + //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); @@ -1177,17 +1132,11 @@ void kPreconditionOptimizer32bit2State(T* buff_g, T* p, // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_A[0], g_vals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); + - /* - DPCT1065:98: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:102: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -1195,18 +1144,11 @@ void kPreconditionOptimizer32bit2State(T* buff_g, T* p, // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state[0], s1_vals); + group_load_float(tmp).load(item_ct1, d_state1, s1_vals); + - /* - DPCT1065:99: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:103: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -1214,8 +1156,7 @@ void kPreconditionOptimizer32bit2State(T* buff_g, T* p, // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float2.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state2[0], s2_vals); + group_load_float(tmp).load(item_ct1, d_state2, s2_vals); # pragma unroll NUM_VALS for(unsigned int j = 0; j < NUM_VALS; j++) @@ -1226,7 +1167,7 @@ void kPreconditionOptimizer32bit2State(T* buff_g, T* p, { switch(OPTIMIZER) { - case ADAM: + case 1: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); s1_vals[j] *= correction1; @@ -1241,9 +1182,7 @@ void kPreconditionOptimizer32bit2State(T* buff_g, T* p, for(unsigned int j = 1; j < NUM_VALS; j++) s1_vals[0] += s1_vals[j]; - /* - DPCT1065:100: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); s1_vals[0] = sycl::reduce_over_group(item_ct1.get_group(), s1_vals[0], sycl::plus<>()); @@ -1254,18 +1193,15 @@ void kPreconditionOptimizer32bit2State(T* buff_g, T* p, } } - - #define NUM_PER_THREAD 4 template SYCL_EXTERNAL -void kOptimizer32bit2State(T* buff_g, T* buff_p, - float* buff_state1, float* buff_state2, float *unorm, const float max_unorm, const float param_norm, +void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2 - sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2) + const sycl::nd_item<3> &item_ct1, sycl_la tacc, sycl::accessor &dacc_g, sycl::accessor &dacc_p, sycl_dacc_float &dacc_state1, sycl_dacc_float &dacc_state2) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -1289,98 +1225,72 @@ void kOptimizer32bit2State(T* buff_g, T* buff_p, else{ update_scale = 1.0f; } } else{ update_scale = 1.0f; } - - //typedef cub::BlockLoad Load; - //typedef cub::BlockStore Store; - - //typedef cub::BlockLoad LoadFloat; - //typedef cub::BlockStore StoreFloat; - /* - union type_ct3{ - typename Load::TempStorage load; - typename Store::TempStorage store; - typename LoadFloat::TempStorage loadf; - typename StoreFloat::TempStorage storef; - }; - */ - //type_ct3 &temp_storage = *(type_ct3 *)temp_storage_ct1; - + + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store>; + using group_store_float = dpct::group::workgroup_store>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_p = dacc_p.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + auto *d_state2 = dacc_state2.get_multi_ptr().get(); + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*TH*NUM_PER_THREAD) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; - /* - DPCT1065:104: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:111: Migration of cub::BlockLoad::Load is not supported. - */ - //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_g[0], g_vals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g , g_vals); - /* - DPCT1065:105: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:112: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + + item_ct1.barrier(sycl::access::fence_space::local_space); + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state1[0], s1_vals); + group_load_float(tmp).load(item_ct1, d_state1, s1_vals); - /* - DPCT1065:106: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:113: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); - + item_ct1.barrier(sycl::access::fence_space::local_space); + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float2.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state2[0], s2_vals); - /* - DPCT1065:107: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:114: Migration of cub::BlockLoad::Load is not supported. - */ - //Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); - + group_load_float(tmp).load(item_ct1, d_state2, s2_vals); + + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_p[0], p_vals); + + group_load(tmp).load(item_ct1, d_p, p_vals); @@ -1393,7 +1303,7 @@ void kOptimizer32bit2State(T* buff_g, T* buff_p, { switch(OPTIMIZER) { - case ADAM: + case 1: if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); @@ -1407,58 +1317,40 @@ void kOptimizer32bit2State(T* buff_g, T* buff_p, } } - /* - DPCT1065:108: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:115: Migration of cub::BlockStore::Store is not supported. - */ - //Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_T.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_p[0], p_vals); + group_store(tmp).store(item_ct1, d_p , p_vals); - /* - DPCT1065:109: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:116: Migration of cub::BlockStore::Store is not supported. - */ - //StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); - + + item_ct1.barrier(sycl::access::fence_space::local_space); + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_float1.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_state1[0], s1_vals); + group_store_float(tmp).store(item_ct1, d_state1, s1_vals); - /* - DPCT1065:110: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:117: Migration of cub::BlockStore::Store is not supported. - */ - //StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); + + item_ct1.barrier(sycl::access::fence_space::local_space); + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_float2.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_state2[0], s2_vals); + group_store_float(tmp).store(item_ct1, d_state2, s2_vals); } @@ -1466,11 +1358,11 @@ void kOptimizer32bit2State(T* buff_g, T* buff_p, template SYCL_EXTERNAL -void kPreconditionOptimizer32bit1State(T* buff_g, T* buff_p, +void kPreconditionOptimizer32bit1State(T* g, T* p, float* buff_state1, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_float ltacc_float1) + const sycl::nd_item<3> &item_ct1, sycl_la &tacc, sycl::accessor &dacc_g, sycl_dacc_float &dacc_state1) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); @@ -1480,31 +1372,23 @@ void kPreconditionOptimizer32bit1State(T* buff_g, T* buff_p, T g_vals[NUM_VALS]; float s1_vals[NUM_VALS]; - - //typedef cub::BlockLoad Load; - //typedef cub::BlockLoad LoadFloat; - //typedef sycl::group<3> BlockReduce; - /* - union type_ct4{ - typename Load::TempStorage load; - typename LoadFloat::TempStorage loadf; - typename BlockReduce::TempStorage reduce; - }; - type_ct4 &temp_storage = *(type_ct4 *)temp_storage_ct1; - */ + + + + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) { valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; - /* - DPCT1065:118: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ item_ct1.barrier(); - /* - DPCT1007:121: Migration of cub::BlockLoad::Load is not supported. - */ - //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -1512,17 +1396,13 @@ void kPreconditionOptimizer32bit1State(T* buff_g, T* buff_p, // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_g[0], g_vals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); - /* - DPCT1065:119: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:122: Migration of cub::BlockLoad::Load is not supported. - */ + + item_ct1.barrier(sycl::access::fence_space::local_space); + //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); // 1. load 8 values per thread @@ -1531,8 +1411,7 @@ void kPreconditionOptimizer32bit1State(T* buff_g, T* buff_p, // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_stat1[0], s1_vals); + group_load_float(tmp).load(item_ct1, d_state1, s1_vals); # pragma unroll NUM_VALS @@ -1571,13 +1450,9 @@ void kPreconditionOptimizer32bit1State(T* buff_g, T* buff_p, for(unsigned int j = 1; j < NUM_VALS; j++) s1_vals[0] += s1_vals[j]; - /* - DPCT1065:120: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:2: Migration of cub::Sum is not supported. - */ + + item_ct1.barrier(sycl::access::fence_space::local_space); + //s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); s1_vals[0] = sycl::reduce_over_group(item_ct1.get_group(), s1_vals[0], sycl::plus<>()); @@ -1588,14 +1463,15 @@ void kPreconditionOptimizer32bit1State(T* buff_g, T* buff_p, } } + template SYCL_EXTERNAL void kOptimizer32bit1State(T *buff_g, T *buff_p, float *buff_state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, - sycl_la_T stacc_T, sycl_la_float stacc_float1) + const sycl::nd_item<3> &item_ct1, sycl_la &tacc, sycl::accessor &dacc_g, sycl::accessor &dacc_p, + sycl_dacc_float &dacc_state1) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -1615,82 +1491,57 @@ void kOptimizer32bit1State(T *buff_g, T *buff_p, T p_vals[NUM_PER_THREAD]; float s1_vals[NUM_PER_THREAD]; - - //typedef cub::BlockLoad Load; - //typedef cub::BlockStore Store; - - //typedef cub::BlockLoad LoadFloat; - //typedef cub::BlockStore StoreFloat; - - /* - union type_ct5{ - typename Load::TempStorage load; - typename Store::TempStorage store; - typename LoadFloat::TempStorage loadf; - typename StoreFloat::TempStorage storef; - }; - type_ct5 &temp_storage = *(type_ct5 *)temp_storage_ct1; - */ + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store>; + using group_store_float = dpct::group::workgroup_store>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_p = dacc_p.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*TH*NUM_PER_THREAD) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; - /* - DPCT1065:123: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(); - /* - DPCT1007:128: Migration of cub::BlockLoad::Load is not supported. - */ - //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_g[0], g_vals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); - /* - DPCT1065:124: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(); - /* - DPCT1007:129: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state1[0], s1_vals); + group_load_float(tmp).load(item_ct1, d_state1, s1_vals); - /* - DPCT1065:125: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(); - /* - DPCT1007:130: Migration of cub::BlockLoad::Load is not supported. - */ - //Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_p[0], p_vals); + group_load(tmp).load(item_ct1, d_p, p_vals); # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) @@ -1735,37 +1586,25 @@ void kOptimizer32bit1State(T *buff_g, T *buff_p, DPCT1065:126: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(); - /* - DPCT1007:131: Migration of cub::BlockStore::Store is not supported. - */ - //Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_T.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_p[0], p_vals); + group_store(tmp).store(item_ct1, d_p, p_vals); + - /* - DPCT1065:127: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ item_ct1.barrier(); - /* - DPCT1007:132: Migration of cub::BlockStore::Store is not supported. - */ - //StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_float1.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_state1[0], s1_vals); + group_store_float(tmp).store(item_ct1, d_state1, s1_vals); } diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 54449d7b9..65d4ca93b 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -61,45 +61,35 @@ template void estimateQuantiles(T *A, float *code, float offset, in sycl::queue &q_ct1 = dev_ct1.in_order_queue(); int num_blocks = n/4096; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(code, 0, 256*sizeof(float)).wait())); - DPCT_CHECK_ERROR(q_ct1.memset(code, 0, 256*sizeof(float)).wait()); + //DPCT_CHECK_ERROR(q_ct1.memset(code, 0, 256*sizeof(float)).wait()); sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; - T *buff_A; - *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((T*)(buff_A), (T*)(A), NUM_BLOCK); - //sycl::buffer buff_A(A,sycl::range<1>(num_blocks)); - /* - DPCT1049:54: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + sycl::buffer buff_A(A,sycl::range<1>(size)); + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load = dpct::group::workgroup_load; - using group_radix_sort = dpct::group::radix_sort; - size_t sort_temp_storage_size = group_radix_sort::get_local_memory_size(NUM_BLOCK); - sycl::local_accessor tacc(sort_temp_storage_size, cgh); + using group_load = dpct::group::workgroup_load>; + using group_radix_sort = dpct::group::radix_sort; + size_t temp_storage_size = group_radix_sort::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(sycl::range<1>(temp_storage_size), cgh); + sycl::accessor dacc(buff_A, cgh, sycl::read_write); - /* - DPCT1054:293: The type of variable temp_storage is declared in device function with the name type_ct1. Adjust the code to make the type_ct1 declaration visible at the accessor declaration point. - */ - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - + auto std_numeric_limits_T_max_ct3 = std::numeric_limits::max(); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kEstimateQuantiles(buff_A, code, offset, std_numeric_limits_T_max_ct3, n, item_ct1, tacc); + kEstimateQuantiles(A, code, offset, std_numeric_limits_T_max_ct3, n, item_ct1, tacc, dacc); + }); }); } - //back memcpy - q_ct1.memcpy((T*)(A), (T*)(buff_A), NUM_BLOCK); } @@ -169,9 +159,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n) q_ct1.memcpy((void*)(buff_out), (void*)(out), NUM_BLOCK); q_ct1.memcpy((void*)(buff_A), (void*)(A), NUM_BLOCK); - /* - DPCT1049:56: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( @@ -186,16 +174,13 @@ void dequantize(float *code, unsigned char *A, float *out, int n) kDequantize(code, buff_A, buff_out, n, item_ct1, smem_code_acc_ct1.get_pointer()); }); }); - //q_ct1.wait(); + } //back memcpy q_ct1.memcpy((void *)(out), (void*)(buff_out), NUM_BLOCK); q_ct1.memcpy((void*)(A), (void*)(buff_A), NUM_BLOCK); - /* - DPCT1010:233: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + } template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) @@ -425,14 +410,12 @@ template void quantizeBlockwise(floa } - /* - DPCT1010:234: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ + //back memcpy q_ct1.memcpy((void*)(A), (void*)(buff_A), NUM_BLOCK); q_ct1.memcpy((void*)(out), (void*)(buff_out), NUM_BLOCK); q_ct1.memcpy((void*)(rand), (void*)(buff_rand), NUM_BLOCK); - //CUDA_CHECK_RETURN(0); + } template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) @@ -443,13 +426,10 @@ template void dequantizeBlockwise(float *code, unsign num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; int tile_size = (DATA_TYPE > 0) ? 1024 : 512; sycl::context ctx = q_ct1.get_context(); + int size = NUM_BLOCK; - unsigned char *buff_A; - T *buff_out; - *((void **)&buff_A) = sycl::malloc_device(tile_size, dev_ct1, ctx); - *((void **)&buff_out) = sycl::malloc_device(tile_size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_A), (void*)(A), tile_size); - q_ct1.memcpy((T*)(buff_out), (T*)(out), tile_size); + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_out(out,sycl::range<1>(size)); if(DATA_TYPE > 0) { @@ -457,19 +437,18 @@ template void dequantizeBlockwise(float *code, unsign q_ct1.submit( [&](sycl::handler &cgh){ - using group_load = dpct::group::workgroup_load; - using group_store = dpct::group::workgroup_store; - size_t store_temp_storage_size = group_store::get_local_memory_size(tile_size); - size_t load_temp_storage_size = group_load::get_local_memory_size(tile_size); - sycl::local_accessor ltacc(load_temp_storage_size, cgh); - sycl::local_accessor stacc(store_temp_storage_size, cgh); - - + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); q_ct1.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, (n+tile_size-1)/tile_size) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { - kDequantizeBlockwise(code, buff_A, absmax, buff_out, blocksize/2, n, item_ct1, ltacc, stacc); + kDequantizeBlockwise(code, A, absmax, out, blocksize/2, n, item_ct1, tacc, dacc_A, dacc_out); }); }); } @@ -479,29 +458,22 @@ template void dequantizeBlockwise(float *code, unsign q_ct1.submit( [&](sycl::handler &cgh){ - using group_load = dpct::group::workgroup_load; - using group_store = dpct::group::workgroup_store; - size_t store_temp_storage_size = group_store::get_local_memory_size(tile_size); - size_t load_temp_storage_size = group_load::get_local_memory_size(tile_size); - sycl::local_accessor ltacc(load_temp_storage_size, cgh); - sycl::local_accessor stacc(store_temp_storage_size, cgh); + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, (n+tile_size-1)/tile_size) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { - kDequantizeBlockwise(code, buff_A, absmax, buff_out, blocksize, n, item_ct1, ltacc, stacc); + kDequantizeBlockwise(code, buff_A, absmax, buff_out, blocksize, n, item_ct1, tacc, dacc_A, dacc_out); }); }); } - /* - DPCT1010:235: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //back memcpy - q_ct1.memcpy((void*)(A), (void*)(buff_A), tile_size); - q_ct1.memcpy((T*)(out), (T*)(buff_out), tile_size); - - //CUDA_CHECK_RETURN(0); } @@ -525,16 +497,10 @@ template void optimizer32bit(T* g, T* p, num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; int size= NUM_BLOCK; - T *buff_g,*buff_p; - float *buff_state1,*buff_state2; - *((void **)&buff_g) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_p) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_state1) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_state2) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_g), (void*)(g), size); - q_ct1.memcpy((void*)(buff_p), (void*)(p), size); - q_ct1.memcpy((void*)(buff_state1), (void*)(state1), size); - q_ct1.memcpy((void*)(buff_state2), (void*)(state2), size); + sycl::buffer buff_g(g,sycl::range<1>(size)); + sycl::buffer buff_p(p,sycl::range<1>(size)); + sycl::buffer buff_state1(state1,sycl::range<1>(size)); + sycl::buffer buff_state2(state2,sycl::range<1>(size)); switch(OPTIMIZER) @@ -542,271 +508,170 @@ template void optimizer32bit(T* g, T* p, case ADAM: if(max_unorm > 0.0f) { - //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); - DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); - /* - DPCT1049:61: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + + //DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load_T = dpct::group::workgroup_load; - using group_load_float1 = dpct::group::workgroup_load; - using group_load_float2 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_float2 = group_load_float2::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - sycl::local_accessor ltacc_float2(load_temp_storage_size_float2, cgh); + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); - /* - DPCT1054:294: The type of variable temp_storage is declared in device function with the name type_ct2. Adjust the code to make the type_ct2 declaration visible at the accessor declaration point. - */ cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizer32bit2State(g, p, buff_state1, buff_state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float1, ltacc_float2); + kPreconditionOptimizer32bit2State(g, p, buff_state1, buff_state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_state1, dacc_state2, dacc_g); }); }); } - /* - DPCT1010:236: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + } - /* - DPCT1049:59: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - /* - DPCT1054:295: The type of variable temp_storage is declared in device function with the name type_ct3. Adjust the code to make the type_ct3 declaration visible at the accessor declaration point. - */ - - using group_load_T = dpct::group::workgroup_load; - using group_load_T1 = dpct::group::workgroup_load; - using group_load_float1 = dpct::group::workgroup_load; - using group_load_float2 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_float2 = group_load_float2::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - sycl::local_accessor ltacc_float2(load_temp_storage_size_float2, cgh); - + - using group_store_T = dpct::group::workgroup_store; - using group_store_float1 = dpct::group::workgroup_store; - using group_store_float2 = dpct::group::workgroup_store; - size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_float2 = group_store_float2::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); - sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); - sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); - sycl::local_accessor stacc_float2(store_temp_storage_size_float2, cgh); - - - - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizer32bit2State(buff_g, buff_p, buff_state1, buff_state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, ltacc_float2,stacc_T, stacc_float1, stacc_float2); + kOptimizer32bit2State(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, tacc, dacc_g, dacc_p, dacc_state1, dacc_state2); }); }); } - /* - DPCT1010:237: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + break; case MOMENTUM: case RMSPROP: case ADAGRAD: if(max_unorm > 0.0f) { - //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); - DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); - /* - DPCT1049:62: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + //DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - /* - DPCT1054:296: The type of variable temp_storage is declared in device function with the name type_ct4. Adjust the code to make the type_ct4 declaration visible at the accessor declaration point. - */ - using group_load_T = dpct::group::workgroup_load; - using group_load_float1 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - - cgh.parallel_for( + + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizer32bit1State(buff_g, buff_p, buff_state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float1); + kPreconditionOptimizer32bit1State(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_g, dacc_state1); }); }); } - /* - DPCT1010:238: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); - } + } - /* - DPCT1049:60: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - /* - DPCT1054:297: The type of variable temp_storage is declared in device function with the name type_ct5. Adjust the code to make the type_ct5 declaration visible at the accessor declaration point. - */ - using group_load_T = dpct::group::workgroup_load; - using group_load_T1 = dpct::group::workgroup_load; - using group_load_float1 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - - - using group_store_T = dpct::group::workgroup_store; - using group_store_float1 = dpct::group::workgroup_store; - size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); - sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); - - - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - - cgh.parallel_for( + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizer32bit1State(buff_g, buff_p, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, stacc_T,stacc_float1); + kOptimizer32bit1State(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, tacc, dacc_g, dacc_p, dacc_state1); }); }); } - /* - DPCT1010:239: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + break; case LION: // in lion, the momentum update after the parameter update - /* - DPCT1049:63: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - /* - DPCT1054:298: The type of variable temp_storage is declared in device function with the name type_ct5. Adjust the code to make the type_ct5 declaration visible at the accessor declaration point. - */ - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - using group_load_T = dpct::group::workgroup_load; - using group_load_T1 = dpct::group::workgroup_load; - using group_load_float1 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - - - using group_store_T = dpct::group::workgroup_store; - using group_store_float1 = dpct::group::workgroup_store; - size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); - sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizer32bit1State(buff_g, buff_p, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, ltacc_T, ltacc_T1, ltacc_float1, stacc_T,stacc_float1); + kOptimizer32bit1State(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, tacc, dacc_g, dacc_p, dacc_state1); }); }); } - /* - DPCT1010:240: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + if(max_unorm > 0.0f) { - //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); - DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); - /* - DPCT1049:64: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + //DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - /* - DPCT1054:299: The type of variable temp_storage is declared in device function with the name type_ct4. Adjust the code to make the type_ct4 declaration visible at the accessor declaration point. - */ - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - using group_load_T = dpct::group::workgroup_load; - using group_load_float1 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizer32bit1State(buff_g, buff_p, buff_state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, ltacc_T, ltacc_float); + kPreconditionOptimizer32bit1State(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_g, dacc_state1); }); }); } - /* - DPCT1010:241: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + } break; } - - //back memcpy - q_ct1.memcpy((void*)(g), (void*)(buff_g), size); - q_ct1.memcpy((void*)(p), (void*)(buff_p), size); - q_ct1.memcpy((void*)(state1), (void*)(buff_state1), size); - q_ct1.memcpy((void*)(state2), (void*)(buff_state2), size); } catch (sycl::exception const &exc) { From f90c06dfcc0ee6d217518721cf3d4ddc8b76df72 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 13 May 2024 00:49:55 -0700 Subject: [PATCH 33/66] full 32 optimizer fixed update --- csrc/sycl/kernels.cpp | 29 ++++++++++++----------------- csrc/sycl/ops.cpp | 10 +++++----- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 4819558e8..7599bbeea 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -972,7 +972,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * using group_load_uc = dpct::group::workgroup_load>; - using group_store = dpct::group::workgroup_store>; + using group_store = dpct::group::workgroup_store>; auto *d_A = dacc_A.template get_multi_ptr().get(); @@ -1081,7 +1081,7 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la &tacc, sycl_dacc_float &dacc_state1, sycl_dacc_float &dacc_state2,sycl::accessor &dacc_g) + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); @@ -1122,9 +1122,6 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, item_ct1.barrier(sycl::access::fence_space::local_space); - - //Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); - // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -1201,7 +1198,7 @@ void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la tacc, sycl::accessor &dacc_g, sycl::accessor &dacc_p, sycl_dacc_float &dacc_state1, sycl_dacc_float &dacc_state2) + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -1362,7 +1359,7 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, float* buff_state1, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la &tacc, sycl::accessor &dacc_g, sycl_dacc_float &dacc_state1) + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); @@ -1376,8 +1373,8 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, - using group_load = dpct::group::workgroup_load>; - using group_load_float = dpct::group::workgroup_load>; + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; auto *d_g = dacc_g.template get_multi_ptr().get(); @@ -1466,11 +1463,11 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, template SYCL_EXTERNAL -void kOptimizer32bit1State(T *buff_g, T *buff_p, - float *buff_state1, float *unorm, const float max_unorm, const float param_norm, +void kOptimizer32bit1State(T *g, T *p, + float *state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la &tacc, sycl::accessor &dacc_g, sycl::accessor &dacc_p, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1) { @@ -1582,10 +1579,8 @@ void kOptimizer32bit1State(T *buff_g, T *buff_p, } } - /* - DPCT1065:126: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); + + item_ct1.barrier(sycl::access::fence_space::local_space); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -1621,7 +1616,7 @@ DPCT1110:6: The total declared local variable size in device function kPrecondit */ SYCL_EXTERNAL void -kPreconditionOptimizerStatic8bit2State(T* buff_p, T* __restrict__ const buff_g, unsigned char*__restrict__ const buff_state1, unsigned char* __restrict__ const buff_state2, +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const buff_state2, float *unorm, const float beta1, const float beta2, const float eps, const int step, diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 65d4ca93b..0fb0665a8 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -530,7 +530,7 @@ template void optimizer32bit(T* g, T* p, cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizer32bit2State(g, p, buff_state1, buff_state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_state1, dacc_state2, dacc_g); + kPreconditionOptimizer32bit2State(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_state1, dacc_state2, dacc_g); }); }); } @@ -575,8 +575,8 @@ template void optimizer32bit(T* g, T* p, q_ct1.submit( [&](sycl::handler &cgh) { - using group_load = dpct::group::workgroup_load>; - using group_load_float = dpct::group::workgroup_load>; + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); sycl::local_accessor tacc(temp_storage_size, cgh); @@ -653,8 +653,8 @@ template void optimizer32bit(T* g, T* p, dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load = dpct::group::workgroup_load>; - using group_load_float = dpct::group::workgroup_load>; + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); sycl::local_accessor tacc(temp_storage_size, cgh); From 873eadb71761b1fe42ecfd5aa8711ad6b1faca22 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 13 May 2024 05:07:45 -0700 Subject: [PATCH 34/66] refine 8 bit optimizers --- csrc/sycl/kernels.cpp | 506 ++++++++++++++---------------------------- csrc/sycl/ops.cpp | 258 ++++++--------------- 2 files changed, 241 insertions(+), 523 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 7599bbeea..8b30031a5 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -1072,6 +1072,9 @@ SYCL_EXTERNAL void kDequantize(float *code, unsigned char *buff_A, float *buff_o +//===================32 bit optimizer======================== + + template /* DPCT1110:1: The total declared local variable size in device function kPreconditionOptimizer32bit2State exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. @@ -1606,17 +1609,18 @@ void kOptimizer32bit1State(T *g, T *p, } +//===================8 bit optimizer======================== + + #define NUM8BIT 16 #define NUM_THREADS 256 #define NUM_PER_BLOCK 4096 template -/* -DPCT1110:6: The total declared local variable size in device function kPreconditionOptimizerStatic8bit2State exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. -*/ -SYCL_EXTERNAL void -kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const buff_state2, +SYCL_EXTERNAL void +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + unsigned char* __restrict__ const state2, float *unorm, const float beta1, const float beta2, const float eps, const int step, @@ -1625,8 +1629,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float* smem_quantiles1, float* smem_quantiles2, - sycl_la_T ltacc_T, - sycl_la_float ltacc_float1, sycl_la_float ltacc_float2) + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2) { const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); @@ -1642,90 +1645,59 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c unsigned char m_c1[NUM8BIT]; unsigned char r_c2[NUM8BIT]; - //typedef cub::BlockLoad LoadT; - //typedef cub::BlockLoad LoadUInt8; - //typedef sycl::group<3> BlockReduce; - - /* - union type_ct6{ - typename LoadT::TempStorage loadh; - typename LoadUInt8::TempStorage loadc; - typename BlockReduce::TempStorage reduce; - }; - type_ct6 &temp_storage = *(type_ct6 *)temp_storage_ct1; - */ - - - + using group_load = dpct::group::workgroup_load>; + using group_load_uc = dpct::group::workgroup_load>; + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + auto *d_state2 = dacc_state2.get_multi_ptr().get(); + if(item_ct1.get_local_id(2) < 256) { smem_quantiles1[item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; smem_quantiles2[item_ct1.get_local_id(2)] = quantiles2[item_ct1.get_local_id(2)]; } - /* - DPCT1065:150: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); + + item_ct1.barrier(sycl::access::fence_space::local_space); for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*item_ct1.get_group_range(2)*NUM8BIT) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; - /* - DPCT1007:156: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_g[0], g_vals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); - /* - DPCT1065:153: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:157: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); - + + item_ct1.barrier(sycl::access::fence_space::local_space); + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state1[0], m_c1); + group_load_uc(tmp).load(item_ct1, d_state1, m_c1); + + + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1065:154: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:158: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); - // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ;ltacc_float2.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state2[0], r_c2); + group_load_uc(tmp).load(item_ct1, d_state2, r_c2); - /* - DPCT1065:155: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); + + item_ct1.barrier(sycl::access::fence_space::local_space); #pragma unroll 16 for(int j = 0; j < NUM8BIT; j++) @@ -1762,34 +1734,19 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c } } - /* - DPCT1065:151: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:7: Migration of cub::Reduce is not supported. - */ - //local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, sycl::maximum<>(), valid_items); + + item_ct1.barrier(sycl::access::fence_space::local_space); + local_max_s1 = sycl::reduce_over_group(item_ct1.get_group(), local_max_s1, sycl::maximum<>()); - /* - DPCT1065:152: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:8: Migration of cub::Reduce is not supported. - */ - //local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, sycl::maximum<>(), valid_items); + + item_ct1.barrier(sycl::access::fence_space::local_space); + local_max_s2 = sycl::reduce_over_group(item_ct1.get_group(), local_max_s2, sycl::maximum<>()); if(unorm != NULL) { - /* - DPCT1065:159: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:9: Migration of cub::Reduce is not supported. - */ - //local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, sycl::plus<>(), valid_items); + + item_ct1.barrier(sycl::access::fence_space::local_space); + local_unorm = sycl::reduce_over_group(item_ct1.get_group(), local_unorm, sycl::plus<>()); } @@ -1808,7 +1765,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c template SYCL_EXTERNAL void -kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state1, unsigned char* buff_state2, +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, const float *unorm, const float max_unorm, const float param_norm, \ const float beta1, const float beta2, const float eps, const int step, const float lr, @@ -1817,9 +1774,8 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state float weight_decay, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float* smem_quantiles1, float* smem_quantiles2, - sycl_la_T ltacc_T, sycl_la_T ltacc_T1, - sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, - sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2 + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2 ) { @@ -1850,20 +1806,18 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state T p_vals[NUM_PER_THREAD2]; T g_vals[NUM_PER_THREAD2]; + using group_load = dpct::group::workgroup_load>; + using group_load_uc = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store>; + using group_store_uc = dpct::group::workgroup_store>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_p = dacc_p.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + auto *d_state2 = dacc_state2.get_multi_ptr().get(); - //typedef cub::BlockLoad LoadT; - //typedef cub::BlockLoad LoadChar; - //typedef cub::BlockStore StoreChar; - //typedef cub::BlockStore StoreT; - /* - union type_ct7{ - typename LoadT::TempStorage loadh; - typename LoadChar::TempStorage loadc; - typename StoreChar::TempStorage storec; - typename StoreT::TempStorage storeh; - }; - type_ct7 &temp_storage = *(type_ct7 *)temp_storage_ct1; - */ if(item_ct1.get_local_id(2) < 512) { @@ -1873,75 +1827,48 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state smem_quantiles2[item_ct1.get_local_id(2)-256] = quantiles2[item_ct1.get_local_id(2)-256]; } - /* - DPCT1065:160: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); + + item_ct1.barrier(sycl::access::fence_space::local_space); for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*NUM_THREADS2*NUM_PER_THREAD2) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; - /* - DPCT1007:167: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_g[0], g_vals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); + + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1065:161: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:168: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); - // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state1[0], c1s); + group_load_uc(tmp).load(item_ct1, d_state1, c1s); - /* - DPCT1065:162: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:169: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); - + + item_ct1.barrier(sycl::access::fence_space::local_space); + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float2.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state2[0], c2s); + group_load_uc(tmp).load(item_ct1, d_state2, c2s); - /* - DPCT1065:163: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:170: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + item_ct1.barrier(sycl::access::fence_space::local_space); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -1949,8 +1876,7 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_p[0], p_vals); + group_load(tmp).load(item_ct1, d_p, p_vals); if((i + (item_ct1.get_local_id(2)*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } @@ -1990,48 +1916,28 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); } - /* - DPCT1007:171: Migration of cub::BlockStore::Store is not supported. - */ - //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_T.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_p[0], p_vals); + group_store(tmp).store(item_ct1, d_p, p_vals); - /* - DPCT1065:164: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:172: Migration of cub::BlockStore::Store is not supported. - */ - //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); - + + item_ct1.barrier(sycl::access::fence_space::local_space); + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_float1.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_state1[0], c1s); + group_store_uc(tmp).store(item_ct1, d_state1, c1s); - - /* - DPCT1065:165: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:173: Migration of cub::BlockStore::Store is not supported. - */ - //StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + item_ct1.barrier(sycl::access::fence_space::local_space); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2039,25 +1945,15 @@ kOptimizerStatic8bit2State(T* buff_p, T* const buff_g, unsigned char* buff_state // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_float2.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_state2[0], c2s); + group_store_uc(tmp).store(item_ct1, d_state2, c2s); - - /* - DPCT1065:166: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); } } - template -/* -DPCT1110:3: The total declared local variable size in device function kPreconditionOptimizerStatic8bit1State exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. -*/ SYCL_EXTERNAL void - -kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, unsigned char*__restrict__ const buff_state1, +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, float *unorm, const float beta1, const float beta2, const float eps, const int step, @@ -2067,7 +1963,7 @@ kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float* smem_quantiles1, - sycl_la_T ltacc_T, sycl_la_float ltacc_float1) + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1) { const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); @@ -2079,60 +1975,36 @@ kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, float s1_vals[NUM8BIT]; T g_vals[NUM8BIT]; unsigned char m_c1[NUM8BIT]; - - //typedef cub::BlockLoad LoadT; - //typedef cub::BlockLoad LoadUInt8; - //typedef sycl::group<3> BlockReduce; - - /* - union type_ct8{ - typename LoadT::TempStorage loadh; - typename LoadUInt8::TempStorage loadc; - typename BlockReduce::TempStorage reduce; - }; - type_ct8 &temp_storage = *(type_ct8 *)temp_storage_ct1; - - */ - - if(item_ct1.get_local_id(2) < 256) + using group_load = dpct::group::workgroup_load>; + using group_load_uc = dpct::group::workgroup_load>; + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + + if(item_ct1.get_local_id(2) < 256) smem_quantiles1[item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; - /* - DPCT1065:133: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); + + item_ct1.barrier(sycl::access::fence_space::local_space); for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*NUM_THREADS*NUM8BIT) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; - /* - DPCT1065:135: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:137: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - + + item_ct1.barrier(sycl::access::fence_space::local_space); + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_g[0], g_vals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); - /* - DPCT1065:136: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:138: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + item_ct1.barrier(sycl::access::fence_space::local_space); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2140,8 +2012,7 @@ kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state1[0], m_c1); + group_load_uc(tmp).load(item_ct1, d_state1, m_c1); #pragma unroll 16 for(int j = 0; j < NUM8BIT; j++) @@ -2171,27 +2042,17 @@ kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, } } - /* - DPCT1065:134: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:4: Migration of cub::Reduce is not supported. - */ + + item_ct1.barrier(sycl::access::fence_space::local_space); + local_max_s1 = sycl::reduce_over_group(item_ct1.get_group(), local_max_s1, sycl::maximum<>()); - //local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, sycl::maximum<>(), valid_items); if(item_ct1.get_local_id(2) == 0){ atomicMax(&new_max1[0], local_max_s1); } if(unorm != NULL) { - /* - DPCT1065:139: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:5: Migration of cub::Reduce is not supported. - */ - //local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, sycl::plus<>(), valid_items); - local_norm = sycl::reduce_over_group(item_ct1.get_group(), local_norm, sycl::plus<>()); + + item_ct1.barrier(sycl::access::fence_space::local_space); + + local_unorm = sycl::reduce_over_group(item_ct1.get_group(), local_unorm, sycl::plus<>()); if(item_ct1.get_local_id(2) == 0){ dpct::atomic_fetch_add(&unorm[0], local_unorm); } } @@ -2200,7 +2061,7 @@ kPreconditionOptimizerStatic8bit1State(T* buff_p, T* __restrict__ const buff_g, template SYCL_EXTERNAL void -kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state1, +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, @@ -2208,9 +2069,9 @@ kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state float* max1, float* new_max1, float weight_decay, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, - sycl_la_float ltacc_float1, sycl_la_T stacc_T, - sycl_la_float stacc_float1) + const sycl::nd_item<3> &item_ct1,float *smem_quantiles1, const sycl_la &tacc, + const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1) { const int n_full = (item_ct1.get_local_range(2) * item_ct1.get_group_range(2))*NUM_PER_THREAD2; @@ -2234,56 +2095,42 @@ kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state T g_vals[NUM_PER_THREAD2]; - //typedef cub::BlockLoad LoadT; - //typedef cub::BlockLoad LoadChar; - - //typedef cub::BlockStore StoreChar; - //typedef cub::BlockStore StoreT; - - /* - union type_ct9{ - typename LoadT::TempStorage loadh; - typename LoadChar::TempStorage loadc; - typename StoreChar::TempStorage storec; - typename StoreT::TempStorage storeh; - }; - type_ct9 &temp_storage = *(type_ct9 *)temp_storage_ct1; - */ + using group_load = dpct::group::workgroup_load>; + using group_load_uc = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store>; + using group_store_uc = dpct::group::workgroup_store>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_p = dacc_p.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + + + if(item_ct1.get_local_id(2) < 256) smem_quantiles1[item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; - /* - DPCT1065:140: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); + + item_ct1.barrier(sycl::access::fence_space::local_space); for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*NUM_THREADS2*NUM_PER_THREAD2) { valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; - /* - DPCT1007:145: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_g[0], g_vals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); - /* - DPCT1065:141: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:146: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + + item_ct1.barrier(sycl::access::fence_space::local_space); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2291,18 +2138,11 @@ kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state1[0], c1s); + group_load_uc(tmp).load(item_ct1, d_state1, c1s); - /* - DPCT1065:142: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:147: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + item_ct1.barrier(sycl::access::fence_space::local_space); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2310,8 +2150,7 @@ kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_p[0], p_vals); + group_load(tmp).load(item_ct1, d_p, p_vals); if((i + (item_ct1.get_local_id(2)*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } @@ -2367,10 +2206,6 @@ kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state } } - /* - DPCT1007:148: Migration of cub::BlockStore::Store is not supported. - */ - //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2378,17 +2213,10 @@ kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_T.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_p[0], p_vals); + group_store(tmp).store(item_ct1, d_p, p_vals); - /* - DPCT1065:143: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); - /* - DPCT1007:149: Migration of cub::BlockStore::Store is not supported. - */ - //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + + item_ct1.barrier(sycl::access::fence_space::local_space); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2396,17 +2224,17 @@ kOptimizerStatic8bit1State(T* buff_p, T* const buff_g, unsigned char* buff_state // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_float1.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_state1[0], c1s); + group_store_uc(tmp).store(item_ct1, d_state1, c1s); - /* - DPCT1065:144: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); } } +//=========================================================================== + + + template SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ buff_g, float *gnorm_vec, int step, const int n, const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T) @@ -2431,7 +2259,7 @@ SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ buff_g, float *gnorm_vec /* DPCT1065:202: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:203: Migration of cub::BlockLoad::Load is not supported. */ @@ -2549,7 +2377,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:174: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #pragma unroll for(int k = 0; k < QUAD; k++) @@ -2566,7 +2394,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:175: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:183: Migration of cub::BlockLoad::Load is not supported. */ @@ -2584,7 +2412,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:176: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:184: Migration of cub::BlockLoad::Load is not supported. */ @@ -2602,7 +2430,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:177: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:185: Migration of cub::BlockLoad::Load is not supported. */ @@ -2662,7 +2490,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:178: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); if(item_ct1.get_local_id(2) == 0) { @@ -2678,7 +2506,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:179: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:186: Migration of cub::BlockLoad::Load is not supported. */ @@ -2710,7 +2538,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:180: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:187: Migration of cub::BlockStore::Store is not supported. */ @@ -2746,7 +2574,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:181: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:188: Migration of cub::BlockStore::Store is not supported. */ @@ -2765,7 +2593,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:182: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:189: Migration of cub::BlockStore::Store is not supported. */ @@ -2850,7 +2678,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:190: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #pragma unroll for(int k = 0; k < QUAD; k++) @@ -2863,7 +2691,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:191: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:197: Migration of cub::BlockLoad::Load is not supported. */ @@ -2881,7 +2709,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:192: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:198: Migration of cub::BlockLoad::Load is not supported. */ @@ -2899,7 +2727,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:193: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:199: Migration of cub::BlockLoad::Load is not supported. */ @@ -2974,7 +2802,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:194: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); if(item_ct1.get_local_id(2) == 0) absmax1[i/BLOCK_SIZE] = new_local_abs_max1; @@ -3011,7 +2839,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:195: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:200: Migration of cub::BlockStore::Store is not supported. */ @@ -3046,7 +2874,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns /* DPCT1065:196: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); /* DPCT1007:201: Migration of cub::BlockStore::Store is not supported. */ @@ -3122,7 +2950,7 @@ template items_per_load ? items_per_load : cols - base_col; int i = base_idx; @@ -3137,7 +2965,7 @@ template()); if(SPARSE_DECOMP) @@ -3194,7 +3022,7 @@ template()); } // we store the data temporarily in shared memory so we @@ -3210,7 +3038,7 @@ templatevoid kdequant_mm_i /* DPCT1065:205: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); // each block processes SUBTILE_ROWS*32 elements @@ -3466,7 +3294,7 @@ template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, float *unorm, float max_unorm, float param_norm, @@ -696,50 +701,35 @@ template void optimizerStatic8bit(T* p, T* g, sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; - T *buff_g,*buff_p; - unsigned char *buff_state1,*buff_state2; - *((void **)&buff_g) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_p) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_state1) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_state2) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_g), (void*)(g), size); - q_ct1.memcpy((void*)(buff_p), (void*)(p), size); - q_ct1.memcpy((void*)(buff_state1), (void*)(state1), size); - q_ct1.memcpy((void*)(buff_state2), (void*)(state2), size); - + sycl::buffer buff_g(g,sycl::range<1>(size)); + sycl::buffer buff_p(p,sycl::range<1>(size)); + sycl::buffer buff_state1(state1,sycl::range<1>(size)); + sycl::buffer buff_state2(state2,sycl::range<1>(size)); - if(max_unorm > 0.0f){ //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait())); - DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); } + if(max_unorm > 0.0f){ + q_ct1.memset(unorm, 0, 1*sizeof(float)).wait(); } switch(OPTIMIZER) { case ADAM: - //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait())); - //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max2, 0, 1*sizeof(float)).wait())); - DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); - DPCT_CHECK_ERROR(q_ct1.memset(new_max2, 0, 1*sizeof(float)).wait()); + + //DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); + //DPCT_CHECK_ERROR(q_ct1.memset(new_max2, 0, 1*sizeof(float)).wait()); { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - /* - DPCT1054:300: The type of variable temp_storage is declared in device function with the name type_ct6. Adjust the code to make the type_ct6 declaration visible at the accessor declaration point. - */ - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); - //sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); + - using group_load_T = dpct::group::workgroup_load; - using group_load_float1 = dpct::group::workgroup_load; - using group_load_float2 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_float2 = group_load_float2::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - sycl::local_accessor ltacc_float2(load_temp_storage_size_float2, cgh); + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); @@ -748,54 +738,25 @@ template void optimizerStatic8bit(T* p, T* g, cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizerStatic8bit2State(buff_p, buff_g, buff_state1, buff_state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), ltacc_T, ltacc_float1, ltacc_float2); + kPreconditionOptimizerStatic8bit2State(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), tacc, dacc_g, dacc_state1, dacc_state2); }); }); } - /* - DPCT1010:242: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); - /* - DPCT1049:65: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - //sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); - //sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); - /* - DPCT1054:301: The type of variable temp_storage is declared in device function with the name type_ct7. Adjust the code to make the type_ct7 declaration visible at the accessor declaration point. - */ + - using group_load_T = dpct::group::workgroup_load; - using group_load_T1 = dpct::group::workgroup_load; - using group_load_float1 = dpct::group::workgroup_load; - using group_load_float2 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_float2 = group_load_float2::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - sycl::local_accessor ltacc_float2(load_temp_storage_size_float2, cgh); - + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); - using group_store_T = dpct::group::workgroup_store; - using group_store_float1 = dpct::group::workgroup_store; - using group_store_float2 = dpct::group::workgroup_store; - size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_float2 = group_store_float2::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); - sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); - sycl::local_accessor stacc_float2(store_temp_storage_size_float2, cgh); - - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); @@ -804,132 +765,79 @@ template void optimizerStatic8bit(T* p, T* g, cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit2State(buff_p, buff_g, buff_state1, buff_state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), ltacc_T, ltacc_T1, ltacc_float1, ltacc_float2, stacc_T, stacc_float1, stacc_float2); + kOptimizerStatic8bit2State(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_state2); }); }); } - /* - DPCT1010:243: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + break; case MOMENTUM: case RMSPROP: case ADAGRAD: - //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait())); - DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); + + //DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - /* - DPCT1054:302: The type of variable temp_storage is declared in device function with the name type_ct8. Adjust the code to make the type_ct8 declaration visible at the accessor declaration point. - */ - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - ; - using group_load_T = dpct::group::workgroup_load; - using group_load_float1 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); - //__shared__ vars + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + + //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), ltacc_T, ltacc_float1); + kPreconditionOptimizerStatic8bit1State(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_state1); }); }); } - /* - DPCT1010:244: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); - /* - DPCT1049:66: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - - /* - DPCT1054:303: The type of variable temp_storage is declared in device function with the name type_ct9. Adjust the code to make the type_ct9 declaration visible at the accessor declaration point. - */ - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - - using group_load_T = dpct::group::workgroup_load; - using group_load_T1 = dpct::group::workgroup_load; - using group_load_float1 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - - - using group_store_T = dpct::group::workgroup_store; - using group_store_float1 = dpct::group::workgroup_store; - size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); - sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); - sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1,smem_quantiles1_acc_ct1.get_pointer(), ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); + kOptimizerStatic8bit1State(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1); }); }); } - /* - DPCT1010:245: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + break; case LION: - // in lion, the momentum update happens after the parameter update - /* - DPCT1049:67: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - /* - DPCT1054:304: The type of variable temp_storage is declared in device function with the name type_ct9. Adjust the code to make the type_ct9 declaration visible at the accessor declaration point. - */ - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); - using group_load_T = dpct::group::workgroup_load; - using group_load_T1 = dpct::group::workgroup_load; - using group_load_float1 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - - - using group_store_T = dpct::group::workgroup_store; - using group_store_float1 = dpct::group::workgroup_store; - size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); - sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); @@ -937,57 +845,39 @@ template void optimizerStatic8bit(T* p, T* g, cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); + kOptimizerStatic8bit1State(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1); }); }); } - /* - DPCT1010:246: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); - - //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait())); + DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - /* - DPCT1054:305: The type of variable temp_storage is declared in device function with the name type_ct8. Adjust the code to make the type_ct8 declaration visible at the accessor declaration point. - */ - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - using group_load_float1 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizerStatic8bit1State(buff_p, buff_g, buff_state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), ltacc_T, ltacc_float1); + kPreconditionOptimizerStatic8bit1State(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_state1); }); }); } - /* - DPCT1010:247: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + break; default: break; } - //back memcpy - q_ct1.memcpy((void*)(buff_g), (void*)(g), size); - q_ct1.memcpy((void*)(buff_p), (void*)(p), size); - q_ct1.memcpy((void*)(buff_state1), (void*)(state1), size); - q_ct1.memcpy((void*)(buff_state2), (void*)(state2), size); - } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; From f71243b25841083b8a6c6f9d726be79f8ac970ee Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 13 May 2024 06:57:57 -0700 Subject: [PATCH 35/66] refine 8 bit blockwise opt --- csrc/sycl/kernels.cpp | 257 ++++++++++++++---------------------------- csrc/sycl/ops.cpp | 118 +++++-------------- 2 files changed, 115 insertions(+), 260 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 8b30031a5..8856a074f 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -13,6 +13,8 @@ #include #include +#define FLT_MAX std::numeric_limits::max() +#define FLT_MIN std::numeric_limits::min() #define HLF_MAX 65504 @@ -404,7 +406,7 @@ unsigned char dQuantize(float* smem_code, const float rand, float x) } template -__dpct_inline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x) +__dpct_inline__ unsigned char quantize_2D(float *__restrict__ quadrants, float x) { int pivot = 127; int upper_pivot = 255; @@ -436,7 +438,7 @@ __dpct_inline__ unsigned char quantize_2D(float *__restrict__ quadrants, float * //val = i == 64 ? quadrants[0] : smem_code[pivot]; local_pivot -= offset; } - val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; + val = i >= 64 ? quadrants[local_pivot] : 0;//smem_code[pivot]; offset -= 1; } @@ -2300,15 +2302,16 @@ SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ buff_g, float *gnorm_vec } +//=========================8 bit blockwise==================================== + + #define LANES 2 #define QUAD 3 template -/* -DPCT1110:10: The total declared local variable size in device function kOptimizerStatic8bit2StateBlockwise exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. -*/ + SYCL_EXTERNAL void -kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, unsigned char* buff_state1, unsigned char* buff_state2, +kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, @@ -2319,9 +2322,9 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2, - sycl_la_T ltacc_T, sycl_la_T ltacc_T1, - sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, - sycl_la_T stacc_T, sycl_la_float1 stacc_float1, sycl_la_float stacc_float2) + const sycl_la &tacc, const sycl::accessor &dacc_g, + const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2) { //const int n_full = n + (n%BLOCK_SIZE); @@ -2346,22 +2349,22 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; - //typedef cub::BlockLoad LoadT; - //typedef cub::BlockLoad LoadChar; - - //typedef cub::BlockStore StoreChar; - //typedef cub::BlockStore StoreT; - - /* - union type_ct10{ - typename LoadT::TempStorage loadh; - typename LoadChar::TempStorage loadc; - typename StoreChar::TempStorage storec; - typename StoreT::TempStorage storeh; - }; - type_ct10 &temp_storage = *(type_ct10 *)temp_storage_ct1; - */ + + using group_load = dpct::group::workgroup_load>; + using group_load_uc = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store>; + using group_store_uc = dpct::group::workgroup_store>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_p = dacc_p.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); + auto *d_state2 = dacc_state2.get_multi_ptr().get(); + + + // init: 0.2 -> 0.23 // 0.23 -> 0.23 @@ -2374,9 +2377,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns smem_quantiles2[j][item_ct1.get_local_id(2)] = smem_quantiles2[0][item_ct1.get_local_id(2)]; } - /* - DPCT1065:174: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); #pragma unroll @@ -2391,32 +2392,20 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns { // loads: 0.23 -> 0.85/1.44 valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; - /* - DPCT1065:175: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:183: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_g[0], g_vals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); + - /* - DPCT1065:176: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:184: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2424,17 +2413,10 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state1[0], c1s); + group_load_uc(tmp).load(item_ct1, d_state1, c1s); + - /* - DPCT1065:177: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:185: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2442,8 +2424,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float2.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state2[0], c2s); + group_load_uc(tmp).load(item_ct1, d_state2, c2s); new_local_abs_max1 = -FLT_MAX; @@ -2487,9 +2468,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns smem_exchange2[0] = new_local_abs_max2; } - /* - DPCT1065:178: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); if(item_ct1.get_local_id(2) == 0) @@ -2503,14 +2482,9 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns new_local_abs_max2 = smem_exchange2[0]; } - /* - DPCT1065:179: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:186: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2518,8 +2492,7 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_p[0], p_vals); + group_load(tmp).load(item_ct1, d_p, p_vals); // reduce: 2.67/1.69 -> 2.67/1.70 # pragma unroll N_PER_TH @@ -2535,14 +2508,9 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns } // store: 0.85/1.44 -> 2.48/1.57 - /* - DPCT1065:180: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:187: Migration of cub::BlockStore::Store is not supported. - */ - //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2550,15 +2518,17 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_T.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_p[0], p_vals); + group_store(tmp).store(item_ct1, d_p, p_vals); // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], s1_vals[j] / new_local_abs_max1); - c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], s2_vals[j] / new_local_abs_max2); + //c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], s1_vals[j] / new_local_abs_max1); + //c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], s2_vals[j] / new_local_abs_max2); + + c1s[j] = quantize_2D<1>(quadrants1, s1_vals[j] / new_local_abs_max1); + c2s[j] = quantize_2D<0>(quadrants2, s2_vals[j] / new_local_abs_max2); // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) @@ -2571,14 +2541,8 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns } } - /* - DPCT1065:181: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:188: Migration of cub::BlockStore::Store is not supported. - */ - //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2586,18 +2550,10 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_float1.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_state1[0], c1s); - + group_store_uc(tmp).store(item_ct1, d_state1, c1s); - /* - DPCT1065:182: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:189: Migration of cub::BlockStore::Store is not supported. - */ - //StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2605,8 +2561,9 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_float2.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_state2[0], c2s); + group_store_uc(tmp).store(item_ct1, d_state2, c2s); + + item_ct1.barrier(sycl::access::fence_space::local_space); } } @@ -2615,12 +2572,9 @@ kOptimizerStatic8bit2StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns #define LANES 2 #define QUAD 3 template -/* -DPCT1110:11: The total declared local variable size in device function kOptimizerStatic8bit1StateBlockwise exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. -*/ SYCL_EXTERNAL void -kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, unsigned char* buff_state1, +kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, @@ -2630,9 +2584,10 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1, - sycl_la_T ltacc_T, sycl_la_T ltacc_T1, - sycl_la_float ltacc_float1,sycl_la_T stacc_T, - sycl_la_float stacc_float1 + const sycl_la &tacc, + const sycl::accessor &dacc_g, + const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1 ) { @@ -2651,22 +2606,18 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; - //typedef cub::BlockLoad LoadT; - //typedef cub::BlockLoad LoadChar; - - //typedef cub::BlockStore StoreChar; - //typedef cub::BlockStore StoreT; - - /* - union type_ct11{ - typename LoadT::TempStorage loadh; - typename LoadChar::TempStorage loadc; - typename StoreChar::TempStorage storec; - typename StoreT::TempStorage storeh; - }; + using group_load = dpct::group::workgroup_load>; + using group_load_uc = dpct::group::workgroup_load>; + + using group_store = dpct::group::workgroup_store>; + using group_store_uc = dpct::group::workgroup_store>; + + + auto *d_g = dacc_g.template get_multi_ptr().get(); + auto *d_p = dacc_p.template get_multi_ptr().get(); + auto *d_state1 = dacc_state1.get_multi_ptr().get(); - type_ct11 &temp_storage = *(type_ct11 *)temp_storage_ct1; - */ + // init: 0.2 -> 0.23 // 0.23 -> 0.23 @@ -2675,9 +2626,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns for(unsigned int j = 1; j < LANES; j++) smem_quantiles1[j][item_ct1.get_local_id(2)] = smem_quantiles1[0][item_ct1.get_local_id(2)]; - /* - DPCT1065:190: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); #pragma unroll @@ -2688,32 +2637,22 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns { // loads: 0.23 -> 0.85/1.44 valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; - /* - DPCT1065:191: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:197: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_g[0], g_vals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, g_vals); /* DPCT1065:192: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. */ item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:198: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2721,17 +2660,10 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_state1[0], c1s); + group_load_uc(tmp).load(item_ct1, d_state1, c1s); + - /* - DPCT1065:193: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:199: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2739,8 +2671,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T1.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_p[0], p_vals); + group_load(tmp).load(item_ct1, d_p, p_vals); new_local_abs_max1 = -FLT_MAX; @@ -2799,9 +2730,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns if(item_ct1.get_local_id(2) == 0) smem_exchange1[0] = new_local_abs_max1; - /* - DPCT1065:194: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); if(item_ct1.get_local_id(2) == 0) @@ -2836,30 +2765,23 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns } // store: 0.85/1.44 -> 2.48/1.57 - /* - DPCT1065:195: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:200: Migration of cub::BlockStore::Store is not supported. - */ - //StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_T.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_p[0], p_vals); + group_store(tmp).store(item_ct1, d_p, p_vals); // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], s1_vals[j] / new_local_abs_max1); - + //c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], s1_vals[j] / new_local_abs_max1); + c1s[j] = quantize_2D<1>(quadrants1, s1_vals[j] / new_local_abs_max1); // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) if(sycl::signbit(smem_quantiles1[lane_id][c1s[j]]) != sycl::signbit(s1_vals[j])) @@ -2871,14 +2793,8 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns } } - /* - DPCT1065:196: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:201: Migration of cub::BlockStore::Store is not supported. - */ - //StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2886,8 +2802,7 @@ kOptimizerStatic8bit1StateBlockwise(T* buff_p, T* __restrict__ const buff_g, uns // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_float1.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_state1[0], c1s); + group_store_uc(tmp).store(item_ct1, d_state1, c1s); } } diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 4df6bdb05..5837ead0b 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -884,6 +884,10 @@ catch (sycl::exception const &exc) { std::exit(1); } + + + + #define BLOCKSIZE_2STATE 2048 #define NUM_2STATE 8 #define BLOCKSIZE_1STATE 2048 @@ -898,18 +902,12 @@ template void optimizerStatic8bitBlockwise(T* p, T* g sycl::queue &q_ct1 = dev_ct1.in_order_queue(); sycl::context ctx = q_ct1.get_context(); int num_blocks = 0; - int size = NUM_BLOCK; + int size = BLOCKSIZE_2STATE; - T *buff_g,*buff_p; - unsigned char *buff_state1,*buff_state2; - *((void **)&buff_g) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_p) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_state1) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_state2) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_g), (void*)(g), size); - q_ct1.memcpy((void*)(buff_p), (void*)(p), size); - q_ct1.memcpy((void*)(buff_state1), (void*)(state1), size); - q_ct1.memcpy((void*)(buff_state2), (void*)(state2), size); + sycl::buffer buff_g(g,sycl::range<1>(size)); + sycl::buffer buff_p(p,sycl::range<1>(size)); + sycl::buffer buff_state1(state1,sycl::range<1>(size)); + sycl::buffer buff_state2(state2,sycl::range<1>(size)); switch(OPTIMIZER) @@ -921,43 +919,15 @@ template void optimizerStatic8bitBlockwise(T* p, T* g dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - /* - DPCT1101:306: 'LANES' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. - */ - /* - DPCT1101:307: 'LANES' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. - */ - - /* - DPCT1054:308: The type of variable temp_storage is declared in device function with the name type_ct10. Adjust the code to make the type_ct10 declaration visible at the accessor declaration point. - */ - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - using group_load_T = dpct::group::workgroup_load; - using group_load_T1 = dpct::group::workgroup_load; - using group_load_float1 = dpct::group::workgroup_load; - using group_load_float2 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_float2 = group_load_float2::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - sycl::local_accessor ltacc_float2(load_temp_storage_size_float2, cgh); - - - using group_store_T = dpct::group::workgroup_store; - using group_store_float1 = dpct::group::workgroup_store; - using group_store_float2 = dpct::group::workgroup_store; - size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_float2 = group_store_float2::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); - sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); - sycl::local_accessor stacc_float2(store_temp_storage_size_float2, cgh); + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); @@ -969,14 +939,11 @@ template void optimizerStatic8bitBlockwise(T* p, T* g cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_2STATE/NUM_2STATE), sycl::range<3>(1, 1, BLOCKSIZE_2STATE/NUM_2STATE)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit2StateBlockwise(buff_p, buff_g, buff_state1, buff_state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n,item_ct1, smem_quantiles1_acc_ct1, smem_quantiles2_acc_ct1,smem_exchange1_acc_ct1.get_pointer(), smem_exchange2_acc_ct1.get_pointer(),ltacc_T, ltacc_T1, ltacc_float1, ltacc_float2, stacc_T, stacc_float1, stacc_float2); + kOptimizerStatic8bit2StateBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n,item_ct1, smem_quantiles1_acc_ct1, smem_quantiles2_acc_ct1,smem_exchange1_acc_ct1.get_pointer(), smem_exchange2_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_state2); }); }); } - /* - DPCT1010:248: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + break; case MOMENTUM: case RMSPROP: @@ -988,35 +955,15 @@ template void optimizerStatic8bitBlockwise(T* p, T* g dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - /* - DPCT1101:309: 'LANES' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. - */ - /* - DPCT1054:310: The type of variable temp_storage is declared in device function with the name type_ct11. Adjust the code to make the type_ct11 declaration visible at the accessor declaration point. - */ - //sycl::local_accessor temp_storage_ct1_acc_ct1(cgh); - using group_load_T = dpct::group::workgroup_load; - using group_load_T1 = dpct::group::workgroup_load; - using group_load_float1 = dpct::group::workgroup_load; - size_t load_temp_storage_size_float1 = group_load_float1::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - size_t load_temp_storage_size_T1 = group_load_T1::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_T1(load_temp_storage_size_T1, cgh); - sycl::local_accessor ltacc_float1(load_temp_storage_size_float1, cgh); - - - using group_store_T = dpct::group::workgroup_store; - using group_store_float1 = dpct::group::workgroup_store; - size_t store_temp_storage_size_float1 = group_store_float1::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_T = group_store_T::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor stacc_T(store_temp_storage_size_T, cgh); - sycl::local_accessor stacc_float1(store_temp_storage_size_float1, cgh); - + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); + sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); sycl::local_accessor smem_exchange1_acc_ct1(sycl::range<1>(1), cgh); @@ -1024,21 +971,14 @@ template void optimizerStatic8bitBlockwise(T* p, T* g cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE), sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit1StateBlockwise(buff_p, buff_g, buff_state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, smem_quantiles1_acc_ct1, smem_exchange1_acc_ct1.get_pointer(), ltacc_T, ltacc_T1, ltacc_float1, stacc_T, stacc_float1); + kOptimizerStatic8bit1StateBlockwise(p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, smem_quantiles1_acc_ct1, smem_exchange1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1); }); }); } - /* - DPCT1010:249: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + break; } - q_ct1.memcpy((void*)(g), (void*)(buff_g), size); - q_ct1.memcpy((void*)(p), (void*)(buff_p), size); - q_ct1.memcpy((void*)(state1), (void*)(buff_state1), size); - q_ct1.memcpy((void*)(state2), (void*)(buff_state2), size); - + } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; From 493a5ec40d0d5632c50ba49171883f35baf49c08 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 13 May 2024 21:33:58 -0700 Subject: [PATCH 36/66] refine k quantize blockwise --- csrc/sycl/kernels.cpp | 64 +++++++--------- csrc/sycl/ops.cpp | 168 ++++++++++++++++++------------------------ 2 files changed, 95 insertions(+), 137 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 8856a074f..1cc0c311b 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -814,11 +814,15 @@ void kQuantize(float * code, float * __restrict__ const buff_A, unsigned char *b } } + +//===========================k quantize blockwise================================ + template //__launch_bounds__(TH, 4) -SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, +SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, - float *smem_absmax_value, sycl_la_T ltacc_T, sycl_la_float ltacc_float, sycl_la_unsigned_char stacc) + float *smem_absmax_value,const sycl_la &tacc,const sycl::accessor &dacc_A, + const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out) { @@ -833,6 +837,16 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ float local_abs_max = 0.0f; int local_rand_idx = 0; + using group_load = dpct::group::workgroup_load>; + using group_load_float = dpct::group::workgroup_load>; + using group_store_uc = dpct::group::workgroup_store>; + + auto *d_A = dacc_A.template get_multi_ptr().get(); + auto *d_rand = dacc_rand.get_multi_ptr().get(); + auto *d_out = dacc_out.get_multi_ptr().get(); + + + if(DATA_TYPE == General8bit) for(int i = item_ct1.get_local_id(2); i < 256; i+=item_ct1.get_local_range(2)) smem_code[i] = code[i]; @@ -851,8 +865,8 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_A[0], vals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_A, vals); // 1. compute local max // 2. broadcast local max @@ -862,18 +876,12 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ for(int j = 0; j < NUM_PER_TH; j++) local_abs_max = sycl::fmax(local_abs_max, sycl::fabs((float)vals[j])); - /* - DPCT1007:0: Migration of cub::Reduce is not supported. - */ - //local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, sycl::maximum<>(), valid_items); - local_abs_max = dpct::group::reduce(item_ct1, local_abs_max, sycl::maximum<>()); + local_abs_max = sycl::reduce_over_group(item_ct1.get_group(), local_abs_max, sycl::maximum<>()); if(item_ct1.get_local_id(2) == 0) smem_absmax_value[0] = local_abs_max; - /* - DPCT1065:85: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); if(item_ct1.get_local_id(2) == 0) @@ -888,11 +896,6 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ if(STOCHASTIC) { local_rand_idx = ((item_ct1.get_group(2)*NUM_BLOCK) + (item_ct1.get_local_id(2)*NUM) + rand_offset) % (1024-4); - /* - DPCT1007:88: Migration of cub::BlockLoad::Load is not supported. - */ - - //LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -900,8 +903,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_float.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_rand[0], rand_vals); + group_load_float(tmp).load(item_ct1, d_rand, rand_vals); } @@ -940,24 +942,20 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const buff_ item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:89: Migration of cub::BlockStore::Store is not supported. - */ - // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_out[0], qvals); - - //StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); + group_store_uc(tmp).store(item_ct1, d_out, qvals); + } } +//===========================k dequantize================================ + template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1, sycl_la tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out ) @@ -1109,18 +1107,6 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, auto *d_state1 = dacc_state1.get_multi_ptr().get(); auto *d_state2 = dacc_state2.get_multi_ptr().get(); - //typedef cub::BlockLoad Load; - //typedef cub::BlockLoad LoadFloat; - //typedef sycl::group<3> BlockReduce; - /* - union type_ct2{ - typename Load::TempStorage load; - typename LoadFloat::TempStorage loadf; - typename BlockReduce::TempStorage reduce; - }; - */ - //type_ct2 &temp_storage = *(type_ct2 *)temp_storage_ct1; - for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) { valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 5837ead0b..66d4a36d1 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -192,37 +192,30 @@ template void quantizeBlockwise(floa num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; sycl::context ctx = q_ct1.get_context(); int size= NUM_BLOCK; + for(int i=0; i< NUM_BLOCK; i++){ out[i]=out[(DATA_TYPE > 0) ? i/2 : i];}; - T *buff_A; - unsigned char *buff_out; - float *buff_rand; - *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_rand) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((T*)(buff_A), (T*)(A), NUM_BLOCK); - q_ct1.memcpy((void*)(buff_out), (void*)(out), NUM_BLOCK); - q_ct1.memcpy((void*)(buff_rand), (void*)(rand), NUM_BLOCK); - - for(int i=0; i< NUM_BLOCK; i++){ buff_out[i]=buff_out[(DATA_TYPE > 0) ? i/2 : i];}; + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_rand(rand,sycl::range<1>(size)); + sycl::buffer buff_out(out,sycl::range<1>(size)); if(blocksize == 4096) - /* - DPCT1049:57: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load_T = dpct::group::workgroup_load; - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - using group_store = dpct::group::workgroup_store; - size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); - using group_load_float = dpct::group::workgroup_load; - size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); - sycl::local_accessor stacc(store_temp_storage_size, cgh); + + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); + //__shared__ vars for funtions sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); @@ -232,31 +225,26 @@ template void quantizeBlockwise(floa cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); }); }); } else if(blocksize == 2048) - /* - DPCT1049:58: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); - using group_load_T = dpct::group::workgroup_load; - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - using group_store = dpct::group::workgroup_store; - size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); - using group_load_float = dpct::group::workgroup_load; - size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); - sycl::local_accessor stacc(store_temp_storage_size, cgh); - //__shared__ vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); @@ -264,7 +252,7 @@ template void quantizeBlockwise(floa cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); }); }); } @@ -274,16 +262,14 @@ template void quantizeBlockwise(floa q_ct1.submit( [&](sycl::handler &cgh) { - using group_load_T = dpct::group::workgroup_load; - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - using group_store = dpct::group::workgroup_store; - size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); - using group_load_float = dpct::group::workgroup_load; - size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); - sycl::local_accessor stacc(store_temp_storage_size, cgh); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); //__shared__vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); @@ -292,7 +278,7 @@ template void quantizeBlockwise(floa cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); }); }); } @@ -302,17 +288,16 @@ template void quantizeBlockwise(floa q_ct1.submit( [&](sycl::handler &cgh) { - using group_load_T = dpct::group::workgroup_load; - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - using group_store = dpct::group::workgroup_store; - size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); - using group_load_float = dpct::group::workgroup_load; - size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); + + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); - sycl::local_accessor stacc(store_temp_storage_size, cgh); - //__shared__ vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); @@ -320,7 +305,7 @@ template void quantizeBlockwise(floa cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); }); }); } @@ -330,16 +315,14 @@ template void quantizeBlockwise(floa q_ct1.submit( [&](sycl::handler &cgh) { - using group_load_T = dpct::group::workgroup_load; - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - using group_store = dpct::group::workgroup_store; - size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); - using group_load_float = dpct::group::workgroup_load; - size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); - sycl::local_accessor stacc(store_temp_storage_size, cgh); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); @@ -349,7 +332,7 @@ template void quantizeBlockwise(floa cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); }); }); } @@ -359,24 +342,22 @@ template void quantizeBlockwise(floa q_ct1.submit( [&](sycl::handler &cgh) { - using group_load_T = dpct::group::workgroup_load; - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - using group_store = dpct::group::workgroup_store; - size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); - using group_load_float = dpct::group::workgroup_load; - size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); - sycl::local_accessor stacc(store_temp_storage_size, cgh); - + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); }); }); } @@ -386,17 +367,15 @@ template void quantizeBlockwise(floa q_ct1.submit( [&](sycl::handler &cgh) { - using group_load_T = dpct::group::workgroup_load; - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - using group_store = dpct::group::workgroup_store; - size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); - using group_load_float = dpct::group::workgroup_load; - size_t load_temp_storage_size_float = group_load_float::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor ltacc_float(load_temp_storage_size_float, cgh); - sycl::local_accessor stacc(store_temp_storage_size, cgh); + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); @@ -405,17 +384,10 @@ template void quantizeBlockwise(floa cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, buff_A, absmax, buff_out, buff_rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), ltacc_T, ltacc_float, stacc); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); }); }); } - - - - //back memcpy - q_ct1.memcpy((void*)(A), (void*)(buff_A), NUM_BLOCK); - q_ct1.memcpy((void*)(out), (void*)(buff_out), NUM_BLOCK); - q_ct1.memcpy((void*)(rand), (void*)(buff_rand), NUM_BLOCK); } From 760a1203d3f566deb1c8c1c4ff6601d41df6d467 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Tue, 14 May 2024 00:12:00 -0700 Subject: [PATCH 37/66] refine k quantize --- csrc/sycl/kernels.cpp | 43 +++++++++++++---------------------- csrc/sycl/ops.cpp | 53 ++++++++++++++++++++++--------------------- 2 files changed, 43 insertions(+), 53 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 1cc0c311b..a503724c6 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -737,10 +737,11 @@ void kEstimateQuantiles(const T *A, float *code, const float offset, const T max } - +//====================================k quantize=========================================== SYCL_EXTERNAL -void kQuantize(float * code, float * __restrict__ const buff_A, unsigned char *buff_out, const int n, - const sycl::nd_item<3> &item_ct1, float* smem_code, sycl_la_float ltacc, sycl_la_unsigned_char stacc) +void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, + const sycl::nd_item<3> &item_ct1, float* smem_code, const sycl_la &tacc, const sycl_dacc_float &dacc_A, + const sycl_dacc_uc &dacc_out) { const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*NUM_BLOCK) : NUM_BLOCK; @@ -750,9 +751,11 @@ void kQuantize(float * code, float * __restrict__ const buff_A, unsigned char *b unsigned char qvals[NUM]; //const int lane_id = threadIdx.x % 2; - //typedef cub::BlockLoad LoadFloat; - //typedef cub::BlockStore StoreChar; - //__shared__ float smem_code[2][257]; + using group_load_float = dpct::group::workgroup_load>; + using group_store_uc = dpct::group::workgroup_store>; + + auto *d_A = dacc_A.template get_multi_ptr().get(); + auto *d_out = dacc_out.get_multi_ptr().get(); if(item_ct1.get_local_id(2) < 256) { @@ -769,12 +772,7 @@ void kQuantize(float * code, float * __restrict__ const buff_A, unsigned char *b // rand_offset % mod value valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; - /* - DPCT1118:50: SYCL group functions and algorithms must be encountered in converged control flow. You may need to adjust the code. - */ - /* - DPCT1065:224: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); // 1. load 8 values per thread @@ -783,22 +781,14 @@ void kQuantize(float * code, float * __restrict__ const buff_A, unsigned char *b // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_A[0], vals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load_float(tmp).load(item_ct1, d_A, vals); - //LoadFloat(loadf).Load(&(A[i]), vals, valid_items); - - #pragma unroll 4 for(int j = 0; j < NUM; j++) qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); - /* - DPCT1118:51: SYCL group functions and algorithms must be encountered in converged control flow. You may need to adjust the code. - */ - /* - DPCT1065:225: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); @@ -808,10 +798,9 @@ void kQuantize(float * code, float * __restrict__ const buff_A, unsigned char *b // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_out[0], qvals); - //StoreChar(storec).Store(&(out[i]), qvals, valid_items); - } + group_store_uc(tmp).store(item_ct1, d_out, qvals); + } + } diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 66d4a36d1..e4c972c81 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -56,6 +56,8 @@ void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *sr //CUDA_CHECK_RETURN(0); } + +//============================estimate quantiles=============================== template void estimateQuantiles(T *A, float *code, float offset, int n) { dpct::device_ext &dev_ct1 = dpct::get_current_device(); @@ -94,6 +96,7 @@ template void estimateQuantiles(T *A, float *code, float offset, in } +//============================k quantize =============================== void quantize(float *code, float *A, unsigned char *out, int n) { int num_blocks = n/1024; @@ -103,47 +106,36 @@ void quantize(float *code, float *A, unsigned char *out, int n) sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; - float *buff_A; - unsigned char *buff_out; - *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_A), (void*)(A), NUM_BLOCK); - q_ct1.memcpy((void*)(buff_out), (void*)(out), NUM_BLOCK); - - /* - DPCT1049:55: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_out(out,sycl::range<1>(size)); + + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load = dpct::group::workgroup_load; - size_t load_temp_storage_size = group_load::get_local_memory_size(NUM_BLOCK); - using group_store = dpct::group::workgroup_store; - size_t store_temp_storage_size = group_store::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc(load_temp_storage_size, cgh); - sycl::local_accessor stacc(store_temp_storage_size, cgh); - + using group_load = dpct::group::workgroup_load>; + + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kQuantize(code, buff_A, buff_out, n, item_ct1, smem_code_acc_ct1.get_pointer(), ltacc, stacc); + kQuantize(code, A, out, n, item_ct1, smem_code_acc_ct1.get_pointer(), tacc, dacc_A, dacc_out); }); }); } - //back memcpy - q_ct1.memcpy((void*)(A), (void *)(buff_A), NUM_BLOCK); - q_ct1.memcpy((void*)(out), (void*)(buff_out), NUM_BLOCK); - /* - DPCT1010:232: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + } + +//============================k dequantize=============================== void dequantize(float *code, unsigned char *A, float *out, int n) { int num_blocks = n/1024; @@ -184,6 +176,8 @@ void dequantize(float *code, unsigned char *A, float *out, int n) } +//============================quantize blockwise=============================== + template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) { dpct::device_ext &dev_ct1 = dpct::get_current_device(); @@ -391,6 +385,8 @@ template void quantizeBlockwise(floa } + +//============================k dequantize blockwise=============================== template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) { dpct::device_ext &dev_ct1 = dpct::get_current_device(); @@ -458,6 +454,8 @@ template void dequantizeBlockwise(float *code, unsign //} + +//============================32 bit optimizer=============================== template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, @@ -655,6 +653,7 @@ catch (sycl::exception const &exc) { +//============================8 bit optimizer=============================== template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, @@ -859,6 +858,7 @@ catch (sycl::exception const &exc) { +//============================8 bit blockwise optimizer=============================== #define BLOCKSIZE_2STATE 2048 #define NUM_2STATE 8 @@ -958,6 +958,7 @@ catch (sycl::exception const &exc) { } +//============================percentile clipping=============================== template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) { From c7d8326482c13e7e648e311ad542c6b202ab6fb1 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Tue, 14 May 2024 00:22:22 -0700 Subject: [PATCH 38/66] refine percentile clipping --- csrc/sycl/kernels.cpp | 35 ++++++++++++----------------------- csrc/sycl/ops.cpp | 37 +++++++++++++------------------------ 2 files changed, 25 insertions(+), 47 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index a503724c6..f71a74bf5 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -2208,22 +2208,19 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, } -//=========================================================================== +//===============================k percentile clipping============================================ template -SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ buff_g, float *gnorm_vec, int step, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T) +SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n, + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl::accessor &dacc_g) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); int valid_items = 0; - - //typedef cub::BlockLoad LoadT; - - //sycl::buffer buff_g(g, sycl::range<1>(NUM_VALS)); - + using group_load = dpct::group::workgroup_load>; + auto *d_g = dacc_g.template get_multi_ptr().get(); T vals[NUM_VALS]; float local_sum = 0.0f; @@ -2233,33 +2230,25 @@ SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ buff_g, float *gnorm_vec valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; local_sum = 0.0f; - /* - DPCT1065:202: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:203: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); - + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_g[0], vals); + auto *tmp = tacc.get_multi_ptr().get(); + group_load(tmp).load(item_ct1, d_g, vals); #pragma unroll NUM_VALS for(int j = 0; j < NUM_VALS; j++) local_sum += ((float)vals[j])*((float)vals[j]); - /* - DPCT1007:12: Migration of cub::Sum is not supported. - */ - locacl_sum = sycl::reduce_over_group(item_ct1.get_group(), local_sum, sycl::plus<>()); - //local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); + + local_sum = sycl::reduce_over_group(item_ct1.get_group(), local_sum, sycl::plus<>()); + if(item_ct1.get_local_id(2) == 0) { if(step == 1) diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index e4c972c81..19e3029c6 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -966,45 +966,34 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, sycl::queue &q_ct1 = dev_ct1.in_order_queue(); sycl::context ctx = q_ct1.get_context(); - int num_blocks = n/2048; - num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; - int size = NUM_BLOCK; - T *buff_g; - *((void **)&buff_g) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_g), (void*)(g), size); + int num_blocks = n/2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + int size = NUM_BLOCK; + sycl::buffer buff_g(g,sycl::range<1>(size)); + q_ct1.memset(&gnorm_vec[step % 100], 0, 1*sizeof(float)).wait(); - //CUDA_CHECK_RETURN(DPCT_CHECK_ERROR(q_ct1.memset(&gnorm_vec[step % 100], 0, 1*sizeof(float)).wait())); - DPCT_CHECK_ERROR(q_ct1.memset(&gnorm_vec[step % 100], 0, 1*sizeof(float)).wait()); - /* - DPCT1049:68: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load_T = dpct::group::workgroup_load; - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + + sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kPercentileClipping(g, gnorm_vec, step, n, item_ct1, ltacc_T); + kPercentileClipping(g, gnorm_vec, step, n, item_ct1, tacc, dacc_g); }); }); } - /* - DPCT1010:250: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); - //back memcpy - q_ct1.memcpy((void*)(g), (void*)(buff_g), size); + } - - - //========================GEMM============================ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) From 693ca7979dacf4b3ddf0274768cda173cd91b99e Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Tue, 14 May 2024 03:44:55 -0700 Subject: [PATCH 39/66] refine estimate quantiles and k dequantize & headers --- CMakeLists.txt | 4 +- csrc/sycl/kernels.cpp | 114 +++++++++++++++++------------------------- csrc/sycl/kernels.h | 64 ++++++++++++------------ csrc/sycl/ops.cpp | 22 ++++---- 4 files changed, 92 insertions(+), 112 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5332a5536..b8eec42b7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,7 +28,7 @@ set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) -set(SYCL_FILES csrc/sycl/ops.cpp csrc/sycl/kernels.cpp) +set(SYCL_FILES csrc/sycl/ops.cpp csrc/sycl/kernels.cpp) # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) @@ -199,7 +199,7 @@ elseif(BUILD_SYCL) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=spir64 -L${MKLROOT}/lib") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ferror-limit=590") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ferror-limit=80") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w") list(APPEND SRC_FILES ${SYCL_FILES}) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index f71a74bf5..d4af83b12 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -518,14 +518,6 @@ SYCL_EXTERNAL void kHistogramScatterAdd2D(float* histogram, int *index1, int *in } } -void warpreduceKernelMax(int* data, const sycl::nd_item<3> &item_ct1) { - int threadid = item_ct1.get_local_id(2); - int input = data[threadid]; - int output = 0; - output = sycl::reduce_over_group(item_ct1.get_sub_group(), input, sycl::maximum<>()); - data[threadid] = output; -} - template void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n, const sycl::nd_item<3> &item_ct1, int *smem_max_indices, @@ -533,10 +525,6 @@ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, cons { - //typename WarpReduce::TempStorage temp_storage; - //typedef cub::BlockLoad LoadT; - //typename LoadT::TempStorage loadt; - const int warp_idx = item_ct1.get_local_id(2)/32; const int valid_items = n - (item_ct1.get_group(2)*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (item_ct1.get_group(2)*BLOCK_SIZE); @@ -553,14 +541,12 @@ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, cons sycl::buffer buff_values(smem_max_values, sycl::range<1>(8*BLOCK_SIZE/32)); sycl::buffer buff_A(A,sycl::range<1>(8*BLOCK_SIZE/32)); - dpct::get_in_order_queue().submit([&](sycl::handler &h) { - + dpct::get_in_order_queue().submit([&](sycl::handler &cgh) { - using group_load = dpct::group::workgroup_load; + using group_load = dpct::group::workgroup_load<8, dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, int, int *, sycl::nd_item<3>>; size_t temp_storage_size = group_load::get_local_memory_size(8*BLOCK_SIZE/32); - sycl::local_accessor tacc( - temp_storage_size, h); - sycl::accessor dacc(buff_A[(item_ct1.get_local_id(2)*BLOCK_SIZE)], h, sycl::read_write); + sycl::local_accessor tacc(temp_storage_size, h); + sycl::accessor dacc_A(buff_A[(item_ct1.get_local_id(2)*BLOCK_SIZE)], cgh, sycl::read_write); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -568,11 +554,11 @@ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, cons // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - h.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), [=](sycl::nd_item<3> item) { - auto *d = dacc.get_multi_ptr().get(); + auto *d_A = dacc_A.get_multi_ptr().get(); auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item_ct1,item_ct1.get_local_linear_id(), d, values); + group_load(tmp).load(item_ct1, d_A, values); }); @@ -643,12 +629,12 @@ typedef sycl::accessor sycl_dacc; typedef sycl::accessor sycl_dacc_float; typedef sycl::accessor sycl_dacc_uc; -//=========================================================== +//======================estimte quantiles===================================== template SYCL_EXTERNAL void kEstimateQuantiles(const T *A, float *code, const float offset, const T max_val, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la tacc, sycl_dacc dacc) + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_A) { const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; @@ -659,7 +645,7 @@ void kEstimateQuantiles(const T *A, float *code, const float offset, const T max using group_radix_sort = dpct::group::radix_sort; T vals[NUM_ESTIMATE]; - auto *d = dacc.get_multi_ptr().get(); + auto *d_A = dacc_A.template get_multi_ptr().get(); int smem_qidx[BLOCK_ESTIMATE]; @@ -685,7 +671,7 @@ void kEstimateQuantiles(const T *A, float *code, const float offset, const T max // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index auto *tmp = tacc.get_multi_ptr().get(); - group_load(tmp).load(item_ct1, d, vals); + group_load(tmp).load(item_ct1, d_A, vals); #pragma unroll 4 for(int j = 0; j < NUM_ESTIMATE; j++) @@ -698,8 +684,6 @@ void kEstimateQuantiles(const T *A, float *code, const float offset, const T max // striped pattern index for thread 0 [0, 1024, 2048, 3096] // striped pattern index for thread 1 [1, 1025, 2049, 3097] - //BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); - // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back @@ -947,7 +931,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out ) + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out ) { const int n_load = (item_ct1.get_group_range(2) * TILE_SIZE); @@ -4815,15 +4799,15 @@ template unsigned char dQuantize<0>(float* smem_code, const float rand, float x) template unsigned char dQuantize<1>(float* smem_code, const float rand, float x); template SYCL_EXTERNAL void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc); + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A); template SYCL_EXTERNAL void kEstimateQuantiles(sycl::half *__restrict__ const A, float *code, const float offset, const sycl::half max_val, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc); + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A); #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ float* state1, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_float ltacc_float1); \ + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1); \ SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(MOMENTUM, sycl::half) SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) @@ -4837,7 +4821,7 @@ SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ template void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1,sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_T stacc_T, sycl_la_float stacc_float1); \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1); \ SYCL_EXTERNAL MAKE_Optimizer32bit1State(MOMENTUM, sycl::half) SYCL_EXTERNAL MAKE_Optimizer32bit1State(MOMENTUM, float) @@ -4853,7 +4837,7 @@ SYCL_EXTERNAL MAKE_Optimizer32bit1State(ADAGRAD, float) template void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ float* state1, float* state2, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,sycl_la_T ltacc_T, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2); \ + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g); \ SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, float) SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::half) @@ -4861,16 +4845,13 @@ SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::ext::oneapi::bfl template SYCL_EXTERNAL void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2 - sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2); template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::half* g, sycl::half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2 - sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2); template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::ext::oneapi::bfloat16* g, sycl::ext::oneapi::bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2 - sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ template void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ @@ -4882,7 +4863,7 @@ template void kPreconditionOptimizerStatic8bit1State(gtype* p, gty float* max1, float* new_max1, \ const float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1,sycl_la_T ltacc_T, sycl_la_float ltacc_float1); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1,const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1); \ SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(MOMENTUM, sycl::half) SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(MOMENTUM, float) @@ -4901,8 +4882,9 @@ template void kOptimizerStatic8bit1State(gtype* p, gtype* const g, float* max1, float* new_max1, \ float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1,sycl_la_T ltacc_T, sycl_la_T ltacc_T1,sycl_la_float ltacc_float1, sycl_la_T stacc_T, \ - sycl_la_float stacc_float1); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, const sycl_la &tacc, \ + const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, \ + const sycl_dacc_uc &dacc_state1); \ SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(MOMENTUM, sycl::half) SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(MOMENTUM, float) @@ -4919,7 +4901,7 @@ template void kPreconditionOptimizerStatic8bit2State(gtype* p, gty float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2,sycl_la_T ltacc_T, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); \ SYCL_EXTERNAL MAKE_PreconditionStatic8bit2State(ADAM, sycl::half) SYCL_EXTERNAL MAKE_PreconditionStatic8bit2State(ADAM, float) @@ -4933,18 +4915,18 @@ template void kOptimizerStatic8bit2State(gtype* p, gtype* const g, float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2,sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2,const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); \ SYCL_EXTERNAL MAKE_optimizerStatic8bit2State(ADAM, sycl::half) SYCL_EXTERNAL MAKE_optimizerStatic8bit2State(ADAM, float) template SYCL_EXTERNAL void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T); + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc); template SYCL_EXTERNAL void kPercentileClipping(sycl::half * __restrict__ g, float *gnorm_vec, int step, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T); + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc); #define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ -template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value,sycl_la_T ltacc_T, sycl_la_float ltacc_float,sycl_la_unsigned_char stacc); \ +template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value,const sycl_la &tacc,const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out); \ SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, General8bit) SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 1, General8bit) @@ -5014,24 +4996,16 @@ SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, N SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, NF4) SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, NF4) -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ template void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ @@ -5040,7 +5014,10 @@ template void kOptimizerStatic8bit2StateBlockwise &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2,sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, sycl_la_T stacc_T, sycl_la_float1 stacc_float1, sycl_la_float stacc_float2); \ + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2,const sycl_la &tacc, \ + const sycl::accessor &dacc_g, \ + const sycl::accessor &dacc_p, \ + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); \ SYCL_EXTERNAL MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) SYCL_EXTERNAL MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::half, 2048, 8) @@ -5055,7 +5032,10 @@ template void kOptimizerStatic8bit1StateBlockwise &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1,sycl_la_T ltacc_T, sycl_la_T ltacc_T1,sycl_la_float ltacc_float1,sycl_la_T stacc_T, sycl_la_float stacc_float1); \ + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1,const sycl_la &tacc, \ + const sycl::accessor &dacc_g, \ + const sycl::accessor &dacc_p, \ + const sycl_dacc_uc &dacc_state1); \ SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::half, 2048, 8) diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 8060eaec3..0c343bc63 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -13,20 +13,21 @@ #pragma once -typedef sycl::local_accessor sycl_la_float; -typedef sycl::local_accessor sycl_la_T; -typedef sycl::local_accessor sycl_la_unsigned_char; -typedef sycl::local_accessor sycl_la_half; -typedef sycl::local_accessor sycl_la_unsigned; -typedef sycl::local_accessor sycl_la_char; +//================typedefs=================================== + +typedef sycl::local_accessor sycl_la; +typedef sycl::accessor sycl_dacc; +typedef sycl::accessor sycl_dacc_float; +typedef sycl::accessor sycl_dacc_uc; //template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); template extern SYCL_EXTERNAL void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl::accessor &dacc_A); extern SYCL_EXTERNAL void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, - const sycl::nd_item<3> &item_ct1, float* smem_code, sycl_la_float ltacc, sycl_la_unsigned_char stacc); + const sycl::nd_item<3> &item_ct1, float* smem_code, const sycl_la &tacc, const sycl_dacc_float &dacc_A, + const sycl_dacc_uc &dacc_out); extern SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code); @@ -34,38 +35,38 @@ template &item_ct1, float *smem_code, float *smem_absmax_value, - sycl_la_T ltacc_T, sycl_la_float ltacc_float,sycl_la_unsigned_char stacc); -template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1, sycl_la_unsigned_char ltacc, sycl_la_T stacc); + const sycl_la &tacc,const sycl::accessor &dacc_A, + const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out); +template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out); template extern SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g); template extern SYCL_EXTERNAL void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, - sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2); template extern SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_float ltacc_float1); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1); template extern SYCL_EXTERNAL void kOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, sycl_la_float ltacc_float1, - sycl_la_T stacc_T, sycl_la_float stacc_float1); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const + sycl_dacc_float &dacc_state1); template extern SYCL_EXTERNAL void @@ -79,7 +80,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, - sycl_la_T ltacc_T, sycl_la_float ltacc_float1); + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1); template @@ -92,9 +93,9 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, - sycl_la_T ltacc_T, sycl_la_T ltacc_T1, - sycl_la_float ltacc_float1, sycl_la_T stacc_T, - sycl_la_float stacc_float1); + const sycl_la &tacc, + const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state11); @@ -109,8 +110,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, - sycl_la_T ltacc_T, - sycl_la_float ltacc_float1, sycl_la_float ltacc_float2); + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); template @@ -123,9 +123,8 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, - float *smem_quantiles2, sycl_la_T ltacc_T, sycl_la_T ltacc_T1, - sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, - sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); + float *smem_quantiles2,const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); template extern SYCL_EXTERNAL void kOptimizerStatic8bit2StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, @@ -136,9 +135,9 @@ template extern SYCL_EX sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2, - sycl_la_T ltacc_T, sycl_la_T ltacc_T1, - sycl_la_float ltacc_float1, sycl_la_float ltacc_float2, - sycl_la_T stacc_T, sycl_la_float stacc_float1, sycl_la_float stacc_float2); + const sycl_la &tacc, const sycl::accessor &dacc_g, + const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); template extern SYCL_EXTERNAL void kOptimizerStatic8bit1StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, @@ -151,12 +150,13 @@ template extern SYCL_EX const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1, - sycl_la_T ltacc_T, sycl_la_T ltacc_T1, - sycl_la_float ltacc_float1,sycl_la_T stacc_T, - sycl_la_float stacc_float1); + const sycl_la &tacc, + const sycl::accessor &dacc_g, + const sycl::accessor &dacc_p, + const sycl_dacc_uc &dacc_state1); -template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n,const sycl::nd_item<3> &item_ct1, sycl_la_T ltacc_T); +template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g); extern SYCL_EXTERNAL void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 19e3029c6..516cfd40a 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -29,6 +29,7 @@ #define NUM_ESTIMATE 8 #define BLOCK_ESTIMATE 4096 #define NUM_PER_THREAD 4 + using namespace dnnl; typedef sycl::ext::oneapi::bfloat16 bf16; @@ -42,18 +43,13 @@ void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *sr int threads = 512; int num_blocks = n/threads; num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; - /* - DPCT1049:53: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + dpct::get_in_order_queue().parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { kHistogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n, item_ct1); }); - /* - DPCT1010:229: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + } @@ -77,10 +73,9 @@ template void estimateQuantiles(T *A, float *code, float offset, in [&](sycl::handler &cgh) { using group_load = dpct::group::workgroup_load>; - using group_radix_sort = dpct::group::radix_sort; size_t temp_storage_size = group_radix_sort::get_local_memory_size(THREADS_ESTIMATE); sycl::local_accessor tacc(sycl::range<1>(temp_storage_size), cgh); - sycl::accessor dacc(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); auto std_numeric_limits_T_max_ct3 = std::numeric_limits::max(); @@ -88,7 +83,7 @@ template void estimateQuantiles(T *A, float *code, float offset, in cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kEstimateQuantiles(A, code, offset, std_numeric_limits_T_max_ct3, n, item_ct1, tacc, dacc); + kEstimateQuantiles(A, code, offset, std_numeric_limits_T_max_ct3, n, item_ct1, tacc, dacc_A); }); }); @@ -438,7 +433,7 @@ template void dequantizeBlockwise(float *code, unsign cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, (n+tile_size-1)/tile_size) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { - kDequantizeBlockwise(code, buff_A, absmax, buff_out, blocksize, n, item_ct1, tacc, dacc_A, dacc_out); + kDequantizeBlockwise(code, A, absmax, out, blocksize, n, item_ct1, tacc, dacc_A, dacc_out); }); }); } @@ -655,6 +650,11 @@ catch (sycl::exception const &exc) { //============================8 bit optimizer=============================== +#define NUM8BIT 16 +#define NUM_THREADS 256 +#define NUM_PER_BLOCK 4096 + + template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, float *unorm, float max_unorm, float param_norm, From 3a41f965ce09b4a87af46cdd52dec826b01a7a4c Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Tue, 14 May 2024 22:43:27 -0700 Subject: [PATCH 40/66] fix errors in k compress --- csrc/sycl/kernels.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index d4af83b12..3ed1d4a17 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -545,7 +545,7 @@ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, cons using group_load = dpct::group::workgroup_load<8, dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, int, int *, sycl::nd_item<3>>; size_t temp_storage_size = group_load::get_local_memory_size(8*BLOCK_SIZE/32); - sycl::local_accessor tacc(temp_storage_size, h); + sycl::local_accessor tacc(temp_storage_size, cgh); sycl::accessor dacc_A(buff_A[(item_ct1.get_local_id(2)*BLOCK_SIZE)], cgh, sycl::read_write); // 1. load 8 values per thread @@ -556,7 +556,7 @@ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, cons // 6. store with byte index cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), [=](sycl::nd_item<3> item) { - auto *d_A = dacc_A.get_multi_ptr().get(); + auto *d_A = dacc_A.template get_multi_ptr().get(); auto *tmp = tacc.get_multi_ptr().get(); group_load(tmp).load(item_ct1, d_A, values); @@ -589,7 +589,7 @@ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, cons { // 3. do warp reduction + broadcast back - output = sycl::reduce_over_group(item_ct1.get_sub_group(), max1, sycl::maximum<>()); + auto output = sycl::reduce_over_group(item_ct1.get_sub_group(), max1, sycl::maximum<>()); warp_max = item_ct1.get_sub_group().shuffle(warp_max, 0); // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest @@ -597,7 +597,7 @@ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, cons { hacc_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1; - hacc_indices_indices[warp_idx*8 + i] = max_idx1; + hacc_indices[warp_idx*8 + i] = max_idx1; sign1 = sign2; max1 = max2; @@ -943,9 +943,9 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; - using group_load_uc = dpct::group::workgroup_load>; + using group_load_uc = dpct::group::workgroup_load>; - using group_store = dpct::group::workgroup_store>; + using group_store = dpct::group::workgroup_store>; auto *d_A = dacc_A.template get_multi_ptr().get(); From 37897ee21e4bd675a1674a90563a6605a4634a17 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 15 May 2024 00:35:09 -0700 Subject: [PATCH 41/66] refine template types for k quants --- csrc/sycl/kernels.cpp | 323 ++++++++++++++++++++++-------------------- 1 file changed, 168 insertions(+), 155 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 3ed1d4a17..84dbac311 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -812,7 +812,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl using group_load = dpct::group::workgroup_load>; using group_load_float = dpct::group::workgroup_load>; - using group_store_uc = dpct::group::workgroup_store>; + using group_store_uc = dpct::group::workgroup_store<(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH, dpct::group::store_algorithm::BLOCK_STORE_DIRECT, unsigned char, unsigned char *, sycl::nd_item<3>>; auto *d_A = dacc_A.template get_multi_ptr().get(); auto *d_rand = dacc_rand.get_multi_ptr().get(); @@ -945,7 +945,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * using group_load_uc = dpct::group::workgroup_load>; - using group_store = dpct::group::workgroup_store>; + using group_store = dpct::group::workgroup_store 0) ? 2 : 1), dpct::group::store_algorithm::BLOCK_STORE_DIRECT, T, T *, sycl::nd_item<3>>; auto *d_A = dacc_A.template get_multi_ptr().get(); @@ -4795,66 +4795,73 @@ template void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(sycl::half *__restrict__ co float *smem_row_stats, unsigned int *smem_nnz_row_idx); + +//==================supported template decls======================================================= + + template unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template unsigned char dQuantize<1>(float* smem_code, const float rand, float x); -template SYCL_EXTERNAL void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n, +template SYCL_EXTERNAL void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A); -template SYCL_EXTERNAL void kEstimateQuantiles(sycl::half *__restrict__ const A, float *code, const float offset, const sycl::half max_val, const int n, +template SYCL_EXTERNAL void kEstimateQuantiles(sycl::half *__restrict__ const A, float *code, const float offset, const sycl::half max_val, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A); #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ -template void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ +template SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ float* state1, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1); \ - -SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(MOMENTUM, sycl::half) -SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) -SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(RMSPROP, sycl::half) -SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) -SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(LION, sycl::half) -SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(LION, float) -SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(LION, sycl::ext::oneapi::bfloat16) -SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(ADAGRAD, sycl::half) -SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1); \ + +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, sycl::half) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, sycl::half) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(LION, sycl::half) +MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(LION, sycl::ext::oneapi::bfloat16) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, sycl::half) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ -template void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1); \ - -SYCL_EXTERNAL MAKE_Optimizer32bit1State(MOMENTUM, sycl::half) -SYCL_EXTERNAL MAKE_Optimizer32bit1State(MOMENTUM, float) -SYCL_EXTERNAL MAKE_Optimizer32bit1State(RMSPROP, sycl::half) -SYCL_EXTERNAL MAKE_Optimizer32bit1State(RMSPROP, float) -SYCL_EXTERNAL MAKE_Optimizer32bit1State(LION, sycl::half) -SYCL_EXTERNAL MAKE_Optimizer32bit1State(LION, float) -SYCL_EXTERNAL MAKE_Optimizer32bit1State(LION, sycl::ext::oneapi::bfloat16) -SYCL_EXTERNAL MAKE_Optimizer32bit1State(ADAGRAD, sycl::half) -SYCL_EXTERNAL MAKE_Optimizer32bit1State(ADAGRAD, float) +template SYCL_EXTERNAL void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1); \ + +MAKE_Optimizer32bit1State(MOMENTUM, sycl::half) +MAKE_Optimizer32bit1State(MOMENTUM, float) +MAKE_Optimizer32bit1State(RMSPROP, sycl::half) +MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(LION, sycl::half) +MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(LION, sycl::ext::oneapi::bfloat16) +MAKE_Optimizer32bit1State(ADAGRAD, sycl::half) +MAKE_Optimizer32bit1State(ADAGRAD, float) #define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ -template void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ +template SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ float* state1, float* state2, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g); \ + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g); \ + +MAKE_PreconditionOptimizer32bit2State(ADAM, float) +MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::half) +MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::ext::oneapi::bfloat16) -SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, float) -SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::half) -SYCL_EXTERNAL MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::ext::oneapi::bfloat16) template SYCL_EXTERNAL void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2); + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, const sycl_dacc_float &dacc_state1, const sycl_dacc_float &dacc_state2); + template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::half* g, sycl::half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2); + template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::ext::oneapi::bfloat16* g, sycl::ext::oneapi::bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ -template void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ +template SYCL_EXTERNAL void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ float *unorm, \ const float beta1, \ const float beta2, \ @@ -4863,14 +4870,14 @@ template void kPreconditionOptimizerStatic8bit1State(gtype* p, gty float* max1, float* new_max1, \ const float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1,const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1,const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1); \ -SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(MOMENTUM, sycl::half) -SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(MOMENTUM, float) -SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(RMSPROP, sycl::half) -SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(RMSPROP, float) -SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(LION, sycl::half) -SYCL_EXTERNAL MAKE_PreconditionStatic8bit1State(LION, float) +MAKE_PreconditionStatic8bit1State(MOMENTUM, sycl::half) +MAKE_PreconditionStatic8bit1State(MOMENTUM, float) +MAKE_PreconditionStatic8bit1State(RMSPROP, sycl::half) +MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LION, sycl::half) +MAKE_PreconditionStatic8bit1State(LION, float) #define MAKE_optimizerStatic8bit1State(oname, gtype) \ template void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ @@ -4883,15 +4890,15 @@ template void kOptimizerStatic8bit1State(gtype* p, gtype* const g, float weight_decay, \ const float gnorm_scale, \ const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, const sycl_la &tacc, \ - const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, \ + const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, \ const sycl_dacc_uc &dacc_state1); \ -SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(MOMENTUM, sycl::half) -SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(MOMENTUM, float) -SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(RMSPROP, sycl::half) -SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(RMSPROP, float) -SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(LION, sycl::half) -SYCL_EXTERNAL MAKE_optimizerStatic8bit1State(LION, float) +MAKE_optimizerStatic8bit1State(MOMENTUM, sycl::half) +MAKE_optimizerStatic8bit1State(MOMENTUM, float) +MAKE_optimizerStatic8bit1State(RMSPROP, sycl::half) +MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LION, sycl::half) +MAKE_optimizerStatic8bit1State(LION, float) #define MAKE_PreconditionStatic8bit2State(oname, gtype) \ template void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ @@ -4901,10 +4908,10 @@ template void kPreconditionOptimizerStatic8bit2State(gtype* p, gty float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); \ -SYCL_EXTERNAL MAKE_PreconditionStatic8bit2State(ADAM, sycl::half) -SYCL_EXTERNAL MAKE_PreconditionStatic8bit2State(ADAM, float) +MAKE_PreconditionStatic8bit2State(ADAM, sycl::half) +MAKE_PreconditionStatic8bit2State(ADAM, float) #define MAKE_optimizerStatic8bit2State(oname, gtype) \ template void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ @@ -4915,97 +4922,103 @@ template void kOptimizerStatic8bit2State(gtype* p, gtype* const g, float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2,const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); \ -SYCL_EXTERNAL MAKE_optimizerStatic8bit2State(ADAM, sycl::half) -SYCL_EXTERNAL MAKE_optimizerStatic8bit2State(ADAM, float) +MAKE_optimizerStatic8bit2State(ADAM, sycl::half) +MAKE_optimizerStatic8bit2State(ADAM, float) template SYCL_EXTERNAL void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n, - const sycl::nd_item<3> &item_ct1, const sycl_la &tacc); + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g); template SYCL_EXTERNAL void kPercentileClipping(sycl::half * __restrict__ g, float *gnorm_vec, int step, const int n, - const sycl::nd_item<3> &item_ct1, const sycl_la &tacc); + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g); + #define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ -template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value,const sycl_la &tacc,const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out); \ - -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 1, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) - -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 1, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 2048, 4, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 1024, 4, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 512, 2, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, General8bit) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 2048, 4, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 1024, 4, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 512, 2, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, FP4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 2048, 4, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 1024, 4, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 512, 2, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, NF4) -SYCL_EXTERNAL MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, NF4) - -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); - -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, sycl_dacc_uc &dacc_A, sycl::accessor &dacc_out); +template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value,const sycl_la &tacc,const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out); + +MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::half, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, NF4) + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); + + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ template void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ @@ -5015,13 +5028,13 @@ template void kOptimizerStatic8bit2StateBlockwise &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2,const sycl_la &tacc, \ - const sycl::accessor &dacc_g, \ - const sycl::accessor &dacc_p, \ + const sycl::accessor &dacc_g, \ + const sycl::accessor &dacc_p, \ const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); \ -SYCL_EXTERNAL MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) -SYCL_EXTERNAL MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::half, 2048, 8) -SYCL_EXTERNAL MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::ext::oneapi::bfloat16, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::ext::oneapi::bfloat16, 2048, 8) #define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ @@ -5033,17 +5046,17 @@ template void kOptimizerStatic8bit1StateBlockwise &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1,const sycl_la &tacc, \ - const sycl::accessor &dacc_g, \ - const sycl::accessor &dacc_p, \ + const sycl::accessor &dacc_g, \ + const sycl::accessor &dacc_p, \ const sycl_dacc_uc &dacc_state1); \ -SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) -SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::half, 2048, 8) -SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) -SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, sycl::half, 2048, 8) -SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) -SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::half, 2048, 8) -SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::ext::oneapi::bfloat16, 2048, 8) -SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) -SYCL_EXTERNAL MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::ext::oneapi::bfloat16, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, sycl::half, 2048, 8) From 7ddce7446c701d2fcde9282bacedd7a74bb7f806 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 15 May 2024 00:59:37 -0700 Subject: [PATCH 42/66] remove mma header --- csrc/sycl/kernels.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 84dbac311..deb99ecb0 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -10,7 +10,7 @@ #include #include "kernels.h" #include -#include +//#include #include #define FLT_MAX std::numeric_limits::max() @@ -634,7 +634,7 @@ typedef sycl::accessor sycl_dacc_uc; template SYCL_EXTERNAL void kEstimateQuantiles(const T *A, float *code, const float offset, const T max_val, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_A) + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A) { const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; @@ -931,7 +931,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out ) + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out ) { const int n_load = (item_ct1.get_group_range(2) * TILE_SIZE); From c039aa1472883e5f4c8307b8a8d874e69f992a0d Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 15 May 2024 04:06:33 -0700 Subject: [PATCH 43/66] refine historgram --- csrc/sycl/kernels.cpp | 4 +++- csrc/sycl/ops.cpp | 15 +++++++++++++-- csrc/sycl/ops.h | 5 +++-- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index deb99ecb0..0c5273b80 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -504,7 +504,7 @@ __dpct_inline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict_ return pivot; } } - +//=====================================histogram 2d==================== SYCL_EXTERNAL void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, const sycl::nd_item<3> &item_ct1) { @@ -518,6 +518,8 @@ SYCL_EXTERNAL void kHistogramScatterAdd2D(float* histogram, int *index1, int *in } } +//===========================k compress max========================== + template void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n, const sycl::nd_item<3> &item_ct1, int *smem_max_indices, diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 516cfd40a..a5faf3bd3 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -38,21 +38,32 @@ using namespace BinSearch; using std::cout; using std::endl; +//================================histogram 2d============================================== + void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) { + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + int threads = 512; int num_blocks = n/threads; num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; - dpct::get_in_order_queue().parallel_for( + { + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { kHistogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n, item_ct1); + }); }); + } } - //============================estimate quantiles=============================== template void estimateQuantiles(T *A, float *code, float offset, int n) { diff --git a/csrc/sycl/ops.h b/csrc/sycl/ops.h index 84850263b..3071b8456 100644 --- a/csrc/sycl/ops.h +++ b/csrc/sycl/ops.h @@ -3,7 +3,7 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. - +#pragma once #ifndef ops_H #define ops_H @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -164,7 +165,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); -extern SYCL_EXTERNAL void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, From b9e9a9c22cff90d6d0c8b2772731935cf7d1b7b6 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 15 May 2024 20:51:08 -0700 Subject: [PATCH 44/66] refine row stats --- csrc/sycl/kernels.cpp | 78 +++++++++++++++---------------------------- csrc/sycl/ops.cpp | 44 +++++++++--------------- 2 files changed, 42 insertions(+), 80 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 0c5273b80..38861ce71 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -2757,7 +2757,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char } } -template void kgetColRowStats(T * __restrict__ buff_A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, sycl_la_half ltacc_half, sycl_la_unsigned exacc) +//==========================k get row col stats========================================== + +template void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A) { // 0. reset stats to -FLT_MAX // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) @@ -2774,20 +2776,14 @@ template LoadT; - typedef sycl::group<3> BlockRowReduce; - typedef sycl::group<3> BlockRowSum; - typedef cub::BlockExchange BlockExchange; - - union type_ct12{ - typename BlockExchange::TempStorage exchange; - typename BlockRowReduce::TempStorage rowreduce; - typename BlockRowSum::TempStorage rowsum; - typename LoadT::TempStorage loadt; - }; - type_ct12 &temp_storage = *(type_ct12 *)temp_storage_ct1; - */ + + using group_load = dpct::group::workgroup_load>; + using group_exchange = dpct::group::exchange; + + + auto *d_A = dacc_A.template get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + sycl::half local_data[ITEMS_PER_THREAD]; float local_data_fp32[ITEMS_PER_THREAD]; @@ -2812,9 +2808,6 @@ template items_per_load ? items_per_load : cols - base_col; @@ -2827,14 +2820,9 @@ template(0.0f).convert()[0]); + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -2842,8 +2830,7 @@ template().get(); - group_load(tmp).load(item, &buff_A[0], local_data); + group_load(tmp).load(item_ct1, d_A, local_data); #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) @@ -2876,17 +2863,13 @@ template()); + row_absmax = (float)sycl::reduce_over_group(item_ct1.get_group(), local_data_fp32[0], sycl::maximum<>()); if(SPARSE_DECOMP) { - /* - DPCT1065:214: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); local_row_nnz_count = sycl::reduce_over_group(item_ct1.get_group(), local_row_nnz_count, sycl::plus<>()); } @@ -2900,9 +2883,7 @@ template().get(); - group_exchange(tmp).blocked_to_striped(item, local_col_absmax_values); + group_exchange(tmp).blocked_to_striped(item_ct1, local_col_absmax_values); #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) @@ -2951,12 +2925,9 @@ template(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, - const sycl::nd_item<3> &item_ct1, - sycl_la_half ltacc_half, sycl_la_unsigned exacc); -template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, - const sycl::nd_item<3> &item_ct1, - sycl_la_half ltacc_half, sycl_la_unsigned exacc); + +//========================================k dequant mm int32fp16=================== + #define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) @@ -4801,6 +4772,9 @@ template void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(sycl::half *__restrict__ co //==================supported template decls======================================================= +template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A); +template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A); + template unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template unsigned char dQuantize<1>(float* smem_code, const float rand, float x); diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index a5faf3bd3..7954e8516 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1267,10 +1267,8 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c int num_blocks = row_tiles * col_tiles; int size = NUM_BLOCK; - sycl::half *buff_A; - *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_A), (void*)(A), size); + sycl::buffer buff_A(A,sycl::range<1>(size)); if(nnz_threshold == 0.0) { @@ -1278,14 +1276,11 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load_half = dpct::group::workgroup_load; - using group_exchange = dpct::group::exchange; - size_t load_temp_storage_size_half = group_load_half::get_local_memory_size(NUM_BLOCK); - size_t exchange_temp_storage_size = group_exchange::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor exacc(exchange_temp_storage_size, cgh); - sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_row_absmax_values_acc_ct1(sycl::range<1>(256), cgh); @@ -1295,8 +1290,8 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kgetColRowStats(buff_A, rowStats, colStats, - nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols,item_ct1, smem_row_absmax_values_acc_ct1.get_pointer(), smem_row_nnz_values_acc_ct1.get_pointer(), ltacc_half, exacc); + kgetColRowStats(A, rowStats, colStats, + nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols,item_ct1, smem_row_absmax_values_acc_ct1.get_pointer(), smem_row_nnz_values_acc_ct1.get_pointer(), tacc, dacc_A); }); }); } @@ -1305,14 +1300,11 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load_half = dpct::group::workgroup_load; - using group_exchange = dpct::group::exchange; - size_t load_temp_storage_size_half = group_load_half::get_local_memory_size(NUM_BLOCK); - size_t exchange_temp_storage_size = group_exchange::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor exacc(exchange_temp_storage_size, cgh); - sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_row_absmax_values_acc_ct1(sycl::range<1>(256), cgh); @@ -1322,21 +1314,17 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kgetColRowStats(buff_A, rowStats, colStats, - nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols,item_ct1, smem_row_absmax_values_acc_ct1.get_pointer(), smem_row_nnz_values_acc_ct1.get_pointer(), ltacc_half, exacc); + kgetColRowStats(A, rowStats, colStats, + nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols,item_ct1, smem_row_absmax_values_acc_ct1.get_pointer(), smem_row_nnz_values_acc_ct1.get_pointer(), tacc, dacc_A); }); }); } - - /* - DPCT1010:284: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); - q_ct1.memcpy((void*)(A), (void*)(buff_A), size); - } + +//===================================double row col quant====================== + void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int *nnz_block_ptr, float threshold, int rows, int cols) { dpct::device_ext &dev_ct1 = dpct::get_current_device(); From c8473b53057011c54f4ed8b20176525eb9cafbac Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Thu, 16 May 2024 00:25:24 -0700 Subject: [PATCH 45/66] refine double row col quants --- csrc/sycl/kernels.cpp | 65 +++++++++++++++---------------------------- csrc/sycl/ops.cpp | 57 +++++++++++++------------------------ 2 files changed, 43 insertions(+), 79 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 38861ce71..126cc3f5c 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -630,6 +630,7 @@ typedef sycl::local_accessor sycl_la; typedef sycl::accessor sycl_dacc; typedef sycl::accessor sycl_dacc_float; typedef sycl::accessor sycl_dacc_uc; +typedef sycl::accessor sycl_dacc_char; //======================estimte quantiles===================================== @@ -3079,8 +3080,9 @@ template void kdequant_mm_i } } +//=====================================k double row col quant============================ -template void kDoubleRowColQuant(sycl::half *__restrict__ const buff_A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *buff_out_col_normed, char *buff_out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, sycl_la_half ltacc_half, sycl_la_char stacc_char1, sycl_la_char stacc_char2 ) +template void kDoubleRowColQuant(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed) { // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD // Each thread reads the same column but multiple rows @@ -3100,12 +3102,14 @@ template LoadHalf; + using group_load_half = dpct::group::workgroup_load>; + using group_store_char = dpct::group::workgroup_store>; - //typedef cub::BlockStore StoreInt8; + auto *d_A = dacc_A.get_multi_ptr().get(); + auto *d_out_col_normed = dacc_out_col_normed.get_multi_ptr().get(); + auto *d_out_row_normed = dacc_out_row_normed.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); - sycl::half local_data[ITEMS_PER_THREAD]; float local_col_stats[ITEMS_PER_THREAD]; char local_quantized_data[ITEMS_PER_THREAD]; @@ -3115,7 +3119,7 @@ template items_per_load ? items_per_load : cols - base_col; - - /* - DPCT1007:218: Migration of cub::BlockLoad::Load is not supported. - */ - //LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f); - // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_half.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_A[0], local_data); + group_load_half(tmp).load(item_ct1, d_A, local_data); float row_stat = 127.0f / smem_row_stats[row]; @@ -3184,19 +3179,13 @@ template (local_data[j]).convert()[0]*row_stat)); } - /* - DPCT1007:219: Migration of cub::BlockStore::Store is not supported. - */ - //StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items); - // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = stacc_char1.get_multi_ptr().get(); - group_store(tmp).store(item, &buff_out_row_normed[0], local_quantized_data); + group_store_char(tmp).store(item_ct1, d_out_row_normed, local_quantized_data); // 2. quantize data with row/col stats #pragma unroll ITEMS_PER_THREAD @@ -3207,13 +3196,9 @@ template (local_data[j]).convert()[0]*local_col_stats[j])); } - /* - DPCT1065:217: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - /* - DPCT1007:220: Migration of cub::BlockStore::Store is not supported. - */ + //StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items); // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) @@ -3221,13 +3206,13 @@ template ().get(); - group_store(tmp).store(item, &buff_out_col_normed[0], local_quantized_data); + group_store_char(tmp).store(item_ct1, d_out_col_normed, local_quantized_data); } } + +//================================================================================================= /* DPCT1110:14: The total declared local variable size in device function kTransformRowToFormat exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. */ @@ -4759,17 +4744,13 @@ template void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, flo const sycl::nd_item<3> &item_ct1, float *smem_rowStats); -template void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, - const sycl::nd_item<3> &item_ct1, - float *smem_row_stats, - unsigned int *smem_nnz_row_idx); -template void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, - const sycl::nd_item<3> &item_ct1, - float *smem_row_stats, - unsigned int *smem_nnz_row_idx); +//==================supported template decls======================================================= + +template void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed); + +template void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed); -//==================supported template decls======================================================= template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A); diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 7954e8516..2b2c1899c 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1332,14 +1332,10 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::context ctx = q_ct1.get_context(); int num_blocks = 0; int size = NUM_BLOCK; - sycl::half *buff_A; - char *buff_out_row_normed, *buff_out_col_normed; - *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_out_row_normed) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_out_col_normed) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_A), (void*)(A), size); - q_ct1.memcpy((void*)(buff_out_row_normed), (void*)(out_row_normed), size); - q_ct1.memcpy((void*)(buff_out_col_normed), (void*)(out_col_normed), size); + + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_out_col_normed(out_col_normed,sycl::range<1>(size)); + sycl::buffer buff_out_row_normed(out_row_normed,sycl::range<1>(size)); int threads = 64; int items_per_thread = 4; @@ -1359,18 +1355,14 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load_half = dpct::group::workgroup_load; - using group_store_char1 = dpct::group::workgroup_store; - using group_store_char2 = dpct::group::workgroup_store; - - size_t load_temp_storage_size_half = group_load_half::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_char1 = group_store_char1::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_char2 = group_store_char2::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); - sycl::local_accessor stacc_char1(store_temp_storage_size_char1, cgh); - sycl::local_accessor stacc_char2(store_temp_storage_size_char2, cgh); + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out_col_normed(buff_out_col_normed, cgh, sycl::read_write); + sycl::accessor dacc_out_row_normed(buff_out_row_normed, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_row_stats_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_nnz_row_idx_acc_ct1(sycl::range<1>(256), cgh); @@ -1379,7 +1371,7 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kDoubleRowColQuant(buff_A, rowStats, colStats, buff_out_col_normed, buff_out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols, item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), ltacc_half, stacc_char1, stacc_char2); + kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols, item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), tacc, dacc_A, dacc_out_col_normed, dacc_out_row_normed); }); }); } @@ -1389,17 +1381,14 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load_half = dpct::group::workgroup_load; - using group_store_char1 = dpct::group::workgroup_store; - using group_store_char2 = dpct::group::workgroup_store; - size_t load_temp_storage_size_half = group_load_half::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_char1 = group_store_char1::get_local_memory_size(NUM_BLOCK); - size_t store_temp_storage_size_char2 = group_store_char2::get_local_memory_size(NUM_BLOCK); + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out_col_normed(buff_out_col_normed, cgh, sycl::read_write); + sycl::accessor dacc_out_row_normed(buff_out_row_normed, cgh, sycl::read_write); - sycl::local_accessor ltacc_half(load_temp_storage_size_half, cgh); - sycl::local_accessor stacc_char1(store_temp_storage_size_char1, cgh); - sycl::local_accessor stacc_char2(store_temp_storage_size_char2, cgh); //__shared__ vars sycl::local_accessor smem_row_stats_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_nnz_row_idx_acc_ct1(sycl::range<1>(256), cgh); @@ -1409,20 +1398,14 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols,item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), ltacc_half, stacc_char1, stacc_char2); + kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols,item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), tacc, dacc_A, dacc_out_col_normed, dacc_out_row_normed); }); }); } - /* - DPCT1010:285: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); - q_ct1.memcpy((void*)(A), (void*)(buff_A), size); - q_ct1.memcpy((void*)(out_row_normed), (void*)(buff_out_row_normed), size); - q_ct1.memcpy((void*)(out_col_normed), (void*)(buff_out_col_normed), size); } +//======================================= transform row to format=============================================== template void transformRowToFormat(char * A, char *out, int rows, int cols) { From c85baf9778bfc67e628c06da8200c7e86b63b952 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Thu, 16 May 2024 03:14:02 -0700 Subject: [PATCH 46/66] refine helper functions --- csrc/sycl/kernels.cpp | 84 +++--- csrc/sycl/kernels.h | 19 +- csrc/sycl/ops.cpp | 645 +++++++++++++++++++++--------------------- 3 files changed, 372 insertions(+), 376 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 126cc3f5c..568a6506b 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -3212,11 +3212,10 @@ template SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data) +//============================================k transform row format===================================================== + +template SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, +const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out) { // 0. Load data into 32*32 shared memory tiles @@ -3267,9 +3266,6 @@ template BlockExchange; - - // we load row after row from the base_position // Load data row by row @@ -3294,7 +3290,7 @@ template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1) +//========================================k extract outliers====================== + +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out) { int local_colidx = idx[item_ct1.get_group(2)]; @@ -3733,10 +3732,10 @@ template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, cha offset += tile_offset_rows + tile_offset_cols; - char val = A[offset]; + char val = dacc_A[offset]; int out_idx = (row*idx_size) + item_ct1.get_group(2); - out[out_idx] = val; + dacc_out[out_idx] = val; } } else if(FORMAT == COL_AMPERE) @@ -3757,12 +3756,11 @@ template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, cha char val = A[offset]; int out_idx = (row*idx_size) + item_ct1.get_group(2); - out[out_idx] = val; + dacc_out[out_idx] = val; } } } - //template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) //{ //// element-wise kernel @@ -4696,12 +4694,7 @@ template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, float *quant_map); - -template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, - const sycl::nd_item<3> &item_ct1); -template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, - const sycl::nd_item<3> &item_ct1); - + template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, sycl::half *smem_dequant_stats); @@ -4721,24 +4714,6 @@ template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int const sycl::nd_item<3> &item_ct1, sycl::half *smem_dequant_stats); -template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, - const sycl::nd_item<3> &item_ct1, - char *smem_data); -template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, - const sycl::nd_item<3> &item_ct1, - char *smem_data); -template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, - const sycl::nd_item<3> &item_ct1, - char *smem_data); -template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, - const sycl::nd_item<3> &item_ct1, - char *smem_data); -template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, - const sycl::nd_item<3> &item_ct1, - char *smem_data); -template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, - const sycl::nd_item<3> &item_ct1, - char *smem_data); template void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, @@ -4747,6 +4722,27 @@ template void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, flo //==================supported template decls======================================================= +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + + + +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + +template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); + + + + template void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed); template void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed); diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 0c343bc63..9977f5303 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -19,7 +19,9 @@ typedef sycl::local_accessor sycl_la; typedef sycl::accessor sycl_dacc; typedef sycl::accessor sycl_dacc_float; typedef sycl::accessor sycl_dacc_uc; +typedef sycl::accessor sycl_dacc_char; +//=========================================================== //template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); template extern SYCL_EXTERNAL void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n, @@ -180,10 +182,10 @@ extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( const sycl::nd_item<3> &item_ct1, float *smem_rowStats, sycl_la_T ltacc_T, sycl_la_float exacc); template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, - const sycl::nd_item<3> &item_ct1,float *smem_row_absmax_values,int *smem_row_nnz_values, sycl_la_half ltacc_half, sycl_la_unsigned exacc); + const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A); + template -extern SYCL_EXTERNAL void kDoubleRowColQuant(sycl::half *__restrict__ const A, + int SPARSE_DECOMP> extern SYCL_EXTERNAL void kDoubleRowColQuant(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, @@ -192,15 +194,14 @@ extern SYCL_EXTERNAL void kDoubleRowColQuant(sycl::half *__restrict__ const A, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, - sycl_la_half ltacc_half, sycl_la_char stacc_char1, sycl_la_char stacc_char2); + const sycl_la &tacc, const sycl::accessor &dacc_A, + const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed); -template extern SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, - const sycl::nd_item<3> &item_ct1, - char *smem_data); +template extern SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, +const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); -template extern SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, - const sycl::nd_item<3> &item_ct1); +template extern SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); template extern SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B); template extern SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 2b2c1899c..bdd8d2aae 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1241,8 +1241,273 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, sycl::half */ //CUDA_CHECK_RETURN(0); q_ct1.memcpy((void*)(A), (void*)(buff_A), size); + +} + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) +{ + + int num_blocks = (m+31)/32; + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + int size = NUM_BLOCK; + T *buff_A, *buff_B, *buff_out; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_A), (void*)(A), size); + *(( void**)&buff_B) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_B), (void*)(B), size); + *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_out), (void*)(out), size); + + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + //if(bits == 32) + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + { + dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + /* + DPCT1101:312: '8*16 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. + */ + //sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(224/*8*16 + (2*16*(batch_size_warps-1))*/), cgh); + /* + DPCT1101:313: '2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. + */ + //__shared__ vars + sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(224/*8*16 + (2*16*(batch_size_warps-1))*/), cgh); + sycl::local_accessor smem_B_acc_ct1(sycl::range<1>(4192/*2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))*/), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 160), sycl::range<3>(1, 1, 160)), + [=](sycl::nd_item<3> item_ct1) { + gemm_device(m, n, k, buff_A, buff_B, buff_out, lda, ldb, ldc, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer()); + }); + }); + } + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //back memcpy + q_ct1.memcpy((void*)(A), (void*)(buff_A), size); + q_ct1.memcpy((void*)(B), (void*)(buff_B), size); + q_ct1.memcpy((void*)(out), (void*)(buff_out), size); + +} + +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + int size = NUM_BLOCK; + T *buff_A, *buff_out; + unsigned char *buff_B; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_A), (void*)(A), size); + *(( void**)&buff_B) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_B), (void*)(B), size); + *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_out), (void*)(out), size); + + + { + dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + /* + DPCT1101:314: '8*16 + (16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. + */ + + /* + DPCT1101:315: '2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. + */ + + //__shared__ vars + sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(176/*8*16 + (16*(batch_size_warps-1))*/), cgh); + sycl::local_accessor smem_B_acc_ct1(sycl::range<1>(4192/*2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))*/), cgh); + sycl::local_accessor smem_C_acc_ct1(sycl::range<1>(8*32), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 96), sycl::range<3>(1, 1, 96)), + [=](sycl::nd_item<3> item_ct1) { + kgemm_4bit_inference(m, n, k, buff_A, buff_B, absmax, buff_out, lda, ldb, ldc, blocksize, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer(), smem_C_acc_ct1.get_pointer()); + }); + }); + } + //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //back memcpy + q_ct1.memcpy((void*)(A), (void*)(buff_A), size); + q_ct1.memcpy((void*)(B), (void*)(buff_B), size); + q_ct1.memcpy((void*)(out), (void*)(buff_out), size); + +} + +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+3)/4; + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + int size = NUM_BLOCK; + T *buff_A, *buff_out; + unsigned char *buff_B; + *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_A), (void*)(A), size); + *(( void**)&buff_B) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_B), (void*)(B), size); + *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); + q_ct1.memcpy((void*)(buff_out), (void*)(out), size); + + { + dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + sycl::local_accessor quant_map_acc_ct1(sycl::range<1>(16), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + kgemm_4bit_inference_naive(m, n, k, buff_A, buff_B, absmax, datatype, buff_out, lda, ldb, ldc, blocksize, item_ct1, quant_map_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:291: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + //CUDA_CHECK_RETURN(0); + q_ct1.memcpy((void*)(A), (void*)(buff_A), size); + q_ct1.memcpy((void*)(B), (void*)(buff_B), size); + q_ct1.memcpy((void*)(out), (void*)(buff_out), size); + +} + +//================================spm coo================================== + +void spmm_coo(sycl::queue* handle, int *A_rowidx, int *A_colidx, sycl::half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, sycl::half *B, int ldc, sycl::half* C, bool transposed_B) +{ + + try{ + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + +//#ifdef NO_CUBLASLT +//#else + + + dpct::sparse::sparse_matrix_desc_t descA; + std::shared_ptr descB, descC; + + float alpha = 1.0f; + float beta = 0.0f; + void *dBuffer = NULL; + size_t bufferSize = 0; + + /* + DPCT1007:287: Migration of cusparseCreateCoo is not supported. + */ + //CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz, + // A_rowidx, A_colidx, A_vals, + // dpct::library_data_t::real_int32, + // oneapi::mkl::index_base::zero, dpct::library_data_t::real_half) ); + // Create dense matrix C + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descC = std::make_shared(A_rows, B_cols, ldc, C, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major)) ); + descC = std::make_shared(A_rows, B_cols, ldc, C, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major); + // Create dense matrix B + if(transposed_B) + { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; + } + + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descB = std::make_shared(A_cols, B_cols, ldb, B, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major)) ); + descB = std::make_shared(A_cols, B_cols, ldb, B, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major); + // allocate an external buffer if needed + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(bufferSize = 0) ); + bufferSize = 0; + //CUDA_CHECK_RETURN( DPCT_CHECK_ERROR(dBuffer = (void *)sycl::malloc_device(bufferSize, q_ct1)) ); + dBuffer = (void *)sycl::malloc_device(bufferSize, q_ct1); + + // execute SpMM + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(dpct::sparse::spmm(*handle, oneapi::mkl::transpose::nontrans, transposed_B ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, &alpha, descA, descB, &beta, descC, dpct::library_data_t::real_float))); + dpct::sparse::spmm(*handle, oneapi::mkl::transpose::nontrans, transposed_B ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, &alpha, descA, descB, &beta, descC, dpct::library_data_t::real_float); + // destroy matrix/vector descriptors + descA.reset(); + descB.reset(); + descC.reset(); + sycl::free(dBuffer, q_ct1); + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descA.reset()) ); + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descB.reset()) ); + //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descC.reset()) ); + //CUDA_CHECK_RETURN( DPCT_CHECK_ERROR(sycl::free(dBuffer, q_ct1)) ); +//#endif + } + catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); + } + } +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ + + { + dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + /* + DPCT1101:311: 'SMEM_SIZE' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. + */ + sycl::local_accessor smem_dequant_stats_acc_ct1(sycl::range<1>(2048/*SMEM_SIZE*/), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nnz_rows) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), + [=](sycl::nd_item<3> item_ct1) { + kspmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB, item_ct1, smem_dequant_stats_acc_ct1.get_pointer()); + }); + }); + } + /* + DPCT1010:289: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. + */ + //CUDA_CHECK_RETURN(0); +} + + +//======================================non gemm 2d quants============================================ //===========================Row col stats================================= @@ -1405,8 +1670,7 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o } } -//======================================= transform row to format=============================================== - +//========================== transform row to format================================ template void transformRowToFormat(char * A, char *out, int rows, int cols) { @@ -1415,11 +1679,9 @@ template void transformRowToFormat(char * A, char *o sycl::context ctx = q_ct1.get_context(); int num_blocks = 0; int size = NUM_BLOCK; - char *buff_A, *buff_out; - *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); - *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_A), (void*)(A), size); - q_ct1.memcpy((void*)(buff_out), (void*)(out), size); + + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_out(out,sycl::range<1>(size)); int threads = 256; @@ -1435,150 +1697,52 @@ template void transformRowToFormat(char * A, char *o col_tiles = col_tiles > 0 ? col_tiles : 1; num_blocks = row_tiles * col_tiles; - int outCols = fill_up_to_nearest_multiple(cols, 32); - int outRows = fill_up_to_nearest_multiple(rows, 32); - if(FORMAT == COL_TURING) - { - if(TRANSPOSE) - outRows = fill_up_to_nearest_multiple(cols, 8); - else - outRows = fill_up_to_nearest_multiple(rows, 8); - } - else if(FORMAT == COL_AMPERE) - { - if(TRANSPOSE) - outRows = fill_up_to_nearest_multiple(cols, 32); - else - outRows = fill_up_to_nearest_multiple(rows, 32); - } - else - { - if(TRANSPOSE) - { - outCols = fill_up_to_nearest_multiple(rows, 32); - outRows = cols; - } - } - - /* - DPCT1049:69: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - dpct::get_in_order_queue().submit( - [&](sycl::handler &cgh) { - - - - //__shared__ vars - sycl::local_accessor smem_data_acc_ct1(sycl::range<1>(32*33*8), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, threads), sycl::range<3>(1, 1, threads)), - [=](sycl::nd_item<3> item_ct1) { - kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT>(buff_A, buff_out, rows, cols, tiledCols, outRows, outCols, item_ct1, smem_data_acc_ct1.get_pointer()); - }); - }); - /* - DPCT1010:286: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - - //CUDA_CHECK_RETURN(0); - - q_ct1.memcpy((void*)(A), (void*)(buff_A), size); - q_ct1.memcpy((void*)(out), (void*)(buff_out), size); - -} - -void spmm_coo(sycl::queue* handle, int *A_rowidx, int *A_colidx, sycl::half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, sycl::half *B, int ldc, sycl::half* C, bool transposed_B) -{ - - try{ - dpct::device_ext &dev_ct1 = dpct::get_current_device(); - sycl::queue &q_ct1 = dev_ct1.in_order_queue(); - -//#ifdef NO_CUBLASLT -//#else - - - dpct::sparse::sparse_matrix_desc_t descA; - std::shared_ptr descB, descC; - - float alpha = 1.0f; - float beta = 0.0f; - void *dBuffer = NULL; - size_t bufferSize = 0; - - /* - DPCT1007:287: Migration of cusparseCreateCoo is not supported. - */ - //CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz, - // A_rowidx, A_colidx, A_vals, - // dpct::library_data_t::real_int32, - // oneapi::mkl::index_base::zero, dpct::library_data_t::real_half) ); - // Create dense matrix C - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descC = std::make_shared(A_rows, B_cols, ldc, C, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major)) ); - descC = std::make_shared(A_rows, B_cols, ldc, C, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major); - // Create dense matrix B - if(transposed_B) - { - int tmp = A_cols; - A_cols = B_cols; - B_cols = tmp; - } - - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descB = std::make_shared(A_cols, B_cols, ldb, B, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major)) ); - descB = std::make_shared(A_cols, B_cols, ldb, B, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major); - // allocate an external buffer if needed - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(bufferSize = 0) ); - bufferSize = 0; - //CUDA_CHECK_RETURN( DPCT_CHECK_ERROR(dBuffer = (void *)sycl::malloc_device(bufferSize, q_ct1)) ); - dBuffer = (void *)sycl::malloc_device(bufferSize, q_ct1); - - // execute SpMM - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(dpct::sparse::spmm(*handle, oneapi::mkl::transpose::nontrans, transposed_B ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, &alpha, descA, descB, &beta, descC, dpct::library_data_t::real_float))); - dpct::sparse::spmm(*handle, oneapi::mkl::transpose::nontrans, transposed_B ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, &alpha, descA, descB, &beta, descC, dpct::library_data_t::real_float); - // destroy matrix/vector descriptors - descA.reset(); - descB.reset(); - descC.reset(); - sycl::free(dBuffer, q_ct1); - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descA.reset()) ); - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descB.reset()) ); - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descC.reset()) ); - //CUDA_CHECK_RETURN( DPCT_CHECK_ERROR(sycl::free(dBuffer, q_ct1)) ); -//#endif + int outCols = fill_up_to_nearest_multiple(cols, 32); + int outRows = fill_up_to_nearest_multiple(rows, 32); + if(FORMAT == COL_TURING) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 8); + else + outRows = fill_up_to_nearest_multiple(rows, 8); } - catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; - std::exit(1); + else if(FORMAT == COL_AMPERE) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 32); + else + outRows = fill_up_to_nearest_multiple(rows, 32); } - -} - -template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) -{ - + else { - dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); - dpct::get_in_order_queue().submit( - [&](sycl::handler &cgh) { - /* - DPCT1101:311: 'SMEM_SIZE' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. - */ - sycl::local_accessor smem_dequant_stats_acc_ct1(sycl::range<1>(2048/*SMEM_SIZE*/), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, nnz_rows) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item_ct1) { - kspmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB, item_ct1, smem_dequant_stats_acc_ct1.get_pointer()); - }); - }); + if(TRANSPOSE) + { + outCols = fill_up_to_nearest_multiple(rows, 32); + outRows = cols; + } } - /* - DPCT1010:289: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + + + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + + + //__shared__ vars + sycl::local_accessor smem_data_acc_ct1(sycl::range<1>(32*33*8), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, threads), sycl::range<3>(1, 1, threads)), + [=](sycl::nd_item<3> item_ct1) { + kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT>(A, out, rows, cols, tiledCols, outRows, outCols, item_ct1, smem_data_acc_ct1.get_pointer(), dacc_A, dacc_out); + }); + }); + } +//===========================extract outliers=========================== template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols) { @@ -1598,191 +1762,23 @@ template void extractOutliers(char * A, int *idx, char *out, int id tiledRows = fill_up_to_nearest_multiple(rows, 32); } - /* - DPCT1049:70: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ - dpct::get_in_order_queue().parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, threads), sycl::range<3>(1, 1, threads)), - [=](sycl::nd_item<3> item_ct1) { - kExtractOutliers(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols, item_ct1); - }); - /* - DPCT1010:290: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); -} - - - - -template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) -{ - - int num_blocks = (m+31)/32; - dpct::device_ext &dev_ct1 = dpct::get_current_device(); sycl::queue &q_ct1 = dev_ct1.in_order_queue(); sycl::context ctx = q_ct1.get_context(); - - int size = NUM_BLOCK; - T *buff_A, *buff_B, *buff_out; - *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_A), (void*)(A), size); - *(( void**)&buff_B) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_B), (void*)(B), size); - *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_out), (void*)(out), size); - - - //cout << num_blocks << endl; - //cout << lda << endl; - //cout << ldb << endl; - //cout << ldc << endl; - - //cout << m << endl; - //cout << n << endl; - //cout << k << endl; - //if(bits == 32) - //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - if(bits == 16) - //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - { - dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); - dpct::get_in_order_queue().submit( - [&](sycl::handler &cgh) { - /* - DPCT1101:312: '8*16 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. - */ - //sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(224/*8*16 + (2*16*(batch_size_warps-1))*/), cgh); - /* - DPCT1101:313: '2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. - */ - //__shared__ vars - sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(224/*8*16 + (2*16*(batch_size_warps-1))*/), cgh); - sycl::local_accessor smem_B_acc_ct1(sycl::range<1>(4192/*2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))*/), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 160), sycl::range<3>(1, 1, 160)), - [=](sycl::nd_item<3> item_ct1) { - gemm_device(m, n, k, buff_A, buff_B, buff_out, lda, ldb, ldc, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer()); - }); - }); - } - //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //back memcpy - q_ct1.memcpy((void*)(A), (void*)(buff_A), size); - q_ct1.memcpy((void*)(B), (void*)(buff_B), size); - q_ct1.memcpy((void*)(out), (void*)(buff_out), size); -} - -template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) -{ - - int num_blocks = (m+31)/32; - - //cout << num_blocks << endl; - //cout << lda << endl; - //cout << ldb << endl; - //cout << ldc << endl; - - //cout << m << endl; - //cout << n << endl; - //cout << k << endl; - - dpct::device_ext &dev_ct1 = dpct::get_current_device(); - sycl::queue &q_ct1 = dev_ct1.in_order_queue(); - sycl::context ctx = q_ct1.get_context(); + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { - int size = NUM_BLOCK; - T *buff_A, *buff_out; - unsigned char *buff_B; - *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_A), (void*)(A), size); - *(( void**)&buff_B) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_B), (void*)(B), size); - *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_out), (void*)(out), size); - + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, threads), sycl::range<3>(1, 1, threads)), + [=](sycl::nd_item<3> item_ct1) { + kExtractOutliers(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols, item_ct1); + }); + }); - { - dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); - dpct::get_in_order_queue().submit( - [&](sycl::handler &cgh) { - /* - DPCT1101:314: '8*16 + (16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. - */ - - /* - DPCT1101:315: '2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. - */ - - //__shared__ vars - sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(176/*8*16 + (16*(batch_size_warps-1))*/), cgh); - sycl::local_accessor smem_B_acc_ct1(sycl::range<1>(4192/*2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))*/), cgh); - sycl::local_accessor smem_C_acc_ct1(sycl::range<1>(8*32), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 96), sycl::range<3>(1, 1, 96)), - [=](sycl::nd_item<3> item_ct1) { - kgemm_4bit_inference(m, n, k, buff_A, buff_B, absmax, buff_out, lda, ldb, ldc, blocksize, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer(), smem_C_acc_ct1.get_pointer()); - }); - }); - } - //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); - //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); - //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); - //back memcpy - q_ct1.memcpy((void*)(A), (void*)(buff_A), size); - q_ct1.memcpy((void*)(B), (void*)(buff_B), size); - q_ct1.memcpy((void*)(out), (void*)(buff_out), size); - } -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) -{ - - int num_blocks = (m+3)/4; - dpct::device_ext &dev_ct1 = dpct::get_current_device(); - sycl::queue &q_ct1 = dev_ct1.in_order_queue(); - sycl::context ctx = q_ct1.get_context(); - - int size = NUM_BLOCK; - T *buff_A, *buff_out; - unsigned char *buff_B; - *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_A), (void*)(A), size); - *(( void**)&buff_B) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_B), (void*)(B), size); - *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_out), (void*)(out), size); - - { - dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); - dpct::get_in_order_queue().submit( - [&](sycl::handler &cgh) { - sycl::local_accessor quant_map_acc_ct1(sycl::range<1>(16), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { - kgemm_4bit_inference_naive(m, n, k, buff_A, buff_B, absmax, datatype, buff_out, lda, ldb, ldc, blocksize, item_ct1, quant_map_acc_ct1.get_pointer()); - }); - }); - } - /* - DPCT1010:291: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); - q_ct1.memcpy((void*)(A), (void*)(buff_A), size); - q_ct1.memcpy((void*)(B), (void*)(buff_B), size); - q_ct1.memcpy((void*)(out), (void*)(buff_out), size); - -} +//==================================func=========================== template void func(T *A, T *B, T value, long n) { @@ -1793,7 +1789,10 @@ template void func(T *A, T *B, T value, long n) /* DPCT1049:71: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ - dpct::get_in_order_queue().parallel_for( + dpct::get_in_order_queue().submit( + [&](sycl::handler &cgh) { + + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { kfunc(A, B, value, n, item_ct1); From 922786e01029daef4957895013ddc7de62719de3 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Thu, 16 May 2024 04:23:02 -0700 Subject: [PATCH 47/66] complete refine non gemm kernels --- csrc/sycl/kernels.cpp | 321 ++++++++++++++---------------------------- csrc/sycl/kernels.h | 2 +- csrc/sycl/ops.cpp | 114 +++++++-------- 3 files changed, 158 insertions(+), 279 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 568a6506b..4d191bb3d 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -504,6 +504,8 @@ __dpct_inline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict_ return pivot; } } +//=====================================================NON GEMMS================================ + //=====================================histogram 2d==================== SYCL_EXTERNAL void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, const sycl::nd_item<3> &item_ct1) @@ -2932,7 +2934,7 @@ templatevoid kdequant_mm_int32_fp16(int *__restrict__ const buff_A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats, sycl_la_T ltacc_T, sycl_la_float exacc) +template void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A) { // Strategy: To dequantize we need to load col/row statistics. This can be very expensive @@ -2985,11 +2987,12 @@ template void kdequant_mm_i sycl::half local_output[ITEMS_PER_THREAD]; float local_rowStats[ITEMS_PER_THREAD]; - - //typedef cub::BlockLoad LoadInt32; - //typedef cub::BlockExchange ExchangeInt32; - + using group_load_int = dpct::group::workgroup_load>; + using group_exchange = dpct::group::exchange; + auto *d_A = dacc_A.get_multi_ptr().get(); + auto *tmp = tacc.get_multi_ptr().get(); + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. float colStat = col >= numCols ? 0.0f : colStats[col]; @@ -3004,9 +3007,7 @@ template void kdequant_mm_i // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements smem_rowStats[j] = rowStats[row]; } - /* - DPCT1065:205: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); @@ -3026,9 +3027,7 @@ template void kdequant_mm_i break; // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) - /* - DPCT1007:206: Migration of cub::BlockLoad::Load is not supported. - */ + //LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); // 1. load 8 values per thread @@ -3037,12 +3036,9 @@ template void kdequant_mm_i // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = ltacc_T.get_multi_ptr().get(); - group_load(tmp).load(item, &buff_A[0], local_values); + auto *tmp = tacc.get_multi_ptr().get(); + group_load_int(tmp).load(item_ct1, d_A, local_values); - /* - DPCT1007:207: Migration of cub::BlockExchange::BlockedToWarpStriped is not supported. - */ //ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); // 1. load 8 values per thread @@ -3051,8 +3047,7 @@ template void kdequant_mm_i // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - auto *tmp = exacc.get_multi_ptr().get(); - group_exchange(tmp).blocked_to_warpstriped(item, local_values); + group_exchange(tmp).blocked_to_striped(item_ct1, local_values); #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) @@ -3541,6 +3536,96 @@ const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out) } } + +//========================================k extract outliers====================== + +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out) +{ + int local_colidx = idx[item_ct1.get_group(2)]; + + if(FORMAT==COL_TURING) + { + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // columns are grouped in increments of 4, meaning that one has the following rows and columns + // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...] + + // each thread reads 1 element = 1 row + for(int row = item_ct1.get_local_id(2); row < rowsA; row+= item_ct1.get_local_range(2)) + { + int offset_per_col_tile = ((rowsA+7)/8)*32*8; + int tile_offset_rows = (row/8)*32*8; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int offset = 0; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 8; + if(row % 2 == 1) + offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2); + else + // even + offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2); + + offset += tile_offset_rows + tile_offset_cols; + + char val = dacc_A[offset]; + + int out_idx = (row*idx_size) + item_ct1.get_group(2); + dacc_out[out_idx] = val; + } + } + else if(FORMAT == COL_AMPERE) + { + + for(int row = item_ct1.get_local_id(2); row < rowsA; row+= item_ct1.get_local_range(2)) + { + // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element + // within each tile. + int offset_per_col_tile = ((rowsA+31)/32)*32*32; + int tile_offset_rows = (row/32)*32*32; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 32; + // this magic is taken from the cublasLt doc (search for COL32) + int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; + offset += tile_offset_cols + tile_offset_rows; + + char val = A[offset]; + int out_idx = (row*idx_size) + item_ct1.get_group(2); + dacc_out[out_idx] = val; + } + } +} + +//=======================kfunc====================== + +template SYCL_EXTERNAL void kfunc(T *A, T *B, T value, long n, const sycl::nd_item<3> &item_ct1) +{ + for(long i = (item_ct1.get_local_range(2)*item_ct1.get_group(2)) + item_ct1.get_local_id(2); i < n; i+=(item_ct1.get_local_range(2)*item_ct1.get_group_range(2))) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; + } + } +} + +//=============================GEMMS=============================================================== + + //============================================k spmm sparse coo=============================================== #define DENORM 1.0f/127.0f #define MAX_SPARSE_COUNT 32 @@ -3696,70 +3781,6 @@ SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int } } -//========================================k extract outliers====================== - -template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out) -{ - int local_colidx = idx[item_ct1.get_group(2)]; - - if(FORMAT==COL_TURING) - { - // TURING FORMAT: - // 8*32 tiles with 4*4 subtiles - // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements) - // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero - // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) - // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column - // index increases by 32 - // columns are grouped in increments of 4, meaning that one has the following rows and columns - // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] - // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...] - - // each thread reads 1 element = 1 row - for(int row = item_ct1.get_local_id(2); row < rowsA; row+= item_ct1.get_local_range(2)) - { - int offset_per_col_tile = ((rowsA+7)/8)*32*8; - int tile_offset_rows = (row/8)*32*8; - int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; - int offset = 0; - int subtile_col_idx = local_colidx%32; - int subtile_row_idx = row % 8; - if(row % 2 == 1) - offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2); - else - // even - offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2); - - offset += tile_offset_rows + tile_offset_cols; - - char val = dacc_A[offset]; - - int out_idx = (row*idx_size) + item_ct1.get_group(2); - dacc_out[out_idx] = val; - } - } - else if(FORMAT == COL_AMPERE) - { - - for(int row = item_ct1.get_local_id(2); row < rowsA; row+= item_ct1.get_local_range(2)) - { - // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element - // within each tile. - int offset_per_col_tile = ((rowsA+31)/32)*32*32; - int tile_offset_rows = (row/32)*32*32; - int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; - int subtile_col_idx = local_colidx%32; - int subtile_row_idx = row % 32; - // this magic is taken from the cublasLt doc (search for COL32) - int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; - offset += tile_offset_cols + tile_offset_rows; - - char val = A[offset]; - int out_idx = (row*idx_size) + item_ct1.get_group(2); - dacc_out[out_idx] = val; - } - } -} //template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) //{ @@ -4469,137 +4490,6 @@ template SYCL_EXTERNAL void kgemm_4bit_infer } -//#define ROWS 2 -//template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) -//{ -//// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp -//// 1. Load dataB into register -//// 2. Dequantize B -//// 3. Fetch data from A and multiply -// -// typedef cub::BlockLoad LoadA; -// //__shared__ typename LoadA::TempStorage loada; -// typedef cub::BlockLoad LoadB; -// //__shared__ typename LoadB::TempStorage loadb; -// typedef cub::BlockReduce BlockReduce; -// // Allocate shared memory for BlockReduce -// //__shared__ typename BlockReduce::TempStorage reduce; -// -// __shared__ union { -// typename BlockReduce::TempStorage reduce; -// typename LoadB::TempStorage loadb; -// typename LoadA::TempStorage loada; -// } temp_storage; -// -// -// T dataA[ITEMS]; -// T local_B[ITEMS]; -// T local_accC[ROWS]; -// int valid_items = 0; -// const int col_offset = blockIdx.x * 8; -// -// __shared__ T tileA[ROWS*THREADS*ITEMS]; -// __shared__ T accumulatorC[ROWS*8]; -// -// //#pragma unroll 8 -// //for(int i = 0; i < 8; i++) -// // tileA[threadIdx.x + (i*256)] = 0.0f; -// //__syncthreads(); -// if(threadIdx.x < 64) -// accumulatorC[threadIdx.x] = 0.0f; -// __syncthreads(); -// -// -// for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS) -// { -// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; -// int baserow = 0; -// for(int row = baserow; row < (baserow+ROWS) && row < N; row++) -// { -// LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); -// -// #pragma unroll ITEMS -// for(int k = 0; k < ITEMS; k++) -// tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k]; -// -// __syncthreads(); -// } -// baserow += ROWS; -// -// // load 16 columns from B at a time. B is transposed, so its like loading rows -// // each warp loads one row -// // each thread loads 128 byte -// -// // col: inner_idx + warp_lane -// // row: ldb*(offset + warp_id) -// for(int col = 0; col < 8 && (col_offset + col) < M; col++) -// { -// int colB = col_offset + col; -// -// for(int k = 0; k < ROWS; k++) -// local_accC[k] = 0.0f; -// -// int base_idxB = ldb*colB; -// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; -// LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); -// __syncthreads(); -// -// for(int row = 0; row < ROWS && row < N; row++) -// { -// #pragma unroll ITEMS -// for(int k = 0; k < ITEMS; k++) -// { -// int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k); -// local_accC[row] += tileA[idxA]*local_B[k]; -// } -// -// local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); -// if(threadIdx.x == 0) -// atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); -// } -// } -// } -// -// for(int row = 0; row < ROWS && row < N; row++) -// { -// int out_idx = ldc*row + col_offset; -// -// //if(threadIdx.x < 8) -// // if(accumulatorC[row*8 + threadIdx.x] != 0.0) -// // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); -// -// if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) -// { -// //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); -// out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; -// } -// } -// -// -// -//} - - -template SYCL_EXTERNAL void kfunc(T *A, T *B, T value, long n, - const sycl::nd_item<3> &item_ct1) -{ - for(long i = (item_ct1.get_local_range(2)*item_ct1.get_group(2)) + item_ct1.get_local_id(2); i < n; i+=(item_ct1.get_local_range(2)*item_ct1.get_group_range(2))) - { - switch(FUNC) - { - case FILL: - A[i] = (T)value; - break; - case ARANGE: - A[i] = (T)i; - break; - case _MUL: - A[i] = A[i]*B[i]; - break; - } - } -} - //============================================================== // TEMPLATE DEFINITIONS @@ -4715,12 +4605,11 @@ template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int sycl::half *smem_dequant_stats); -template void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, - const sycl::nd_item<3> &item_ct1, - float *smem_rowStats); //==================supported template decls======================================================= +template void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A); + template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 9977f5303..a6515aff6 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -179,7 +179,7 @@ extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( float *__restrict__ const colStats, sycl::half *out, float *newRowStats, float *newcolStats, sycl::half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_rowStats, sycl_la_T ltacc_T, sycl_la_float exacc); + const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A); template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A); diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index bdd8d2aae..d5093de30 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -38,6 +38,11 @@ using namespace BinSearch; using std::cout; using std::endl; +int fill_up_to_nearest_multiple(int value, int multiple) +{ + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); +} + //================================histogram 2d============================================== void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) @@ -1005,6 +1010,50 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, } +//==========================dequant mm int 32 fp16========================== + +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *bias, int numRows, int numCols) +{ + int threads = 512; + int tileCols = fill_up_to_nearest_multiple(numCols, 32); + int n = numRows*tileCols; + int subtile_rows = 128; + int tilesize = 32*subtile_rows; + int num_blocks = numRows/subtile_rows; + num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; + num_blocks = num_blocks*(tileCols/32); + assert(threads <= tilesize); + int size = NUM_BLOCK; + + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + + + sycl::buffer buff_A (A, sycl::range<1>(size)); + + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + q_ct1.submit( + [&](sycl::handler &cgh) { + + using group_load = dpct::group::workgroup_load>; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + sycl::local_accessor tacc(temp_storage_size, cgh); + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + + //__shared__ vars + sycl::local_accessor smem_rowStats_acc_ct1(sycl::range<1>(256), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE), sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE)), + [=](sycl::nd_item<3> item_ct1) { + kdequant_mm_int32_fp16<4, 128, 512>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n, item_ct1,smem_rowStats_acc_ct1.get_pointer(), tacc, dacc_A ); + }); + + }); + +} + //========================GEMM============================ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) @@ -1188,61 +1237,6 @@ catch (sycl::exception const &exc) { std::exit(1); } -int fill_up_to_nearest_multiple(int value, int multiple) -{ - return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); -} - -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *bias, int numRows, int numCols) -{ - int threads = 512; - int tileCols = fill_up_to_nearest_multiple(numCols, 32); - int n = numRows*tileCols; - int subtile_rows = 128; - int tilesize = 32*subtile_rows; - int num_blocks = numRows/subtile_rows; - num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; - num_blocks = num_blocks*(tileCols/32); - assert(threads <= tilesize); - - dpct::device_ext &dev_ct1 = dpct::get_current_device(); - sycl::queue &q_ct1 = dev_ct1.in_order_queue(); - sycl::context ctx = q_ct1.get_context(); - - int size= NUM_BLOCK; - int *buff_A; - *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_A), (void*)(A), size); - dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); - q_ct1.submit( - [&](sycl::handler &cgh) { - - using group_load_T = dpct::group::workgroup_load; - using group_exchange = dpct::group::exchange; - size_t load_temp_storage_size_T = group_load_T::get_local_memory_size(NUM_BLOCK); - size_t exchange_temp_storage_size = group_exchange::get_local_memory_size(NUM_BLOCK); - - sycl::local_accessor ltacc_T(load_temp_storage_size_T, cgh); - sycl::local_accessor exacc(exchange_temp_storage_size, cgh); - - - //__shared__ vars - sycl::local_accessor smem_rowStats_acc_ct1(sycl::range<1>(256), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE), sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE)), - [=](sycl::nd_item<3> item_ct1) { - kdequant_mm_int32_fp16<4, 128, 512>(buff_A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n, item_ct1,smem_rowStats_acc_ct1.get_pointer(), ltacc_T, exacc ); - }); - - }); - /* - DPCT1010:283: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); - q_ct1.memcpy((void*)(A), (void*)(buff_A), size); - -} template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) { @@ -1786,9 +1780,7 @@ template void func(T *A, T *B, T value, long n) int blocks = n/threads; blocks = n % threads == 0 ? blocks : blocks + 1; blocks = blocks > 65535 ? 65535 : blocks; - /* - DPCT1049:71: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. - */ + dpct::get_in_order_queue().submit( [&](sycl::handler &cgh) { @@ -1797,10 +1789,8 @@ template void func(T *A, T *B, T value, long n) [=](sycl::nd_item<3> item_ct1) { kfunc(A, B, value, n, item_ct1); }); - /* - DPCT1010:292: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + }); + } //============================================================== From ee8225ed49c01cf09b281b2cf562b2077d71074d Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Thu, 16 May 2024 05:19:05 -0700 Subject: [PATCH 48/66] fix extract function --- csrc/sycl/ops.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index d5093de30..4e76db7fe 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1744,7 +1744,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id // we load 128 column values per warp int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32); int tiledRows = 0; - + int size = NUM_BLOCK; int num_blocks = idx_size; if(FORMAT == COL_TURING) @@ -1760,13 +1760,19 @@ template void extractOutliers(char * A, int *idx, char *out, int id sycl::queue &q_ct1 = dev_ct1.in_order_queue(); sycl::context ctx = q_ct1.get_context(); + sycl::buffer buff_A(A,sycl::range<1>(size)); + sycl::buffer buff_out(out,sycl::range<1>(size)); + + dpct::get_in_order_queue().submit( [&](sycl::handler &cgh) { + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, threads), sycl::range<3>(1, 1, threads)), [=](sycl::nd_item<3> item_ct1) { - kExtractOutliers(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols, item_ct1); + kExtractOutliers(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols, item_ct1, dacc_A, dacc_out); }); }); From 49ca1d7997fd54f9dca77970c44760455a10d68c Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Mon, 20 May 2024 04:53:04 -0700 Subject: [PATCH 49/66] refine igemmlt kernels --- csrc/sycl/ops.cpp | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 4e76db7fe..57c0589de 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -16,7 +16,9 @@ #include #include -#include "oneapi/dnnl/dnnl.hpp" +#include +#include + #define ERR_NOT_IMPLEMENTED 100 @@ -1132,6 +1134,8 @@ template int get_leading_dim(int dim1, int dim2); template int get_leading_dim(int dim1, int dim2); template int get_leading_dim(int dim1, int dim2); +//=================================transform GEMM============================== + template void transform( T *A, T *out, int dim1, int dim2) { @@ -1145,7 +1149,7 @@ template void trans int ldOut = get_leading_dim(dim1, dim2); int ldAOut = get_leading_dim(dim1, dim2); - dnnl::engine engine = dnnl::sycl_interop::make_engine(dev, ctx); + dnnl::engine engine = sycl_interop::make_engine(dev, ctx); // column major const memory::dims a_strides = memory::dims {1, ldA}; const auto a_md = DTYPE ==32 ? memory::desc({dim1, dim2}, dt::s32, a_strides) : memory::desc({dim1, dim2}, dt::s8, a_strides); @@ -1161,7 +1165,7 @@ template void trans //create dnnl stream auto q_ct1 = sycl::queue(ctx, dev); - dnnl::stream stream = sycl_interop::make_stream(q_ct1); + dnnl::stream stream = sycl_interop::make_stream(engine, q_ct1); primitive_attr attr; @@ -1186,6 +1190,9 @@ template void transform( int8_t *A, int8_t *o template void transform( int8_t *A, int8_t *out, int dim1, int dim2); template void transform( int32_t *A, int32_t *out, int dim1, int dim2); + +//========================igemmlt============================================ + template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) try { @@ -1194,7 +1201,7 @@ template int igemmlt( int m, int n, auto dev = sycl::device(sycl::gpu_selector_v); auto ctx = sycl::context(dev); - dnnl::engine engine = dnnl::sycl_interop::make_engine(dev, ctx); + dnnl::engine engine = sycl_interop::make_engine(dev, ctx); // column major const memory::dims a_strides = memory::dims {1, lda}; const auto a_md = memory::desc({m, k}, dt::s8, a_strides); @@ -1204,14 +1211,14 @@ template int igemmlt( int m, int n, const auto c_md = DTYPE_OUT == 32 ? memory::desc({m, n}, dt::s32, c_strides) : memory::desc({m, n}, dt::s8, c_strides); //memory align - memory a_mem(a_md, engine, A); - memory b_mem(b_md, engine, B); - memory c_mem(c_md, engine, C); + memory a_mem(a_md, engine); + memory b_mem(b_md, engine); + memory c_mem(c_md, engine); memory scales_C_mem({{1}, dt::f32, {1}}, engine, row_scale); //create dnnl stream auto q_ct1 = sycl::queue(ctx, dev); - dnnl::stream stream = dnnl::sycl_interop::make_stream(q_ct1); + dnnl::stream stream = sycl_interop::make_stream(engine, q_ct1); primitive_attr attr; if (SCALE_ROWS) { @@ -1237,6 +1244,8 @@ catch (sycl::exception const &exc) { std::exit(1); } +//===========================gemm_host============================================ + template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) { From 342f95c52ef2b6fa1689c55d7c99ee1a504b06db Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Tue, 21 May 2024 04:31:23 -0700 Subject: [PATCH 50/66] refine dnn and gemm 4 bit kernels --- csrc/sycl/kernels.cpp | 259 +++++++++++++++++++----------------------- csrc/sycl/kernels.h | 16 +-- csrc/sycl/ops.cpp | 107 +++++++---------- 3 files changed, 164 insertions(+), 218 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 4d191bb3d..a7e03a251 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -3820,11 +3820,11 @@ template inline void vector_load(T *loca } } +//=======================================gemm_device=================== + #define WARPS 3 -/* -DPCT1110:15: The total declared local variable size in device function gemm_device exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. -*/ -template SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, T *smem_A, T *smem_B) + +template SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, T *smem_A, T *smem_B, const sycl::accessor &dacc_A, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out) { #if DPCT_COMPATIBILITY_TEMP >= 750 @@ -3842,22 +3842,14 @@ template SYCL_EXTERNAL void gemm_device(int const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); + auto d_A = dacc_A.template get_multi_ptr(); + auto d_B = dacc_B.template get_multi_ptr(); + auto sg_size = item_ct1.get_sub_group(); - - //__shared__ T smem_C[8*32]; - - /* - DPCT1082:16: Migration of nvcuda::wmma::fragment type is not supported. - */ - /* - DPCT1082:17: Migration of nvcuda::wmma::matrix_a type is not supported. - */ - /* - DPCT1082:18: Migration of nvcuda::wmma::row_major type is not supported. - */ - sycl::ext::oneapi::experimental::matrix::joint_matrix a_frag; - sycl::ext::oneapi::experimental::matrix::joint_matrix b_frag; - sycl::ext::oneapi::experimental::matrix::joint_matrix c_frag; + + sycl::ext::oneapi::experimental::matrix::joint_matrix a_frag{}; + sycl::ext::oneapi::experimental::matrix::joint_matrix b_frag{}; + sycl::ext::oneapi::experimental::matrix::joint_matrix c_frag{}; sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(item_ct1.get_sub_group(), c_frag, 0.0f); //wmma::fragment a_frag; @@ -3873,18 +3865,18 @@ template SYCL_EXTERNAL void gemm_device(int { if(loaded_values == 0) { - local_A[0] = A[idx]; - local_A[1] = A[idx+(1*val_per_iter)]; - local_A[2] = A[idx+(2*val_per_iter)]; - local_A[3] = A[idx+(3*val_per_iter)]; + local_A[0] = dacc_A[idx]; + local_A[1] = dacc_A[idx+(1*val_per_iter)]; + local_A[2] = dacc_A[idx+(2*val_per_iter)]; + local_A[3] = dacc_A[idx+(3*val_per_iter)]; #pragma unroll 32 for(int col = 0; col < 32; col++) { - local_B[col] = B[(col_offset+col)*ldb+idx]; - local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; - local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; - local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + local_B[col] = dacc_B[(col_offset+col)*ldb+idx]; + local_B[col+32] = dacc_B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = dacc_B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = dacc_B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; } loaded_values = 3; } @@ -3951,18 +3943,18 @@ template SYCL_EXTERNAL void gemm_device(int // local_B[col] = B[(col_offset+col)*ldb+idx]; if(loaded_values == 0) { - local_A[0] = A[idx]; - local_A[1] = A[idx+(1*val_per_iter)]; - local_A[2] = A[idx+(2*val_per_iter)]; - local_A[3] = A[idx+(3*val_per_iter)]; + local_A[0] = dacc_A[idx]; + local_A[1] = dacc_A[idx+(1*val_per_iter)]; + local_A[2] = dacc_A[idx+(2*val_per_iter)]; + local_A[3] = dacc_A[idx+(3*val_per_iter)]; #pragma unroll 32 for(int col = 0; col < 32; col++) { - local_B[col] = B[(col_offset+col)*ldb+idx]; - local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; - local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; - local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + local_B[col] = dacc_B[(col_offset+col)*ldb+idx]; + local_B[col+32] = dacc_B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = dacc_B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = dacc_B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; } loaded_values = 3; @@ -4014,33 +4006,31 @@ template SYCL_EXTERNAL void gemm_device(int smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; } ticktock = ticktock == 0 ? 1 : 0; - + + if(warp_id == (WARPS-1)) for(int k = 0; k < batch_size_warps; k++) { - /* - DPCT1007:25: Migration of nvcuda::wmma::load_matrix_sync is not supported. - */ + //wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); + dacc_A[(ticktock*batch_size_warps + k)*a_tile_offset] = smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]; + dacc_B[(ticktock*batch_size_warps + k)*b_tile_offset] = smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]; + d_A = dacc_A.template get_multi_ptr(); + d_B = dacc_B.template get_multi_ptr(); - /* - DPCT1007:26: Migration of nvcuda::wmma::load_matrix_sync is not supported. - */ + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, a_frag, d_A, 16); + + //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, a_frag, d_B, 16); - /* - DPCT1007:27: Migration of nvcuda::wmma::mma_sync is not supported. - */ //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(item_ct1.get_sub_group(), c_frag, a_frag, b_frag, c_frag); + sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sg_size, c_frag, a_frag, b_frag, c_frag); } + } - } - } - + item_ct1.barrier(sycl::access::fence_space::local_space); if(warp_id != (WARPS-1)){ return; } // only warp_id == (WARPS-1) from here @@ -4049,38 +4039,33 @@ template SYCL_EXTERNAL void gemm_device(int ticktock = ticktock == 0 ? 1 : 0; for(int k = 0; k < batch_size_warps; k++) { - /* - DPCT1007:28: Migration of nvcuda::wmma::load_matrix_sync is not supported. - */ - //wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); - - /* - DPCT1007:29: Migration of nvcuda::wmma::load_matrix_sync is not supported. - */ - //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); - - /* - DPCT1007:30: Migration of nvcuda::wmma::mma_sync is not supported. - */ - //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(item_ct1.get_sub_group(), c_frag, a_frag, b_frag, c_frag); + dacc_A[(ticktock*batch_size_warps + k)*a_tile_offset] = smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]; + dacc_B[(ticktock*batch_size_warps + k)*b_tile_offset] = smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]; + d_A = dacc_A.template get_multi_ptr(); + d_B = dacc_B.template get_multi_ptr(); + + //wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, a_frag, d_A, 16); + + //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, b_frag, d_B, 16); + + //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sg_size, c_frag, a_frag, b_frag, c_frag); } // 129 mu if(warp_id == (WARPS-1)) - /* - DPCT1007:31: Migration of nvcuda::wmma::store_matrix_sync is not supported. - */ + //wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); - sycl::ext::oneapi::experimental::matrix::joint_matrix_store(item_ct1.get_sub_group(), c_frag, sycl::address_space_cast&(smem_A[0]), 32, sycl::ext::oneapi::experimental::matrix::layout::row_major); + sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sg_size, c_frag, d_A, (size_t)32, sycl::ext::oneapi::experimental::matrix::layout::row_major); if(col_offset + warp_lane < M) - out[col_offset + warp_lane] = smem_A[warp_lane]; + dacc_out[col_offset + warp_lane] = dacc_A[warp_lane]; #endif } +//===============================print=========================================== template void printnonzero(T *A, int num_values, const char * strval, const sycl::stream &stream_ct1) @@ -4098,15 +4083,14 @@ template void printnonzero(float *A, int num_values, const char*strval, template void printnonzero(sycl::half *A, int num_values, const char*strval, const sycl::stream &stream_ct1); -static dpct::global_memory nf4_data(sycl::range<1>(16), {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}); -/* -DPCT1110:32: The total declared local variable size in device function kgemm_4bit_inference exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. -*/ -template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, - const sycl::nd_item<3> &item_ct1, - T *smem_A, - unsigned char *smem_B, - T *smem_C) +//=======================================4 bit gemm=============================== + +const float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; + + + +template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *smem_A, unsigned char *smem_B, T *smem_C, +const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out) { #if DPCT_COMPATIBILITY_TEMP >= 750 @@ -4133,20 +4117,17 @@ template SYCL_EXTERNAL void kgemm_4bit_inference(int M const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); - /* - DPCT1082:33: Migration of nvcuda::wmma::fragment type is not supported. - */ - /* - DPCT1082:34: Migration of nvcuda::wmma::matrix_a type is not supported. - */ - /* - DPCT1082:35: Migration of nvcuda::wmma::row_major type is not supported. - */ + auto d_A = dacc_A.template get_multi_ptr(); + auto d_B = dacc_B.get_multi_ptr(); + auto sg_size = item_ct1.get_sub_group(); + - sycl::ext::oneapi::experimental::matrix::joint_matrix a_frag; - sycl::ext::oneapi::experimental::matrix::joint_matrix b_frag; - sycl::ext::oneapi::experimental::matrix::joint_matrix c_frag; - sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(item_ct1.get_sub_group(), c_frag, 0.0f); + sycl::ext::oneapi::experimental::matrix::joint_matrix a_frag{}; + sycl::ext::oneapi::experimental::matrix::joint_matrix b_frag{}; + sycl::ext::oneapi::experimental::matrix::joint_matrix c_frag{}; + sycl::ext::oneapi::experimental::matrix::joint_matrix_fill(sg_size, c_frag, 0.0f); + + //wmma::fragment a_frag; //wmma::fragment b_frag; @@ -4250,10 +4231,8 @@ template SYCL_EXTERNAL void kgemm_4bit_inference(int M loaded_values--; int absidx = (idx + col_offset)/blocksize; - /* - DPCT1098:222: The '*' expression is used instead of the __ldg call. These two expressions do not provide the exact same functionality. Check the generated code for potential precision and/or performance issues. - */ - sycl::half local_absmax = sycl::ext::oneapi::experimental::cuda::ldg(&absmax[absidx]); + + sycl::half local_absmax = absmax[absidx]; #pragma unroll 64 for(int col = 0; col < 64; col+=2) @@ -4295,23 +4274,22 @@ template SYCL_EXTERNAL void kgemm_4bit_inference(int M if(warp_id == (WARPS-1)) for(int k = 0; k < batch_size_warps; k++) { - /* - DPCT1007:42: Migration of nvcuda::wmma::load_matrix_sync is not supported. - */ + + dacc_A[(ticktock*batch_size_warps + k)*a_tile_offset] = smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]; + dacc_B[(ticktock*batch_size_warps + k)*b_tile_offset] = smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]; + d_A = dacc_A.template get_multi_ptr(); + d_B = dacc_B.get_multi_ptr(); + //wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, a_frag, d_A, 16); - /* - DPCT1007:43: Migration of nvcuda::wmma::load_matrix_sync is not supported. - */ + //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, b_frag, d_B, 16); - /* - DPCT1007:44: Migration of nvcuda::wmma::mma_sync is not supported. - */ + //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(item_ct1.get_sub_group(), c_frag, a_frag, b_frag, c_frag); + sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sg_size, c_frag, a_frag, b_frag, c_frag); } } @@ -4331,46 +4309,45 @@ template SYCL_EXTERNAL void kgemm_4bit_inference(int M { //if(warp_lane == 0) //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); - /* - DPCT1007:45: Migration of nvcuda::wmma::load_matrix_sync is not supported. - */ - sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); - - /* - DPCT1007:46: Migration of nvcuda::wmma::load_matrix_sync is not supported. - */ + + dacc_A[(ticktock*batch_size_warps + k)*a_tile_offset] = smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]; + dacc_B[(ticktock*batch_size_warps + k)*b_tile_offset] = smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]; + d_A = dacc_A.template get_multi_ptr(); + d_B = dacc_B.get_multi_ptr(); + + //wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, a_frag, d_A, 16); + //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - sycl::ext::oneapi::experimental::matrix::joint_matrix_load(item_ct1.get_sub_group(), a_frag, sycl::address_space_cast&(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, b_frag, d_B, 16); - /* - DPCT1007:47: Migration of nvcuda::wmma::mma_sync is not supported. - */ + //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(item_ct1.get_sub_group(), c_frag, a_frag, b_frag, c_frag); -wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sg_size, c_frag, a_frag, b_frag, c_frag); } // 129 mu if(warp_id == (WARPS-1)) - /* - DPCT1007:48: Migration of nvcuda::wmma::store_matrix_sync is not supported. - */ + //wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); - sycl::ext::oneapi::experimental::matrix::joint_matrix_store(item_ct1.get_sub_group(), c_frag, sycl::address_space_cast&(smem_A[0]), 32, sycl::ext::oneapi::experimental::matrix::layout::row_major); + sycl::ext::oneapi::experimental::matrix::joint_matrix_store(sg_size, c_frag, d_A, 32, sycl::ext::oneapi::experimental::matrix::layout::row_major); //printnonzero(smem_C, 32, ""); if(col_offset + warp_lane < M) - out[col_offset + warp_lane] = smem_C[warp_lane]; + // use smem_A itself + dacc_out[col_offset + warp_lane] = dacc_A[warp_lane]; #endif } + +//=========================================4 bit gemm naive=============== + + #define num_values_4bit 32 -/* -DPCT1110:49: The total declared local variable size in device function kgemm_4bit_inference_naive exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. -*/ -template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *quant_map) + +template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *quant_map, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out) { // per threadblock: @@ -4401,10 +4378,8 @@ template SYCL_EXTERNAL void kgemm_4bit_infer int inner_idx_halved = inner_idx/2; int offset_B = ldb*row_B; int absidx = ((2*offset_B)+inner_idx)/blocksize; - /* - DPCT1098:223: The '*' expression is used instead of the __ldg call. These two expressions do not provide the exact same functionality. Check the generated code for potential precision and/or performance issues. - */ - local_absmax = sycl::ext::oneapi::experimental::cuda::ldg(&absmax[absidx]); + + local_absmax = absmax[absidx]; if(row_B < M) { @@ -4418,7 +4393,7 @@ template SYCL_EXTERNAL void kgemm_4bit_infer #pragma unroll for(int j = 0; j < (num_values_8bit); j++) if((inner_idx_halved) + j < (K/2)) - local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; + local_B_4bit[j] = dacc_B[offset_B+inner_idx_halved + j]; else local_B_4bit[j] = 0b01110111; } @@ -4463,7 +4438,7 @@ template SYCL_EXTERNAL void kgemm_4bit_infer #pragma unroll for(int k = 0; k < num_values_4bit/4; k++) if(inner_idx + (i*num_values_4bit/4) + k < K) - local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)]; + local_A[k] = dacc_A[inner_idx + k + (i*num_values_4bit/4)]; else local_A[k] = T(0.0f); @@ -4485,12 +4460,10 @@ template SYCL_EXTERNAL void kgemm_4bit_infer local_C = sycl::reduce_over_group(item_ct1.get_sub_group(), local_C, sycl::plus<>()); if(row_B < M && warp_lane == 0) - out[row_B] = T(local_C); + dacc_out[row_B] = T(local_C); } - - //============================================================== // TEMPLATE DEFINITIONS //============================================================== diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index a6515aff6..da4aad226 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -203,14 +203,14 @@ const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); template extern SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); -template extern SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B); -template extern SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, - sycl::half *smem_B, - sycl::half *smem_C); -template extern SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, - const sycl::nd_item<3> &item_ct1, - T *quant_map); +template extern SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, T *smem_A, T *smem_B, const sycl::accessor &dacc_A, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out); + + +template extern SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *smem_A, unsigned char *smem_B, + T *smem_C, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out); + + +template extern SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *quant_map, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out); template extern SYCL_EXTERNAL void kfunc(T *A, T *B, T value, long n, const sycl::nd_item<3> &item_ct1); diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 57c0589de..229a5efa7 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1066,7 +1066,7 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in const void * beta = &fbeta; int status; - DPCT_CHECK_ERROR(dpct::gemm(*context->m_handle, transposeA ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, transposeB ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, m, n, k, alpha, A, dpct::library_data_t::real_int8, lda, B, dpct::library_data_t::real_int8, ldb, beta, C, dpct::library_data_t::real_int32, ldc, dpct::library_data_t::real_int32)); + dpct::gemm(*context->m_handle, transposeA ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, transposeB ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, m, n, k, alpha, A, dpct::library_data_t::real_int8, lda, B, dpct::library_data_t::real_int8, ldb, beta, C, dpct::library_data_t::real_int32, ldc, dpct::library_data_t::real_int32); } @@ -1090,7 +1090,7 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i //printf("%i %i %i\n", strideA, strideB, strideC); //printf("%i\n", batchCount); - DPCT_CHECK_ERROR(dpct::gemm_batch(*context->m_handle, transposeA ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, transposeB ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, m, n, k, alpha, A, dpct::library_data_t::real_int8, lda, (long long int)strideA, B, dpct::library_data_t::real_int8, ldb, (long long int)strideB, beta, C, dpct::library_data_t::real_int32, ldc, (long long int)strideC, batchCount, dpct::library_data_t::real_int32)); + dpct::gemm_batch(*context->m_handle, transposeA ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, transposeB ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, m, n, k, alpha, A, dpct::library_data_t::real_int8, lda, (long long int)strideA, B, dpct::library_data_t::real_int8, ldb, (long long int)strideB, beta, C, dpct::library_data_t::real_int32, ldc, (long long int)strideC, batchCount, dpct::library_data_t::real_int32); } catch (sycl::exception const &exc) { @@ -1246,7 +1246,6 @@ catch (sycl::exception const &exc) { //===========================gemm_host============================================ - template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) { @@ -1257,15 +1256,11 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; - T *buff_A, *buff_B, *buff_out; - *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_A), (void*)(A), size); - *(( void**)&buff_B) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_B), (void*)(B), size); - *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_out), (void*)(out), size); - - + + sycl::buffer buff_A (A, sycl::range<1>(size)); + sycl::buffer buff_B (B, sycl::range<1>(size)); + sycl::buffer buff_out (out, sycl::range<1>(size)); + //cout << num_blocks << endl; //cout << lda << endl; //cout << ldb << endl; @@ -1283,13 +1278,11 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); dpct::get_in_order_queue().submit( [&](sycl::handler &cgh) { - /* - DPCT1101:312: '8*16 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. - */ - //sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(224/*8*16 + (2*16*(batch_size_warps-1))*/), cgh); - /* - DPCT1101:313: '2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. - */ + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_B(buff_B, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(224/*8*16 + (2*16*(batch_size_warps-1))*/), cgh); sycl::local_accessor smem_B_acc_ct1(sycl::range<1>(4192/*2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))*/), cgh); @@ -1297,7 +1290,8 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 160), sycl::range<3>(1, 1, 160)), [=](sycl::nd_item<3> item_ct1) { - gemm_device(m, n, k, buff_A, buff_B, buff_out, lda, ldb, ldc, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer()); + gemm_device(m, n, k, A, B, out, lda, ldb, ldc, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer(), + dacc_A, dacc_B, dacc_out); }); }); } @@ -1305,13 +1299,11 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //back memcpy - q_ct1.memcpy((void*)(A), (void*)(buff_A), size); - q_ct1.memcpy((void*)(B), (void*)(buff_B), size); - q_ct1.memcpy((void*)(out), (void*)(buff_out), size); - + } +//============================gemm 4bit inference ================================ + template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { @@ -1331,49 +1323,35 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; - T *buff_A, *buff_out; - unsigned char *buff_B; - *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_A), (void*)(A), size); - *(( void**)&buff_B) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_B), (void*)(B), size); - *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_out), (void*)(out), size); - + + sycl::buffer buff_A (A, sycl::range<1>(size)); + sycl::buffer buff_B (B, sycl::range<1>(size)); + sycl::buffer buff_out (out, sycl::range<1>(size)); { dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); dpct::get_in_order_queue().submit( [&](sycl::handler &cgh) { - /* - DPCT1101:314: '8*16 + (16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. - */ - /* - DPCT1101:315: '2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. - */ + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_B(buff_B, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_A_acc_ct1(sycl::range<1>(176/*8*16 + (16*(batch_size_warps-1))*/), cgh); - sycl::local_accessor smem_B_acc_ct1(sycl::range<1>(4192/*2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))*/), cgh); + sycl::local_accessor smem_B_acc_ct1(sycl::range<1>(4192/*2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))*/), cgh); sycl::local_accessor smem_C_acc_ct1(sycl::range<1>(8*32), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 96), sycl::range<3>(1, 1, 96)), [=](sycl::nd_item<3> item_ct1) { - kgemm_4bit_inference(m, n, k, buff_A, buff_B, absmax, buff_out, lda, ldb, ldc, blocksize, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer(), smem_C_acc_ct1.get_pointer()); + kgemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize, item_ct1, smem_A_acc_ct1.get_pointer(), smem_B_acc_ct1.get_pointer(), smem_C_acc_ct1.get_pointer(), dacc_A, dacc_B, dacc_out); }); }); } - //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); - //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); - //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); - //back memcpy - q_ct1.memcpy((void*)(A), (void*)(buff_A), size); - q_ct1.memcpy((void*)(B), (void*)(buff_B), size); - q_ct1.memcpy((void*)(out), (void*)(buff_out), size); } +//============================gemm 4 bit inference naive ================= template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { @@ -1384,36 +1362,31 @@ template void gemm_4bit_inference_naive(int m, int n, int sycl::context ctx = q_ct1.get_context(); int size = NUM_BLOCK; - T *buff_A, *buff_out; - unsigned char *buff_B; - *((void **)&buff_A) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_A), (void*)(A), size); - *(( void**)&buff_B) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_B), (void*)(B), size); - *((void **)&buff_out) = sycl::malloc_device(size, dev_ct1, ctx); - q_ct1.memcpy((void*)(buff_out), (void*)(out), size); - + + sycl::buffer buff_A (A, sycl::range<1>(size)); + sycl::buffer buff_B (B, sycl::range<1>(size)); + sycl::buffer buff_out (out, sycl::range<1>(size)); + + { dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); dpct::get_in_order_queue().submit( [&](sycl::handler &cgh) { + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_B(buff_B, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::local_accessor quant_map_acc_ct1(sycl::range<1>(16), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { - kgemm_4bit_inference_naive(m, n, k, buff_A, buff_B, absmax, datatype, buff_out, lda, ldb, ldc, blocksize, item_ct1, quant_map_acc_ct1.get_pointer()); + kgemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, item_ct1, quant_map_acc_ct1.get_pointer(), dacc_A, dacc_B, dacc_out); }); }); } - /* - DPCT1010:291: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); - q_ct1.memcpy((void*)(A), (void*)(buff_A), size); - q_ct1.memcpy((void*)(B), (void*)(buff_B), size); - q_ct1.memcpy((void*)(out), (void*)(buff_out), size); - + } //================================spm coo================================== From 5e2611a1ff38806ee756a83fb2ccdf90b9d79c38 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Tue, 21 May 2024 05:26:09 -0700 Subject: [PATCH 51/66] refine spm experimental --- csrc/sycl/kernels.cpp | 4 +--- csrc/sycl/ops.cpp | 46 +++++++++++++------------------------------ 2 files changed, 15 insertions(+), 35 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index a7e03a251..89d287886 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -3631,9 +3631,7 @@ template SYCL_EXTERNAL void kfunc(T *A, T *B, T value, lo #define MAX_SPARSE_COUNT 32 #define SMEM_SIZE 8*256 template -/* -DPCT1110:13: The total declared local variable size in device function kspmm_coo_very_sparse_naive exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. -*/ + SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, sycl::half *smem_dequant_stats) { diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 229a5efa7..a5c036ed2 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1391,17 +1391,13 @@ template void gemm_4bit_inference_naive(int m, int n, int //================================spm coo================================== -void spmm_coo(sycl::queue* handle, int *A_rowidx, int *A_colidx, sycl::half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, sycl::half *B, int ldc, sycl::half* C, bool transposed_B) +void spmm_coo(int *A_rowidx, int *A_colidx, sycl::half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, sycl::half *B, int ldc, sycl::half* C, bool transposed_B) { try{ dpct::device_ext &dev_ct1 = dpct::get_current_device(); sycl::queue &q_ct1 = dev_ct1.in_order_queue(); - -//#ifdef NO_CUBLASLT -//#else - - + dpct::sparse::sparse_matrix_desc_t descA; std::shared_ptr descB, descC; @@ -1410,15 +1406,9 @@ void spmm_coo(sycl::queue* handle, int *A_rowidx, int *A_colidx, sycl::half *A_v void *dBuffer = NULL; size_t bufferSize = 0; - /* - DPCT1007:287: Migration of cusparseCreateCoo is not supported. - */ - //CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz, - // A_rowidx, A_colidx, A_vals, - // dpct::library_data_t::real_int32, - // oneapi::mkl::index_base::zero, dpct::library_data_t::real_half) ); + // Create dense matrix C - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descC = std::make_shared(A_rows, B_cols, ldc, C, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major)) ); + descC = std::make_shared(A_rows, B_cols, ldc, C, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major); // Create dense matrix B if(transposed_B) @@ -1428,27 +1418,22 @@ void spmm_coo(sycl::queue* handle, int *A_rowidx, int *A_colidx, sycl::half *A_v B_cols = tmp; } - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descB = std::make_shared(A_cols, B_cols, ldb, B, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major)) ); + descB = std::make_shared(A_cols, B_cols, ldb, B, dpct::library_data_t::real_half, oneapi::mkl::layout::row_major); // allocate an external buffer if needed - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(bufferSize = 0) ); + bufferSize = 0; - //CUDA_CHECK_RETURN( DPCT_CHECK_ERROR(dBuffer = (void *)sycl::malloc_device(bufferSize, q_ct1)) ); + dBuffer = (void *)sycl::malloc_device(bufferSize, q_ct1); - // execute SpMM - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(dpct::sparse::spmm(*handle, oneapi::mkl::transpose::nontrans, transposed_B ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, &alpha, descA, descB, &beta, descC, dpct::library_data_t::real_float))); - dpct::sparse::spmm(*handle, oneapi::mkl::transpose::nontrans, transposed_B ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, &alpha, descA, descB, &beta, descC, dpct::library_data_t::real_float); + + dpct::sparse::spmm(q_ct1, oneapi::mkl::transpose::nontrans, transposed_B ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans, &alpha, descA, descB, &beta, descC, dpct::library_data_t::real_float); // destroy matrix/vector descriptors descA.reset(); descB.reset(); descC.reset(); sycl::free(dBuffer, q_ct1); - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descA.reset()) ); - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descB.reset()) ); - //CHECK_CUSPARSE( DPCT_CHECK_ERROR(descC.reset()) ); - //CUDA_CHECK_RETURN( DPCT_CHECK_ERROR(sycl::free(dBuffer, q_ct1)) ); -//#endif + } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; @@ -1457,6 +1442,8 @@ void spmm_coo(sycl::queue* handle, int *A_rowidx, int *A_colidx, sycl::half *A_v } +//===============================spm _coo _very _sparse========================= + template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { @@ -1464,9 +1451,7 @@ template void spmm_coo_very_sparse_naive(int *max_count, dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); dpct::get_in_order_queue().submit( [&](sycl::handler &cgh) { - /* - DPCT1101:311: 'SMEM_SIZE' expression was replaced with a value. Modify the code to use the original expression, provided in comments, if it is correct. - */ + sycl::local_accessor smem_dequant_stats_acc_ct1(sycl::range<1>(2048/*SMEM_SIZE*/), cgh); cgh.parallel_for( @@ -1476,10 +1461,7 @@ template void spmm_coo_very_sparse_naive(int *max_count, }); }); } - /* - DPCT1010:289: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. - */ - //CUDA_CHECK_RETURN(0); + } From 6af46e2f428e1e9df3cad34ed5f587981ce3a92b Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Sun, 26 May 2024 01:45:50 -0700 Subject: [PATCH 52/66] fix host dereference issue on nv dequant & kquant --- csrc/sycl/kernels.cpp | 33 ++++++++++++++++----------------- csrc/sycl/kernels.h | 7 +++++-- csrc/sycl/ops.cpp | 25 +++++++++++++++++++------ 3 files changed, 40 insertions(+), 25 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 89d287886..4595fb62b 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -730,7 +730,7 @@ void kEstimateQuantiles(const T *A, float *code, const float offset, const T max SYCL_EXTERNAL void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, const sycl::nd_item<3> &item_ct1, float* smem_code, const sycl_la &tacc, const sycl_dacc_float &dacc_A, - const sycl_dacc_uc &dacc_out) + const sycl_dacc_uc &dacc_out, const sycl_dacc_float &dacc_code) { const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*NUM_BLOCK) : NUM_BLOCK; @@ -740,15 +740,15 @@ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, c unsigned char qvals[NUM]; //const int lane_id = threadIdx.x % 2; - using group_load_float = dpct::group::workgroup_load>; - using group_store_uc = dpct::group::workgroup_store>; + using group_load_float = dpct_::group::workgroup_load>; + using group_store_uc = dpct_::group::workgroup_store>; auto *d_A = dacc_A.template get_multi_ptr().get(); auto *d_out = dacc_out.get_multi_ptr().get(); if(item_ct1.get_local_id(2) < 256) { - smem_code[item_ct1.get_local_id(2)] = code[item_ct1.get_local_id(2)]; + smem_code[item_ct1.get_local_id(2)] = dacc_code[item_ct1.get_local_id(2)]; //smem_code[0][threadIdx.x] = code[threadIdx.x]; //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; } @@ -788,11 +788,11 @@ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, c // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index group_store_uc(tmp).store(item_ct1, d_out, qvals); - } - + } } + //===========================k quantize blockwise================================ template @@ -2934,7 +2934,7 @@ templatevoid kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A) +template void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl::accessor &dacc_out, const sycl::accessor &dacc_bias ) { // Strategy: To dequantize we need to load col/row statistics. This can be very expensive @@ -2987,16 +2987,16 @@ template void kdequant_mm_i sycl::half local_output[ITEMS_PER_THREAD]; float local_rowStats[ITEMS_PER_THREAD]; - using group_load_int = dpct::group::workgroup_load>; - using group_exchange = dpct::group::exchange; + using group_load_int = dpct_::group::workgroup_load>; + using group_exchange = exchange; auto *d_A = dacc_A.get_multi_ptr().get(); auto *tmp = tacc.get_multi_ptr().get(); - + //dacc_colStats //dacc_bias //dacc_rowStats // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. - float colStat = col >= numCols ? 0.0f : colStats[col]; - float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : sycl::vec(bias[col]).convert()[0]; + float colStat = col >= numCols ? 0.0f : dacc_colStats[col]; + float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : sycl::vec(dacc_bias[col]).convert()[0]; // no block loads for rows for now -- keep it simple for(int j = item_ct1.get_local_id(2); j < SUBTILE_ROWS; j+=item_ct1.get_local_range(2)) { @@ -3005,7 +3005,7 @@ template void kdequant_mm_i // each warp accesses the same element, for four consequitive elements // todo: update description about striped shared memory, it is not needed // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements - smem_rowStats[j] = rowStats[row]; + smem_rowStats[j] = dacc_rowStats[row]; } item_ct1.barrier(sycl::access::fence_space::local_space); @@ -3068,13 +3068,12 @@ template void kdequant_mm_i { int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols); if(outIdx< n_out && col < numCols) - out[outIdx] = local_output[j]; + dacc_out[outIdx] = local_output[j]; } row_offset += rows_per_load; } } - //=====================================k double row col quant============================ template void kDoubleRowColQuant(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed) @@ -4579,7 +4578,7 @@ template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int //==================supported template decls======================================================= -template void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A); +template void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl::accessor &dacc_out, const sycl::accessor &dacc_bias); template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); @@ -4747,7 +4746,7 @@ template SYCL_EXTERNAL void kPercentileClipping(sycl::half #define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ -template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value,const sycl_la &tacc,const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out); +template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value,const sycl_la &tacc,const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out, const sycl_dacc_float &dacc_code); MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 1, General8bit) diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index da4aad226..0c68cfcae 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -29,7 +29,7 @@ template extern SYCL_EXTERNAL void kEstimateQuantiles(T *__restrict_ extern SYCL_EXTERNAL void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, const sycl::nd_item<3> &item_ct1, float* smem_code, const sycl_la &tacc, const sycl_dacc_float &dacc_A, - const sycl_dacc_uc &dacc_out); + const sycl_dacc_uc &dacc_out, const sycl_dacc_float &dacc_code); extern SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code); @@ -179,7 +179,10 @@ extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( float *__restrict__ const colStats, sycl::half *out, float *newRowStats, float *newcolStats, sycl::half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, - const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A); + const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A, + const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl::accessor &dacc_out, + const sycl::accessor &dacc_bias +); template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A); diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index a5c036ed2..457ba3dab 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -121,33 +121,34 @@ void quantize(float *code, float *A, unsigned char *out, int n) sycl::buffer buff_A(A,sycl::range<1>(size)); sycl::buffer buff_out(out,sycl::range<1>(size)); - + sycl::buffer buff_code(code,sycl::range<1>(size)); { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load = dpct::group::workgroup_load>; + using group_load = dpct_::group::workgroup_load>; size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); sycl::local_accessor tacc(temp_storage_size, cgh); sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kQuantize(code, A, out, n, item_ct1, smem_code_acc_ct1.get_pointer(), tacc, dacc_A, dacc_out); + kQuantize(code, A, out, n, item_ct1, smem_code_acc_ct1.get_pointer(), tacc, dacc_A, dacc_out, dacc_code); }); }); } } - //============================k dequantize=============================== void dequantize(float *code, unsigned char *A, float *out, int n) { @@ -1033,15 +1034,27 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, sycl::half sycl::buffer buff_A (A, sycl::range<1>(size)); + sycl::buffer buff_rowStats (rowStats, sycl::range<1>(size)); + sycl::buffer buff_colStats (colStats, sycl::range<1>(size)); + sycl::buffer buff_out (out, sycl::range<1>(size)); + sycl::buffer buff_bias (bias, sycl::range<1>(size)); + + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load = dpct::group::workgroup_load>; + using group_load = dpct_::group::workgroup_load>; size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); sycl::local_accessor tacc(temp_storage_size, cgh); sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + + sycl::accessor dacc_rowStats(buff_rowStats, cgh, sycl::read_write); + sycl::accessor dacc_colStats(buff_colStats, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_bias(buff_bias, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_rowStats_acc_ct1(sycl::range<1>(256), cgh); @@ -1049,7 +1062,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, sycl::half cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE), sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE)), [=](sycl::nd_item<3> item_ct1) { - kdequant_mm_int32_fp16<4, 128, 512>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n, item_ct1,smem_rowStats_acc_ct1.get_pointer(), tacc, dacc_A ); + kdequant_mm_int32_fp16<4, 128, 512>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n, item_ct1,smem_rowStats_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rowStats, dacc_colStats, dacc_out, dacc_bias); }); }); From bcad0ea79f4955159c64fa671887773994ea7eaa Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Thu, 6 Jun 2024 06:32:50 -0700 Subject: [PATCH 53/66] fix header name --- csrc/sycl/kernels.cpp | 6 +++--- csrc/sycl/ops.cpp | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 4595fb62b..dd35f87f7 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -740,8 +740,8 @@ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, c unsigned char qvals[NUM]; //const int lane_id = threadIdx.x % 2; - using group_load_float = dpct_::group::workgroup_load>; - using group_store_uc = dpct_::group::workgroup_store>; + using group_load_float = dpct::group::workgroup_load>; + using group_store_uc = dpct::group::workgroup_store>; auto *d_A = dacc_A.template get_multi_ptr().get(); auto *d_out = dacc_out.get_multi_ptr().get(); @@ -2987,7 +2987,7 @@ template void kdequant_mm_i sycl::half local_output[ITEMS_PER_THREAD]; float local_rowStats[ITEMS_PER_THREAD]; - using group_load_int = dpct_::group::workgroup_load>; + using group_load_int = dpct::group::workgroup_load>; using group_exchange = exchange; auto *d_A = dacc_A.get_multi_ptr().get(); diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 457ba3dab..38e558243 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -127,7 +127,7 @@ void quantize(float *code, float *A, unsigned char *out, int n) dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load = dpct_::group::workgroup_load>; + using group_load = dpct::group::workgroup_load>; size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); sycl::local_accessor tacc(temp_storage_size, cgh); @@ -1045,7 +1045,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, sycl::half q_ct1.submit( [&](sycl::handler &cgh) { - using group_load = dpct_::group::workgroup_load>; + using group_load = dpct::group::workgroup_load>; size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); sycl::local_accessor tacc(temp_storage_size, cgh); sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); From 6d3ed2664f044513f220910153bcf4b3c467176b Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Thu, 6 Jun 2024 08:04:54 -0700 Subject: [PATCH 54/66] fix nv host dereference issue on 8 bit --- csrc/sycl/kernels.cpp | 65 +++++++++++++++++++++++-------------------- csrc/sycl/kernels.h | 15 +++++++--- csrc/sycl/ops.cpp | 28 +++++++++++++------ 3 files changed, 65 insertions(+), 43 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index dd35f87f7..d4ed4d5bb 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -2277,7 +2277,9 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char float *smem_exchange1, float *smem_exchange2, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, - const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2) + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, + const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, + const sycl_dacc_float &dacc_absmax1, const sycl_dacc_float &dacc_absmax2) { //const int n_full = n + (n%BLOCK_SIZE); @@ -2316,13 +2318,13 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char auto *d_state1 = dacc_state1.get_multi_ptr().get(); auto *d_state2 = dacc_state2.get_multi_ptr().get(); - + //quantiles1 //quantiles2 //absmax1 //absmax2 // init: 0.2 -> 0.23 // 0.23 -> 0.23 - smem_quantiles1[0][item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; - smem_quantiles2[0][item_ct1.get_local_id(2)] = quantiles2[item_ct1.get_local_id(2)]; + smem_quantiles1[0][item_ct1.get_local_id(2)] = dacc_quantiles1[item_ct1.get_local_id(2)]; + smem_quantiles2[0][item_ct1.get_local_id(2)] = dacc_quantiles2[item_ct1.get_local_id(2)]; # pragma unroll for(unsigned int j = 1; j < LANES; j++) { @@ -2389,7 +2391,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char { if(!sycl::isnan((float)g_vals[j]) && !sycl::isinf((float)g_vals[j])) { - s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*dacc_absmax2[i/BLOCK_SIZE]; g_val = g_vals[j]; //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; @@ -2397,7 +2399,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); - s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*dacc_absmax1[i/BLOCK_SIZE]; s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); } else @@ -2426,8 +2428,8 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char if(item_ct1.get_local_id(2) == 0) { - absmax1[i/BLOCK_SIZE] = new_local_abs_max1; - absmax2[i/BLOCK_SIZE] = new_local_abs_max2; + dacc_absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + dacc_absmax2[i/BLOCK_SIZE] = new_local_abs_max2; } else { @@ -2477,9 +2479,6 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - //c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], s1_vals[j] / new_local_abs_max1); - //c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], s2_vals[j] / new_local_abs_max2); - c1s[j] = quantize_2D<1>(quadrants1, s1_vals[j] / new_local_abs_max1); c2s[j] = quantize_2D<0>(quadrants2, s2_vals[j] / new_local_abs_max2); @@ -2540,7 +2539,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, - const sycl_dacc_uc &dacc_state1 + const sycl_dacc_uc &dacc_state1, + const sycl_dacc_float &dacc_quantiles1, + const sycl_dacc_float &dacc_absmax1 ) { @@ -2574,7 +2575,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char // init: 0.2 -> 0.23 // 0.23 -> 0.23 - smem_quantiles1[0][item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; + smem_quantiles1[0][item_ct1.get_local_id(2)] = dacc_quantiles1[item_ct1.get_local_id(2)]; # pragma unroll for(unsigned int j = 1; j < LANES; j++) smem_quantiles1[j][item_ct1.get_local_id(2)] = smem_quantiles1[0][item_ct1.get_local_id(2)]; @@ -2638,36 +2639,36 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char { if(weight_decay > 0.0f) { switch(OPTIMIZER) { - case MOMENTUM: - case ADAGRAD: - case RMSPROP: + case 1: + case 3: + case 2: g_val += ((float)p_vals[j])*weight_decay; break; - case LION: + case 4: p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); break; } } - s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*dacc_absmax1[i/BLOCK_SIZE]; switch(OPTIMIZER) { - case MOMENTUM: + case 1: if(step == 1) s1_vals[j] = g_val; else s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; - case LION: + case 4: // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; - case RMSPROP: + case 2: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; - case ADAGRAD: + case 3: s1_vals[j] = s1_vals[j] + (g_val*g_val); break; } @@ -2687,7 +2688,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char item_ct1.barrier(sycl::access::fence_space::local_space); if(item_ct1.get_local_id(2) == 0) - absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + dacc_absmax1[i/BLOCK_SIZE] = new_local_abs_max1; else new_local_abs_max1 = smem_exchange1[0]; @@ -2699,17 +2700,17 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char { switch(OPTIMIZER) { - case MOMENTUM: + case 1: p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); break; - case LION: + case 4: p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); break; - case RMSPROP: + case 2: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(g_val / (sycl::sqrt(s1_vals[j])+eps)); break; - case ADAGRAD: + case 3: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(g_val / (sycl::sqrt(s1_vals[j])+eps)); break; @@ -2733,8 +2734,8 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - //c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], s1_vals[j] / new_local_abs_max1); c1s[j] = quantize_2D<1>(quadrants1, s1_vals[j] / new_local_abs_max1); + // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) if(sycl::signbit(smem_quantiles1[lane_id][c1s[j]]) != sycl::signbit(s1_vals[j])) @@ -4842,7 +4843,9 @@ template void kOptimizerStatic8bit2StateBlockwise &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2,const sycl_la &tacc, \ const sycl::accessor &dacc_g, \ const sycl::accessor &dacc_p, \ - const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); \ + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, \ + const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, \ + const sycl_dacc_float &dacc_absmax1, const sycl_dacc_float &dacc_absmax2); \ MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, sycl::half, 2048, 8) @@ -4860,7 +4863,9 @@ template void kOptimizerStatic8bit1StateBlockwise &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1,const sycl_la &tacc, \ const sycl::accessor &dacc_g, \ const sycl::accessor &dacc_p, \ - const sycl_dacc_uc &dacc_state1); \ + const sycl_dacc_uc &dacc_state1, \ + const sycl_dacc_float &dacc_quantiles1, \ + const sycl_dacc_float &dacc_absmax1); \ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::half, 2048, 8) diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 0c68cfcae..43bd6cd4f 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -130,16 +130,21 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha template extern SYCL_EXTERNAL void kOptimizerStatic8bit2StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, - const float beta1, const float beta2, const float eps, const int step, const float lr, + const float beta1, const float beta2, + const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n, + float* absmax1, float* absmax2, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, sycl::local_accessor smem_quantiles2, float *smem_exchange1, float *smem_exchange2, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, - const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, + const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, + const sycl_dacc_float &dacc_absmax1, const sycl_dacc_float &dacc_absmax2); template extern SYCL_EXTERNAL void kOptimizerStatic8bit1StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, @@ -155,7 +160,9 @@ template extern SYCL_EX const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, - const sycl_dacc_uc &dacc_state1); + const sycl_dacc_uc &dacc_state1, + const sycl_dacc_float &dacc_quantiles1, + const sycl_dacc_float &dacc_absmax1); template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g); diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 38e558243..dd442886e 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -893,17 +893,21 @@ template void optimizerStatic8bitBlockwise(T* p, T* g sycl::queue &q_ct1 = dev_ct1.in_order_queue(); sycl::context ctx = q_ct1.get_context(); int num_blocks = 0; - int size = BLOCKSIZE_2STATE; + int size = NUM_BLOCK; sycl::buffer buff_g(g,sycl::range<1>(size)); sycl::buffer buff_p(p,sycl::range<1>(size)); sycl::buffer buff_state1(state1,sycl::range<1>(size)); sycl::buffer buff_state2(state2,sycl::range<1>(size)); + sycl::buffer buff_quantiles1(quantiles1,sycl::range<1>(size)); + sycl::buffer buff_quantiles2(quantiles2,sycl::range<1>(size)); + sycl::buffer buff_absmax1(absmax1,sycl::range<1>(size)); + sycl::buffer buff_absmax2(absmax2,sycl::range<1>(size)); switch(OPTIMIZER) { - case ADAM: + case 0: num_blocks = n/BLOCKSIZE_2STATE; num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; { @@ -919,6 +923,10 @@ template void optimizerStatic8bitBlockwise(T* p, T* g sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_quantiles2(buff_quantiles2, cgh, sycl::read_write); + sycl::accessor dacc_absmax1(buff_absmax1, cgh, sycl::read_write); + sycl::accessor dacc_absmax2(buff_absmax2, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); @@ -930,16 +938,16 @@ template void optimizerStatic8bitBlockwise(T* p, T* g cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_2STATE/NUM_2STATE), sycl::range<3>(1, 1, BLOCKSIZE_2STATE/NUM_2STATE)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit2StateBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n,item_ct1, smem_quantiles1_acc_ct1, smem_quantiles2_acc_ct1,smem_exchange1_acc_ct1.get_pointer(), smem_exchange2_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_state2); + kOptimizerStatic8bit2StateBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n,item_ct1, smem_quantiles1_acc_ct1, smem_quantiles2_acc_ct1,smem_exchange1_acc_ct1.get_pointer(), smem_exchange2_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_state2, dacc_quantiles1, dacc_quantiles2, dacc_absmax1, dacc_absmax2); }); }); } break; - case MOMENTUM: - case RMSPROP: - case ADAGRAD: - case LION: + case 1: + case 2: + case 3: + case 4: num_blocks = n/BLOCKSIZE_1STATE; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; { @@ -954,6 +962,9 @@ template void optimizerStatic8bitBlockwise(T* p, T* g sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_absmax1(buff_absmax1, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<2>(2/*LANES*/, 257), cgh); @@ -962,7 +973,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE), sycl::range<3>(1, 1, BLOCKSIZE_1STATE/NUM_1STATE)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit1StateBlockwise(p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, smem_quantiles1_acc_ct1, smem_exchange1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1); + kOptimizerStatic8bit1StateBlockwise(p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n, item_ct1, smem_quantiles1_acc_ct1, smem_exchange1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_quantiles1, dacc_absmax1); }); }); } @@ -976,7 +987,6 @@ catch (sycl::exception const &exc) { std::exit(1); } - //============================percentile clipping=============================== template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) From 4f19d41541214e98a0e5def8aaad1f5c64cfd3ed Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Thu, 6 Jun 2024 08:38:32 -0700 Subject: [PATCH 55/66] fix nv issue k quants --- csrc/sycl/kernels.cpp | 35 ++++++++++++++-------------- csrc/sycl/kernels.h | 3 ++- csrc/sycl/ops.cpp | 53 +++++++++++++++++++++++++++++-------------- 3 files changed, 56 insertions(+), 35 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index d4ed4d5bb..39114deef 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -800,7 +800,8 @@ template &item_ct1, float *smem_code, float *smem_absmax_value,const sycl_la &tacc,const sycl::accessor &dacc_A, - const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out) + const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out, + const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax) { @@ -823,11 +824,11 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl auto *d_rand = dacc_rand.get_multi_ptr().get(); auto *d_out = dacc_out.get_multi_ptr().get(); - + //code //absmax if(DATA_TYPE == General8bit) for(int i = item_ct1.get_local_id(2); i < 256; i+=item_ct1.get_local_range(2)) - smem_code[i] = code[i]; + smem_code[i] = dacc_code[i]; for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) { @@ -863,7 +864,7 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl item_ct1.barrier(sycl::access::fence_space::local_space); if(item_ct1.get_local_id(2) == 0) - absmax[i/BLOCK_SIZE] = local_abs_max; + dacc_absmax[i/BLOCK_SIZE] = local_abs_max; else local_abs_max = smem_absmax_value[0]; @@ -2639,12 +2640,12 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char { if(weight_decay > 0.0f) { switch(OPTIMIZER) { - case 1: - case 3: - case 2: + case MOMENTUM: + case ADAGRAD: + case RMSPROP: g_val += ((float)p_vals[j])*weight_decay; break; - case 4: + case LION: p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); break; } @@ -2654,21 +2655,21 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char switch(OPTIMIZER) { - case 1: + case MOMENTUM: if(step == 1) s1_vals[j] = g_val; else s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; - case 4: + case LION: // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; - case 2: + case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; - case 3: + case ADAGRAD: s1_vals[j] = s1_vals[j] + (g_val*g_val); break; } @@ -2700,17 +2701,17 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char { switch(OPTIMIZER) { - case 1: + case MOMENTUM: p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); break; - case 4: + case LION: p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); break; - case 2: + case RMSPROP: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(g_val / (sycl::sqrt(s1_vals[j])+eps)); break; - case 3: + case ADAGRAD: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(g_val / (sycl::sqrt(s1_vals[j])+eps)); break; @@ -4747,7 +4748,7 @@ template SYCL_EXTERNAL void kPercentileClipping(sycl::half #define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ -template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value,const sycl_la &tacc,const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out, const sycl_dacc_float &dacc_code); +template void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code, float *smem_absmax_value,const sycl_la &tacc,const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(sycl::half, 4096, 4, 1, General8bit) diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 43bd6cd4f..9a6b10153 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -38,7 +38,8 @@ template &dacc_A, - const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out); + const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out, + const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out); template diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index dd442886e..ae5d5378c 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -202,9 +202,14 @@ template void quantizeBlockwise(floa int size= NUM_BLOCK; for(int i=0; i< NUM_BLOCK; i++){ out[i]=out[(DATA_TYPE > 0) ? i/2 : i];}; + sycl::buffer buff_A(A,sycl::range<1>(size)); sycl::buffer buff_rand(rand,sycl::range<1>(size)); sycl::buffer buff_out(out,sycl::range<1>(size)); + sycl::buffer buff_code(code,sycl::range<1>(size)); + sycl::buffer buff_absmax(absmax,sycl::range<1>(size)); + + if(blocksize == 4096) @@ -223,6 +228,8 @@ template void quantizeBlockwise(floa sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); //__shared__ vars for funtions @@ -233,7 +240,7 @@ template void quantizeBlockwise(floa cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); }); }); } @@ -252,6 +259,8 @@ template void quantizeBlockwise(floa sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); @@ -260,7 +269,7 @@ template void quantizeBlockwise(floa cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); }); }); } @@ -278,7 +287,9 @@ template void quantizeBlockwise(floa sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); - + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + //__shared__vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); @@ -286,7 +297,7 @@ template void quantizeBlockwise(floa cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); }); }); } @@ -305,6 +316,8 @@ template void quantizeBlockwise(floa sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); @@ -313,7 +326,7 @@ template void quantizeBlockwise(floa cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); }); }); } @@ -331,7 +344,9 @@ template void quantizeBlockwise(floa sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); - + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); @@ -340,7 +355,7 @@ template void quantizeBlockwise(floa cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); }); }); } @@ -358,14 +373,16 @@ template void quantizeBlockwise(floa sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); - + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); }); }); } @@ -383,7 +400,9 @@ template void quantizeBlockwise(floa sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); sycl::accessor dacc_rand(buff_rand, cgh, sycl::read_write); - + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_code_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_absmax_value_acc_ct1(sycl::range<1>(1), cgh); @@ -392,7 +411,7 @@ template void quantizeBlockwise(floa cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { - kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out); + kQuantizeBlockwise(code, A, absmax, out, rand, rand_offset, n, item_ct1, smem_code_acc_ct1.get_pointer(), smem_absmax_value_acc_ct1.get_pointer(), tacc, dacc_A, dacc_rand, dacc_out, dacc_code, dacc_absmax); }); }); } @@ -893,7 +912,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g sycl::queue &q_ct1 = dev_ct1.in_order_queue(); sycl::context ctx = q_ct1.get_context(); int num_blocks = 0; - int size = NUM_BLOCK; + int size = BLOCKSIZE_2STATE; sycl::buffer buff_g(g,sycl::range<1>(size)); sycl::buffer buff_p(p,sycl::range<1>(size)); @@ -907,7 +926,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g switch(OPTIMIZER) { - case 0: + case ADAM: num_blocks = n/BLOCKSIZE_2STATE; num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; { @@ -944,10 +963,10 @@ template void optimizerStatic8bitBlockwise(T* p, T* g } break; - case 1: - case 2: - case 3: - case 4: + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + case LION: num_blocks = n/BLOCKSIZE_1STATE; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; { From dd709caf465f314ecf2ddb982afb0a25dba158e6 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Thu, 6 Jun 2024 10:33:48 -0700 Subject: [PATCH 56/66] fix nv issue on row col quant --- csrc/sycl/kernels.cpp | 20 +++++++++++--------- csrc/sycl/kernels.h | 5 ++++- csrc/sycl/ops.cpp | 29 +++++++++++++++++++++++++++-- 3 files changed, 42 insertions(+), 12 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 39114deef..877dd225a 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -3078,7 +3078,7 @@ template void kdequant_mm_i } //=====================================k double row col quant============================ -template void kDoubleRowColQuant(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed) +template void kDoubleRowColQuant(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr) { // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD // Each thread reads the same column but multiple rows @@ -3097,6 +3097,7 @@ template >; using group_store_char = dpct::group::workgroup_store>; @@ -3117,15 +3118,15 @@ template (&smem_nnz_row_idx[row], UINT_MAX); - rowidx[old_idx] = base_row+row; - colidx[old_idx] = base_col+(item_ct1.get_local_id(2)*ITEMS_PER_THREAD)+j; - val[old_idx] = local_data[j]; + dacc_rowidx[old_idx] = base_row+row; + dacc_colidx[old_idx] = base_col+(item_ct1.get_local_id(2)*ITEMS_PER_THREAD)+j; + dacc_val[old_idx] = local_data[j]; } else { @@ -3208,6 +3209,7 @@ template SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, @@ -4604,9 +4606,9 @@ template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPER -template void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed); +template void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr); -template void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed); +template void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr); diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 9a6b10153..e681af880 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -206,7 +206,10 @@ template &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, - const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed); + const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed, + const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr); template extern SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index ae5d5378c..9a8cd1bbb 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1602,6 +1602,15 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::buffer buff_out_col_normed(out_col_normed,sycl::range<1>(size)); sycl::buffer buff_out_row_normed(out_row_normed,sycl::range<1>(size)); + sycl::buffer buff_rowStats(rowStats,sycl::range<1>(size)); + sycl::buffer buff_colStats(colStats,sycl::range<1>(size)); + sycl::buffer buff_rowidx(rowidx,sycl::range<1>(size)); + sycl::buffer buff_colidx(colidx,sycl::range<1>(size)); + sycl::buffer buff_val(val,sycl::range<1>(size)); + sycl::buffer buff_nnz_block_ptr(nnz_block_ptr,sycl::range<1>(size)); + + + int threads = 64; int items_per_thread = 4; int tile_cols = threads*items_per_thread; @@ -1628,6 +1637,14 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::accessor dacc_out_col_normed(buff_out_col_normed, cgh, sycl::read_write); sycl::accessor dacc_out_row_normed(buff_out_row_normed, cgh, sycl::read_write); + sycl::accessor dacc_rowStats(buff_rowStats, cgh, sycl::read_write); + sycl::accessor dacc_colStats(buff_colStats, cgh, sycl::read_write); + sycl::accessor dacc_rowidx(buff_rowidx, cgh, sycl::read_write); + sycl::accessor dacc_colidx(buff_colidx, cgh, sycl::read_write); + sycl::accessor dacc_val(buff_val, cgh, sycl::read_write); + sycl::accessor dacc_nnz_block_ptr(buff_nnz_block_ptr, cgh, sycl::read_write); + + //__shared__ vars sycl::local_accessor smem_row_stats_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_nnz_row_idx_acc_ct1(sycl::range<1>(256), cgh); @@ -1636,7 +1653,7 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols, item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), tacc, dacc_A, dacc_out_col_normed, dacc_out_row_normed); + kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols, item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), tacc, dacc_A, dacc_out_col_normed, dacc_out_row_normed, dacc_rowStats, dacc_colStats, dacc_rowidx, dacc_colidx, dacc_val, dacc_nnz_block_ptr); }); }); } @@ -1654,6 +1671,14 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::accessor dacc_out_col_normed(buff_out_col_normed, cgh, sycl::read_write); sycl::accessor dacc_out_row_normed(buff_out_row_normed, cgh, sycl::read_write); + sycl::accessor dacc_rowStats(buff_rowStats, cgh, sycl::read_write); + sycl::accessor dacc_colStats(buff_colStats, cgh, sycl::read_write); + sycl::accessor dacc_rowidx(buff_rowidx, cgh, sycl::read_write); + sycl::accessor dacc_colidx(buff_colidx, cgh, sycl::read_write); + sycl::accessor dacc_val(buff_val, cgh, sycl::read_write); + sycl::accessor dacc_nnz_block_ptr(buff_nnz_block_ptr, cgh, sycl::read_write); + + //__shared__ vars sycl::local_accessor smem_row_stats_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_nnz_row_idx_acc_ct1(sycl::range<1>(256), cgh); @@ -1663,7 +1688,7 @@ void doubleRowColQuant(sycl::half * A, float *rowStats, float *colStats, char *o sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols,item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), tacc, dacc_A, dacc_out_col_normed, dacc_out_row_normed); + kDoubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols,item_ct1, smem_row_stats_acc_ct1.get_pointer(), smem_nnz_row_idx_acc_ct1.get_pointer(), tacc, dacc_A, dacc_out_col_normed, dacc_out_row_normed, dacc_rowStats, dacc_colStats, dacc_rowidx, dacc_colidx, dacc_val, dacc_nnz_block_ptr); }); }); From 0fbedfea1c0a64a0a2dc74908f6c7efe96a3858b Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Fri, 7 Jun 2024 01:15:19 -0700 Subject: [PATCH 57/66] fix nv issue on 32 bit --- csrc/sycl/kernels.cpp | 45 +++++++++++++++++++++++-------------------- csrc/sycl/kernels.h | 13 +++++++++---- csrc/sycl/ops.cpp | 35 +++++++++++++++++++++------------ 3 files changed, 56 insertions(+), 37 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 877dd225a..84dc76810 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -1053,17 +1053,16 @@ SYCL_EXTERNAL void kDequantize(float *code, unsigned char *buff_A, float *buff_o //===================32 bit optimizer======================== - -template /* DPCT1110:1: The total declared local variable size in device function kPreconditionOptimizer32bit2State exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. */ +template SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g) + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g, const sycl_dacc_float &dacc_unorm) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); @@ -1085,7 +1084,10 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, auto *d_g = dacc_g.template get_multi_ptr().get(); auto *d_state1 = dacc_state1.get_multi_ptr().get(); auto *d_state2 = dacc_state2.get_multi_ptr().get(); - + + + + for (unsigned int i = base_idx; i < n_full; i += item_ct1.get_group_range(2)*BLOCK_SIZE) { valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; @@ -1134,7 +1136,7 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, { switch(OPTIMIZER) { - case 1: + case ADAM: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); s1_vals[j] *= correction1; @@ -1154,7 +1156,7 @@ void kPreconditionOptimizer32bit2State(T* g, T* p, s1_vals[0] = sycl::reduce_over_group(item_ct1.get_group(), s1_vals[0], sycl::plus<>()); if(item_ct1.get_local_id(2) == 0) - dpct::atomic_fetch_add(&unorm[0], s1_vals[0]); + dpct::atomic_fetch_add(&dacc_unorm[0], s1_vals[0]); sycl::group_barrier(item_ct1.get_sub_group()); } @@ -1168,7 +1170,7 @@ void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2) + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2, const sycl_dacc_float &dacc_unorm) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -1187,7 +1189,7 @@ void kOptimizer32bit2State(T* g, T* p, if(max_unorm > 0.0f) { - update_scale = max_unorm > 0.0f ? sycl::sqrt(unorm[0]) : 1.0f; + update_scale = max_unorm > 0.0f ? sycl::sqrt(dacc_unorm[0]) : 1.0f; if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } else{ update_scale = 1.0f; } } @@ -1270,7 +1272,7 @@ void kOptimizer32bit2State(T* g, T* p, { switch(OPTIMIZER) { - case 1: + case ADAM: if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); @@ -1329,7 +1331,8 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, float* buff_state1, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1) + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1, + const sycl_dacc_float &dacc_unorm) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); @@ -1424,7 +1427,7 @@ void kPreconditionOptimizer32bit1State(T* g, T* p, s1_vals[0] = sycl::reduce_over_group(item_ct1.get_group(), s1_vals[0], sycl::plus<>()); if(item_ct1.get_local_id(2) == 0) - dpct::atomic_fetch_add(&unorm[0], s1_vals[0]); + dpct::atomic_fetch_add(&dacc_unorm[0], s1_vals[0]); sycl::group_barrier(item_ct1.get_sub_group()); } @@ -1438,7 +1441,7 @@ void kOptimizer32bit1State(T *g, T *p, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const - sycl_dacc_float &dacc_state1) + sycl_dacc_float &dacc_state1, const sycl_dacc_float &dacc_unorm) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -1448,7 +1451,7 @@ void kOptimizer32bit1State(T *g, T *p, if(max_unorm > 0.0f) { - update_scale = max_unorm > 0.0f ? sycl::sqrt(unorm[0]) : 1.0f; + update_scale = max_unorm > 0.0f ? sycl::sqrt(dacc_unorm[0]) : 1.0f; if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } else{ update_scale = 1.0f; } } @@ -1575,7 +1578,6 @@ void kOptimizer32bit1State(T *g, T *p, } } - //===================8 bit optimizer======================== @@ -4608,7 +4610,8 @@ template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPER template void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr); -template void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr); +template void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normedconst sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr); @@ -4627,7 +4630,7 @@ template SYCL_EXTERNAL void kEstimateQuantiles(sycl::half *__restric template SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ float* state1, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1); \ + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1, const sycl_dacc_float &dacc_unorm); \ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, sycl::half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) @@ -4641,7 +4644,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ template SYCL_EXTERNAL void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1); \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1, const sycl_dacc_float &dacc_unorm); \ MAKE_Optimizer32bit1State(MOMENTUM, sycl::half) MAKE_Optimizer32bit1State(MOMENTUM, float) @@ -4657,7 +4660,7 @@ MAKE_Optimizer32bit1State(ADAGRAD, float) template SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ float* state1, float* state2, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g); \ + const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g, const sycl_dacc_float &dacc_unorm); \ MAKE_PreconditionOptimizer32bit2State(ADAM, float) MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::half) @@ -4666,15 +4669,15 @@ MAKE_PreconditionOptimizer32bit2State(ADAM, sycl::ext::oneapi::bfloat16) template SYCL_EXTERNAL void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, const sycl_dacc_float &dacc_state1, const sycl_dacc_float &dacc_state2); + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, const sycl_dacc_float &dacc_state1, const sycl_dacc_float &dacc_state2, const sycl_dacc_float &dacc_unorm); template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::half* g, sycl::half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2, const sycl_dacc_float &dacc_unorm); template SYCL_EXTERNAL void kOptimizer32bit2State(sycl::ext::oneapi::bfloat16* g, sycl::ext::oneapi::bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2, const sycl_dacc_float &dacc_unorm); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ template SYCL_EXTERNAL void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index e681af880..50a009967 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -42,26 +42,29 @@ template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out); +//====================32 bit headers============================= + template extern SYCL_EXTERNAL void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2,const sycl::accessor &dacc_g, const sycl_dacc_float &dacc_unorm); template extern SYCL_EXTERNAL void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const sycl_dacc_float &dacc_state1,const sycl_dacc_float &dacc_state2, const sycl_dacc_float &dacc_unorm); template extern SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1); + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1, + const sycl_dacc_float &dacc_unorm); template extern SYCL_EXTERNAL void kOptimizer32bit1State(T* g, T* p, @@ -69,8 +72,10 @@ extern SYCL_EXTERNAL void kOptimizer32bit1State(T* g, T* p, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl::accessor &dacc_p,const - sycl_dacc_float &dacc_state1); + sycl_dacc_float &dacc_state1, const sycl_dacc_float &dacc_unorm); + +//==============================8 bit headers========================== template extern SYCL_EXTERNAL void kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 9a8cd1bbb..c8daadbac 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -505,14 +505,14 @@ template void optimizer32bit(T* g, T* p, sycl::buffer buff_p(p,sycl::range<1>(size)); sycl::buffer buff_state1(state1,sycl::range<1>(size)); sycl::buffer buff_state2(state2,sycl::range<1>(size)); - + sycl::buffer buff_unorm(unorm, sycl::range<1>(size)); switch(OPTIMIZER) { case ADAM: if(max_unorm > 0.0f) { - + std::memset(unorm, 0, 1*sizeof(float)); //DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); { @@ -530,11 +530,13 @@ template void optimizer32bit(T* g, T* p, sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizer32bit2State(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_state1, dacc_state2, dacc_g); + kPreconditionOptimizer32bit2State(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_state1, dacc_state2, dacc_g, dacc_unorm); }); }); } @@ -556,12 +558,14 @@ template void optimizer32bit(T* g, T* p, sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); - sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); + sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizer32bit2State(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, tacc, dacc_g, dacc_p, dacc_state1, dacc_state2); + kOptimizer32bit2State(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, tacc, dacc_g, dacc_p, dacc_state1, dacc_state2, dacc_unorm); }); }); } @@ -572,6 +576,7 @@ template void optimizer32bit(T* g, T* p, case ADAGRAD: if(max_unorm > 0.0f) { + std::memset(unorm, 0, 1*sizeof(float)); //DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); { @@ -586,19 +591,19 @@ template void optimizer32bit(T* g, T* p, sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); - + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizer32bit1State(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_g, dacc_state1); + kPreconditionOptimizer32bit1State(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_g, dacc_state1, dacc_unorm); }); }); } } - { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( @@ -612,11 +617,13 @@ template void optimizer32bit(T* g, T* p, sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizer32bit1State(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, tacc, dacc_g, dacc_p, dacc_state1); + kOptimizer32bit1State(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, tacc, dacc_g, dacc_p, dacc_state1, dacc_unorm); }); }); } @@ -639,11 +646,13 @@ template void optimizer32bit(T* g, T* p, sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizer32bit1State(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, tacc, dacc_g, dacc_p, dacc_state1); + kOptimizer32bit1State(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n, item_ct1, tacc, dacc_g, dacc_p, dacc_state1, dacc_unorm); }); }); } @@ -651,6 +660,7 @@ template void optimizer32bit(T* g, T* p, if(max_unorm > 0.0f) { + std::memset(unorm, 0, 1*sizeof(float)); //DPCT_CHECK_ERROR(q_ct1.memset(unorm, 0, 1*sizeof(float)).wait()); { @@ -664,11 +674,13 @@ template void optimizer32bit(T* g, T* p, sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizer32bit1State(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_g, dacc_state1); + kPreconditionOptimizer32bit1State(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n, item_ct1, tacc, dacc_g, dacc_state1, dacc_unorm); }); }); } @@ -685,7 +697,6 @@ catch (sycl::exception const &exc) { - //============================8 bit optimizer=============================== #define NUM8BIT 16 From 38312cba697fb001e15a8c639084e2f82b5d8df9 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Fri, 7 Jun 2024 03:15:41 -0700 Subject: [PATCH 58/66] fix nv issue on k dequant blockwise --- csrc/sycl/kernels.cpp | 37 ++++++++++++++++++++++++------------- csrc/sycl/kernels.h | 14 +++++++++----- csrc/sycl/ops.cpp | 42 +++++++++++++++--------------------------- 3 files changed, 48 insertions(+), 45 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 84dc76810..649f5e1f9 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -933,11 +933,12 @@ SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, fl } } -//===========================k dequantize================================ +//===========================k dequantize blockwise================================ + template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, - const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out ) + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax ) { const int n_load = (item_ct1.get_group_range(2) * TILE_SIZE); @@ -956,7 +957,12 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * auto *d_A = dacc_A.template get_multi_ptr().get(); auto *d_out = dacc_out.template get_multi_ptr().get(); + //A //out //code //absmax + //typedef cub::BlockLoad LoadChar; + //typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + + for (unsigned int i = base_idx; i < n_load; i += item_ct1.get_group_range(2)*TILE_SIZE) { if(DATA_TYPE > 0) @@ -970,7 +976,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; } - local_abs_max = sycl::ext::oneapi::experimental::cuda::ldg(&absmax[(i+item_ct1.get_local_id(2)*NUM_PER_TH)/(blocksize)]); + local_abs_max = dacc_absmax[(i+item_ct1.get_local_id(2)*NUM_PER_TH)/(blocksize)];//sycl::ext::oneapi::experimental::cuda::ldg(&absmax[(i+item_ct1.get_local_id(2)*NUM_PER_TH)/(blocksize)]); item_ct1.barrier(sycl::access::fence_space::local_space); @@ -994,7 +1000,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) - vals[j] = sycl::ext::oneapi::experimental::cuda::ldg(&code[qvals[j]]*local_abs_max); + vals[j] = dacc_code[qvals[j]]*local_abs_max;//sycl::ext::oneapi::experimental::cuda::ldg(&code[qvals[j]]*local_abs_max); break; case FP4: #pragma unroll NUM_PER_TH @@ -1017,6 +1023,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * item_ct1.barrier(); + // 1. load 8 values per thread // 2. compute 2-max in registers (64 max per warp) // 3. do warp reduction + broadcast back @@ -1028,6 +1035,7 @@ SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * } } +//=========================k dequantize====================== SYCL_EXTERNAL void kDequantize(float *code, unsigned char *buff_A, float *buff_out, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code) @@ -4823,21 +4831,24 @@ MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(sycl::ext::oneapi::bfloat16, 64, 2, 0, NF4) -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); + +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::half *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); -template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out); +template SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, sycl::ext::oneapi::bfloat16 *out, const int blocksize, const int n,const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl_dacc_uc &dacc_A, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ template void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 50a009967..b5f53ece4 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -34,13 +34,15 @@ extern SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const sycl::nd_item<3> &item_ct1, float *smem_code); template extern SYCL_EXTERNAL void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n, - const sycl::nd_item<3> &item_ct1, - float *smem_code, - float *smem_absmax_value, - const sycl_la &tacc,const sycl::accessor &dacc_A, + const sycl::nd_item<3> &item_ct1, + float *smem_code, + float *smem_absmax_value, + const sycl_la &tacc,const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rand, const sycl_dacc_uc &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); -template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out); + +//=========================k-dequant blockwise ====================== +template extern SYCL_EXTERNAL void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl_dacc_uc &dacc_A,const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_code, const sycl_dacc_float &dacc_absmax); //====================32 bit headers============================= @@ -133,6 +135,8 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2,const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); + +//====================8 bit blockwise========================= template extern SYCL_EXTERNAL void kOptimizerStatic8bit2StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index c8daadbac..4a3323051 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -432,30 +432,12 @@ template void dequantizeBlockwise(float *code, unsign sycl::buffer buff_A(A,sycl::range<1>(size)); sycl::buffer buff_out(out,sycl::range<1>(size)); + sycl::buffer buff_code(code,sycl::range<1>(size)); + sycl::buffer buff_absmax(absmax,sycl::range<1>(size)); - if(DATA_TYPE > 0) - { - dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); - q_ct1.submit( - [&](sycl::handler &cgh){ - - using group_load = dpct::group::workgroup_load>; - - size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); - sycl::local_accessor tacc(temp_storage_size, cgh); - - sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); - sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); - - q_ct1.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, (n+tile_size-1)/tile_size) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), - [=](sycl::nd_item<3> item_ct1) { - kDequantizeBlockwise(code, A, absmax, out, blocksize/2, n, item_ct1, tacc, dacc_A, dacc_out); - }); - }); - } - else - { + + + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh){ @@ -467,15 +449,21 @@ template void dequantizeBlockwise(float *code, unsign sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); - + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, (n+tile_size-1)/tile_size) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { - kDequantizeBlockwise(code, A, absmax, out, blocksize, n, item_ct1, tacc, dacc_A, dacc_out); + if(DATA_TYPE > 0){ + kDequantizeBlockwise(code, A, absmax, out, blocksize/2, n, item_ct1, tacc, dacc_A, dacc_out, dacc_code, dacc_absmax); } + else{ + kDequantizeBlockwise(code, A, absmax, out, blocksize, n, item_ct1, tacc, dacc_A, dacc_out, dacc_code, dacc_absmax); + } }); + }); - } - + } From 1d6f56b120833fecf1f59cc9700f9da010d6a6b0 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Fri, 7 Jun 2024 05:46:13 -0700 Subject: [PATCH 59/66] fix nv issue on gemms and some quant kernels --- csrc/sycl/kernels.cpp | 251 +++++++++++++++++++++++++++++------------- csrc/sycl/kernels.h | 33 ++++-- csrc/sycl/ops.cpp | 147 +++++++++++++++++++------ 3 files changed, 315 insertions(+), 116 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 649f5e1f9..a0731c065 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -508,18 +508,21 @@ __dpct_inline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict_ //=====================================histogram 2d==================== SYCL_EXTERNAL void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, - const sycl::nd_item<3> &item_ct1) + const sycl::nd_item<3> &item_ct1, const sycl_dacc_float &dacc_histogram, const sycl_dacc &dacc_index1, + const sycl_dacc &dacc_index2, const sycl_dacc_float &dacc_src) { const int tid = item_ct1.get_local_id(2) + (item_ct1.get_local_range(2)*item_ct1.get_group(2)); const int numThreads = item_ct1.get_local_range(2)*item_ct1.get_group_range(2); for(int i = tid; i < n; i+=numThreads) { - int idx = (index1[i]*maxidx1) + index2[i]; - dpct::atomic_fetch_add(&histogram[idx], src[i]); + int idx = (dacc_index1[i]*maxidx1) + dacc_index2[i]; + dpct::atomic_fetch_add(&dacc_histogram[idx], dacc_src[i]); } } + + //===========================k compress max========================== template @@ -639,7 +642,7 @@ typedef sycl::accessor sycl_dacc_char; template SYCL_EXTERNAL void kEstimateQuantiles(const T *A, float *code, const float offset, const T max_val, const int n, - const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A) + const sycl::nd_item<3> &item_ct1, sycl_la tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code) { const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; @@ -647,7 +650,7 @@ void kEstimateQuantiles(const T *A, float *code, const float offset, const T max const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); using group_load = dpct::group::workgroup_load>; - using group_radix_sort = dpct::group::radix_sort; + //using group_radix_sort = dpct::group::radix_sort; T vals[NUM_ESTIMATE]; auto *d_A = dacc_A.template get_multi_ptr().get(); @@ -719,7 +722,7 @@ void kEstimateQuantiles(const T *A, float *code, const float offset, const T max for(int i = item_ct1.get_local_id(2); i < BLOCK_ESTIMATE; i+=item_ct1.get_local_range(2)) { if(smem_qidx[i] != -1) - dpct::atomic_fetch_add(&code[smem_qidx[i]], vals[i/THREADS_ESTIMATE]); + dpct::atomic_fetch_add(&dacc_code[smem_qidx[i]], vals[i/THREADS_ESTIMATE]); } } @@ -1606,7 +1609,10 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float* smem_quantiles1, float* smem_quantiles2, - const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2) + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, + const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_max2, const sycl_dacc_float &dacc_new_max1, + const sycl_dacc_float &dacc_new_max2) { const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); @@ -1631,8 +1637,8 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c if(item_ct1.get_local_id(2) < 256) { - smem_quantiles1[item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; - smem_quantiles2[item_ct1.get_local_id(2)] = quantiles2[item_ct1.get_local_id(2)]; + smem_quantiles1[item_ct1.get_local_id(2)] = dacc_quantiles1[item_ct1.get_local_id(2)]; + smem_quantiles2[item_ct1.get_local_id(2)] = dacc_quantiles2[item_ct1.get_local_id(2)]; } @@ -1681,7 +1687,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c { g_val = g_vals[j]; g_val *= gnorm_scale; - s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1; + s1_vals[j] = smem_quantiles1[m_c1[j]]*dacc_max1[0]*beta1; s1_vals[j] += (1.0f-beta1)*g_val; local_max_s1 = sycl::fmax(local_max_s1, sycl::fabs(s1_vals[j])); } @@ -1691,7 +1697,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c { g_val = g_vals[j]; g_val *= gnorm_scale; - s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2; + s2_vals[j] = smem_quantiles2[r_c2[j]]*dacc_max2[0]*beta2; s2_vals[j] += (1.0f-beta2)*g_val*g_val; local_max_s2 = sycl::fmax(local_max_s2, sycl::fabs(s2_vals[j])); } @@ -1729,9 +1735,9 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c if(item_ct1.get_local_id(2) == 0) { - atomicMax(&new_max1[0], local_max_s1); - atomicMax(&new_max2[0], local_max_s2); - if(unorm != NULL){ dpct::atomic_fetch_add(&unorm[0], local_unorm); } + atomicMax(&dacc_new_max1[0], local_max_s1); + atomicMax(&dacc_new_max2[0], local_max_s2); + if(unorm != NULL){ dpct::atomic_fetch_add(&dacc_unorm[0], local_unorm); } } } @@ -1739,6 +1745,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c #define NUM_THREADS2 1024 #define NUM_PER_BLOCK2 4096 + template SYCL_EXTERNAL void @@ -1752,7 +1759,10 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float* smem_quantiles1, float* smem_quantiles2, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, - const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2 + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, const sycl_dacc_float &dacc_unorm, + const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_max2, const sycl_dacc_float &dacc_new_max1, + const sycl_dacc_float &dacc_new_max2 ) { @@ -1766,13 +1776,13 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha const float correction2 = sycl::sqrt(1.0f - dpct::pow(beta2, step)); const float step_size = -lr*correction2/correction1; //const float step_size = -lr*correction2/correction1; - float new_max_val1 = 1.0f/new_max1[0]; - float new_max_val2 = 1.0f/new_max2[0]; + float new_max_val1 = 1.0f/dacc_new_max1[0]; + float new_max_val2 = 1.0f/dacc_new_max2[0]; float update_scale = 1.0f; if(max_unorm > 0.0f) { - update_scale = max_unorm > 0.0f ? sycl::sqrt((float)(unorm[0])) : 1.0f; + update_scale = max_unorm > 0.0f ? sycl::sqrt((float)(dacc_unorm[0])) : 1.0f; if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } else{ update_scale = 1.0f; } } @@ -1799,9 +1809,9 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha if(item_ct1.get_local_id(2) < 512) { if(item_ct1.get_local_id(2) < 256) - smem_quantiles1[item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; + smem_quantiles1[item_ct1.get_local_id(2)] = dacc_quantiles1[item_ct1.get_local_id(2)]; else - smem_quantiles2[item_ct1.get_local_id(2)-256] = quantiles2[item_ct1.get_local_id(2)-256]; + smem_quantiles2[item_ct1.get_local_id(2)-256] = dacc_quantiles2[item_ct1.get_local_id(2)-256]; } @@ -1863,7 +1873,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha g_val = float(g_vals[j]); g_val *= gnorm_scale; s1_vals[j] = smem_quantiles1[c1s[j]]; - s1_vals[j] = s1_vals[j]*max1[0]; + s1_vals[j] = s1_vals[j]*dacc_max1[0]; s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); @@ -1880,7 +1890,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha } s2_vals[j] = smem_quantiles2[c2s[j]]; - s2_vals[j] = s2_vals[j]*max2[0]; + s2_vals[j] = s2_vals[j]*dacc_max2[0]; s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2); } @@ -1940,7 +1950,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float* smem_quantiles1, - const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1) + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, + const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_new_max1) { const int n_full = item_ct1.get_group_range(2) * NUM_PER_BLOCK; const int base_idx = (item_ct1.get_group(2) * item_ct1.get_local_range(2) * NUM_PER_THREAD); @@ -1960,7 +1972,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c auto *d_state1 = dacc_state1.get_multi_ptr().get(); if(item_ct1.get_local_id(2) < 256) - smem_quantiles1[item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; + smem_quantiles1[item_ct1.get_local_id(2)] = dacc_quantiles1[item_ct1.get_local_id(2)]; item_ct1.barrier(sycl::access::fence_space::local_space); @@ -1996,7 +2008,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c { g_val = g_vals[j]; g_val *= gnorm_scale; - s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; + s1_vals[j] = smem_quantiles1[m_c1[j]]*dacc_max1[0]; switch(OPTIMIZER) { case MOMENTUM: @@ -2023,14 +2035,14 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c item_ct1.barrier(sycl::access::fence_space::local_space); local_max_s1 = sycl::reduce_over_group(item_ct1.get_group(), local_max_s1, sycl::maximum<>()); - if(item_ct1.get_local_id(2) == 0){ atomicMax(&new_max1[0], local_max_s1); } + if(item_ct1.get_local_id(2) == 0){ atomicMax(&dacc_new_max1[0], local_max_s1); } if(unorm != NULL) { item_ct1.barrier(sycl::access::fence_space::local_space); local_unorm = sycl::reduce_over_group(item_ct1.get_group(), local_unorm, sycl::plus<>()); - if(item_ct1.get_local_id(2) == 0){ dpct::atomic_fetch_add(&unorm[0], local_unorm); } + if(item_ct1.get_local_id(2) == 0){ dpct::atomic_fetch_add(&dacc_unorm[0], local_unorm); } } } @@ -2046,9 +2058,11 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, const int n, - const sycl::nd_item<3> &item_ct1,float *smem_quantiles1, const sycl_la &tacc, + const sycl::nd_item<3> &item_ct1,float *smem_quantiles1, const sycl_la tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, - const sycl_dacc_uc &dacc_state1) + const sycl_dacc_uc &dacc_state1, + const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_new_max1) { const int n_full = (item_ct1.get_local_range(2) * item_ct1.get_group_range(2))*NUM_PER_THREAD2; @@ -2056,12 +2070,12 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, int valid_items = 0; float g_val = 0.0f; float s1_vals[NUM_PER_THREAD2]; - float new_max_val1 = 1.0f/new_max1[0]; + float new_max_val1 = 1.0f/dacc_new_max1[0]; float update_scale = 1.0f; if(max_unorm > 0.0f) { - update_scale = max_unorm > 0.0f ? sycl::sqrt((float)(unorm[0])) : 1.0f; + update_scale = max_unorm > 0.0f ? sycl::sqrt((float)(dacc_unorm[0])) : 1.0f; if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } else{ update_scale = 1.0f; } } @@ -2087,7 +2101,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, if(item_ct1.get_local_id(2) < 256) - smem_quantiles1[item_ct1.get_local_id(2)] = quantiles1[item_ct1.get_local_id(2)]; + smem_quantiles1[item_ct1.get_local_id(2)] = dacc_quantiles1[item_ct1.get_local_id(2)]; item_ct1.barrier(sycl::access::fence_space::local_space); @@ -2149,7 +2163,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, } } - s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; + s1_vals[j] = smem_quantiles1[c1s[j]]*dacc_max1[0]; switch(OPTIMIZER) { @@ -3835,6 +3849,7 @@ template inline void vector_load(T *loca #define WARPS 3 + template SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, T *smem_A, T *smem_B, const sycl::accessor &dacc_A, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out) { @@ -4034,7 +4049,7 @@ template SYCL_EXTERNAL void gemm_device(int //wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, a_frag, d_B, 16); + sycl::ext::oneapi::experimental::matrix::joint_matrix_load(sg_size, b_frag, d_B, 16); //wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); sycl::ext::oneapi::experimental::matrix::joint_matrix_mad(sg_size, c_frag, a_frag, b_frag, c_frag); @@ -4099,7 +4114,6 @@ template void printnonzero(sycl::half *A, int num_values, const char const float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; - template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *smem_A, unsigned char *smem_B, T *smem_C, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out) { @@ -4158,12 +4172,12 @@ const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::acce { if(loaded_values == 0) { - local_A[0] = A[idx]; - local_A[1] = A[idx+item_ct1.get_local_range(2)-32]; + local_A[0] = dacc_A[idx]; + local_A[1] = dacc_A[idx+item_ct1.get_local_range(2)-32]; #pragma unroll 32 for(int col = 0; col < 32; col++) - local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col] = dacc_B[(col_offset+col)*ldb+idx]; loaded_values = 1; } @@ -4224,14 +4238,14 @@ const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::acce { if(loaded_values == 0) { - local_A[0] = A[idx]; - local_A[1] = A[idx+item_ct1.get_local_range(2)-32]; + local_A[0] = dacc_A[idx]; + local_A[1] = dacc_A[idx+item_ct1.get_local_range(2)-32]; #pragma unroll 32 for(int col = 0; col < 32; col++) { - local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; - local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col] = dacc_B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = dacc_B[(col_offset+col)*ldb+idx]; } loaded_values = 1; @@ -4353,19 +4367,21 @@ const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::acce } + + //=========================================4 bit gemm naive=============== #define num_values_4bit 32 -template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *quant_map, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out) +template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *quant_map, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_absmax, const sycl_dacc_float &dacc_datatype) { // per threadblock: // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] // 4 warps -> 4 loads per iter // 1x32 * 32x4 -> 1x4 outputs per thread block - + // datatype absmax const int warp_idx = item_ct1.get_local_id(2) / 32; const int warp_lane = item_ct1.get_local_id(2) % 32; const int row_B = (THREADS/32)*item_ct1.get_group(2) + warp_idx; @@ -4379,7 +4395,7 @@ template SYCL_EXTERNAL void kgemm_4bit_infer T local_absmax = T(0.0f); for(int i = item_ct1.get_local_id(2); i < 16; i++) - quant_map[i] = T(datatype[i]); + quant_map[i] = T(dacc_datatype[i]); item_ct1.barrier(sycl::access::fence_space::local_space); // A: [1, K] @@ -4390,7 +4406,7 @@ template SYCL_EXTERNAL void kgemm_4bit_infer int offset_B = ldb*row_B; int absidx = ((2*offset_B)+inner_idx)/blocksize; - local_absmax = absmax[absidx]; + local_absmax = dacc_absmax[absidx]; if(row_B < M) { @@ -4492,82 +4508,162 @@ template SYCL_EXTERNAL void kfunc(float *A, float *B, float value, //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out);; template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); + template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - sycl::half *smem_C); + sycl::half *smem_C, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out); + template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - sycl::half *smem_C); + sycl::half *smem_C, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out); + template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - sycl::half *smem_C); + sycl::half *smem_C, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out); + template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - sycl::half *smem_C); + sycl::half *smem_C, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, sycl::half * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, - sycl::half *quant_map); + sycl::half *quant_map, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out, + const sycl_dacc_float &dacc_absmax, + const sycl_dacc_float &dacc_datatype); + template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, sycl::ext::oneapi::bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, sycl::ext::oneapi::bfloat16 * out, int lda, int ldb, int ldc, int blocksize, - const sycl::nd_item<3> &item_ct1, - sycl::ext::oneapi::bfloat16 *quant_map); + const sycl::nd_item<3> &item_ct1, + sycl::ext::oneapi::bfloat16 *quant_map, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out, + const sycl_dacc_float &dacc_absmax, + const sycl_dacc_float &dacc_datatype); + template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, - float *quant_map); + float *quant_map, + const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, + const sycl::accessor &dacc_out, + const sycl_dacc_float &dacc_absmax, + const sycl_dacc_float &dacc_datatype); + template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, @@ -4629,10 +4725,8 @@ template void kgetColRowStats(sycl::half * __res template unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template unsigned char dQuantize<1>(float* smem_code, const float rand, float x); -template SYCL_EXTERNAL void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n, - const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A); -template SYCL_EXTERNAL void kEstimateQuantiles(sycl::half *__restrict__ const A, float *code, const float offset, const sycl::half max_val, const int n, - const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A); +template SYCL_EXTERNAL void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code); +template SYCL_EXTERNAL void kEstimateQuantiles(sycl::half *__restrict__ const A, float *code, const float offset, const sycl::half max_val, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code); #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ @@ -4697,7 +4791,8 @@ template SYCL_EXTERNAL void kPreconditionOptimizerStatic8bit1State float* max1, float* new_max1, \ const float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1,const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1,const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, \ + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_new_max1); \ MAKE_PreconditionStatic8bit1State(MOMENTUM, sycl::half) MAKE_PreconditionStatic8bit1State(MOMENTUM, float) @@ -4718,7 +4813,8 @@ template void kOptimizerStatic8bit1State(gtype* p, gtype* const g, const float gnorm_scale, \ const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, const sycl_la &tacc, \ const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, \ - const sycl_dacc_uc &dacc_state1); \ + const sycl_dacc_uc &dacc_state1, const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, \ + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_new_max1); \ MAKE_optimizerStatic8bit1State(MOMENTUM, sycl::half) MAKE_optimizerStatic8bit1State(MOMENTUM, float) @@ -4735,7 +4831,9 @@ template void kPreconditionOptimizerStatic8bit2State(gtype* p, gty float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2,const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, \ + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_max2, const sycl_dacc_float &dacc_new_max1, \ + const sycl_dacc_float &dacc_new_max2); \ MAKE_PreconditionStatic8bit2State(ADAM, sycl::half) MAKE_PreconditionStatic8bit2State(ADAM, float) @@ -4749,7 +4847,10 @@ template void kOptimizerStatic8bit2State(gtype* p, gtype* const g, float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, \ const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); \ + const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, const sycl_dacc_float &dacc_unorm, \ + const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, \ + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_max2, const sycl_dacc_float &dacc_new_max1, \ + const sycl_dacc_float &dacc_new_max2); \ MAKE_optimizerStatic8bit2State(ADAM, sycl::half) MAKE_optimizerStatic8bit2State(ADAM, float) diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index b5f53ece4..4f5a878c0 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -24,12 +24,12 @@ typedef sycl::accessor sycl_dacc_char; //=========================================================== //template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); -template extern SYCL_EXTERNAL void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl::accessor &dacc_A); +template extern SYCL_EXTERNAL void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code); extern SYCL_EXTERNAL void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n, const sycl::nd_item<3> &item_ct1, float* smem_code, const sycl_la &tacc, const sycl_dacc_float &dacc_A, const sycl_dacc_uc &dacc_out, const sycl_dacc_float &dacc_code); + extern SYCL_EXTERNAL void kDequantize(float *code, unsigned char *A, float *out, const int n, const sycl::nd_item<3> &item_ct1, float *smem_code); @@ -90,7 +90,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, - const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1); + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, + const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_new_max1); template @@ -105,7 +107,9 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, - const sycl_dacc_uc &dacc_state11); + const sycl_dacc_uc &dacc_state1, + const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_new_max1); @@ -120,7 +124,10 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2, - const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); + const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, + const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_max2, const sycl_dacc_float &dacc_new_max1, + const sycl_dacc_float &dacc_new_max2); template @@ -134,7 +141,10 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha float weight_decay, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, float *smem_quantiles2,const sycl_la &tacc, const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, - const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2); + const sycl_dacc_uc &dacc_state1, const sycl_dacc_uc &dacc_state2, const sycl_dacc_float &dacc_unorm, + const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_quantiles2, + const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_max2, const sycl_dacc_float &dacc_new_max1, + const sycl_dacc_float &dacc_new_max2); //====================8 bit blockwise========================= @@ -178,8 +188,11 @@ template extern SYCL_EX template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g); +//===============histogram======================== + extern SYCL_EXTERNAL void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n, - const sycl::nd_item<3> &item_ct1); + const sycl::nd_item<3> &item_ct1, const sycl_dacc_float &dacc_histogram, const sycl_dacc &dacc_index1, + const sycl_dacc &dacc_index2, const sycl_dacc_float &dacc_src); template extern SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, @@ -204,6 +217,7 @@ extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A); +//===========================double row col quant=================== template extern SYCL_EXTERNAL void kDoubleRowColQuant(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, @@ -230,10 +244,11 @@ template extern SYCL_EXTERNAL void gemm_devi template extern SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *smem_A, unsigned char *smem_B, - T *smem_C, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out); + T *smem_C, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out, const sycl::accessor &dacc_A, + const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out); -template extern SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *quant_map, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out); +template extern SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *quant_map, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_absmax, const sycl_dacc_float &dacc_datatype); template extern SYCL_EXTERNAL void kfunc(T *A, T *B, T value, long n, const sycl::nd_item<3> &item_ct1); diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 4a3323051..a8a44626e 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -55,22 +55,35 @@ void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *sr int threads = 512; int num_blocks = n/threads; num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; + int size = NUM_BLOCK; + + + sycl::buffer buff_histogram(histogram,sycl::range<1>(size)); + sycl::buffer buff_index1(index1,sycl::range<1>(size)); + sycl::buffer buff_index2(index2,sycl::range<1>(size)); + sycl::buffer buff_src(src,sycl::range<1>(size)); + { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { + sycl::accessor dacc_histogram(buff_histogram, cgh, sycl::read_write); + sycl::accessor dacc_index1(buff_index1, cgh, sycl::read_write); + sycl::accessor dacc_index2(buff_index2, cgh, sycl::read_write); + sycl::accessor dacc_src(buff_src, cgh, sycl::read_write); + + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kHistogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n, item_ct1); + kHistogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n, item_ct1, dacc_histogram, dacc_index1, dacc_index2, dacc_src); }); }); } } - //============================estimate quantiles=============================== template void estimateQuantiles(T *A, float *code, float offset, int n) { @@ -78,22 +91,25 @@ template void estimateQuantiles(T *A, float *code, float offset, in sycl::queue &q_ct1 = dev_ct1.in_order_queue(); int num_blocks = n/4096; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + std::memset(code, 0, 256*sizeof(float)); //DPCT_CHECK_ERROR(q_ct1.memset(code, 0, 256*sizeof(float)).wait()); sycl::context ctx = q_ct1.get_context(); - int size = NUM_BLOCK; + int size = 512; sycl::buffer buff_A(A,sycl::range<1>(size)); - + sycl::buffer buff_code(code,sycl::range<1>(size)); { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { using group_load = dpct::group::workgroup_load>; - size_t temp_storage_size = group_radix_sort::get_local_memory_size(THREADS_ESTIMATE); + //using group_radix_sort = dpct::group::radix_sort; + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); sycl::local_accessor tacc(sycl::range<1>(temp_storage_size), cgh); sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + sycl::accessor dacc_code(buff_code, cgh, sycl::read_write); auto std_numeric_limits_T_max_ct3 = std::numeric_limits::max(); @@ -101,7 +117,7 @@ template void estimateQuantiles(T *A, float *code, float offset, in cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kEstimateQuantiles(A, code, offset, std_numeric_limits_T_max_ct3, n, item_ct1, tacc, dacc_A); + kEstimateQuantiles(A, code, offset, std_numeric_limits_T_max_ct3, n, item_ct1, tacc, dacc_A, dacc_code); }); }); @@ -713,18 +729,29 @@ template void optimizerStatic8bit(T* p, T* g, sycl::buffer buff_p(p,sycl::range<1>(size)); sycl::buffer buff_state1(state1,sycl::range<1>(size)); sycl::buffer buff_state2(state2,sycl::range<1>(size)); + + sycl::buffer buff_quantiles1(quantiles1,sycl::range<1>(size)); + sycl::buffer buff_quantiles2(quantiles2,sycl::range<1>(size)); + sycl::buffer buff_max1(max1,sycl::range<1>(size)); + sycl::buffer buff_max2(max2,sycl::range<1>(size)); + sycl::buffer buff_new_max1(new_max1,sycl::range<1>(size)); + sycl::buffer buff_new_max2(new_max2,sycl::range<1>(size)); + sycl::buffer buff_unorm(unorm,sycl::range<1>(size)); + if(max_unorm > 0.0f){ - q_ct1.memset(unorm, 0, 1*sizeof(float)).wait(); } + std::memset(unorm, 0, 1*sizeof(float)); } switch(OPTIMIZER) { case ADAM: + std::memset(new_max1, 0, 1*sizeof(float)); + std::memset(new_max2, 0, 1*sizeof(float)); //DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); //DPCT_CHECK_ERROR(q_ct1.memset(new_max2, 0, 1*sizeof(float)).wait()); { - dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + //dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { @@ -738,6 +765,15 @@ template void optimizerStatic8bit(T* p, T* g, sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_quantiles2(buff_quantiles2, cgh, sycl::read_write); + sycl::accessor dacc_max1(buff_max1, cgh, sycl::read_write); + sycl::accessor dacc_max2(buff_max2, cgh, sycl::read_write); + sycl::accessor dacc_new_max1(buff_new_max1, cgh, sycl::read_write); + sycl::accessor dacc_new_max2(buff_new_max2, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_quantiles2_acc_ct1(sycl::range<1>(256), cgh); @@ -746,13 +782,13 @@ template void optimizerStatic8bit(T* p, T* g, cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizerStatic8bit2State(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), tacc, dacc_g, dacc_state1, dacc_state2); + kPreconditionOptimizerStatic8bit2State(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), tacc, dacc_g, dacc_state1, dacc_state2, dacc_unorm, dacc_quantiles1, dacc_quantiles2, dacc_max1, dacc_max2, dacc_new_max1 , dacc_new_max2); }); }); } { - dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + //dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { @@ -765,6 +801,14 @@ template void optimizerStatic8bit(T* p, T* g, sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); sycl::accessor dacc_state2(buff_state2, cgh, sycl::read_write); + + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_quantiles2(buff_quantiles2, cgh, sycl::read_write); + sycl::accessor dacc_max1(buff_max1, cgh, sycl::read_write); + sycl::accessor dacc_max2(buff_max2, cgh, sycl::read_write); + sycl::accessor dacc_new_max1(buff_new_max1, cgh, sycl::read_write); + sycl::accessor dacc_new_max2(buff_new_max2, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); @@ -773,7 +817,7 @@ template void optimizerStatic8bit(T* p, T* g, cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit2State(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_state2); + kOptimizerStatic8bit2State(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), smem_quantiles2_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_state2, dacc_unorm, dacc_quantiles1, dacc_quantiles2, dacc_max1, dacc_max2, dacc_new_max1 , dacc_new_max2); }); }); } @@ -783,9 +827,10 @@ template void optimizerStatic8bit(T* p, T* g, case RMSPROP: case ADAGRAD: + std::memset(new_max1, 0, 1*sizeof(float)); //DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); { - dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + //dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { @@ -796,20 +841,25 @@ template void optimizerStatic8bit(T* p, T* g, sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_max1(buff_max1, cgh, sycl::read_write); + sycl::accessor dacc_new_max1(buff_new_max1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizerStatic8bit1State(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_state1); + kPreconditionOptimizerStatic8bit1State(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_state1, dacc_unorm, dacc_quantiles1, dacc_max1, dacc_new_max1); }); }); } { - dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + //dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { @@ -820,13 +870,19 @@ template void optimizerStatic8bit(T* p, T* g, sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_max1(buff_max1, cgh, sycl::read_write); + sycl::accessor dacc_new_max1(buff_new_max1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit1State(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1); + kOptimizerStatic8bit1State(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1,smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_unorm, dacc_quantiles1, dacc_max1, dacc_new_max1); }); }); } @@ -835,7 +891,7 @@ template void optimizerStatic8bit(T* p, T* g, case LION: { - dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + //dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { @@ -845,7 +901,13 @@ template void optimizerStatic8bit(T* p, T* g, sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); sycl::accessor dacc_p(buff_p, cgh, sycl::read_write); - sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_max1(buff_max1, cgh, sycl::read_write); + sycl::accessor dacc_new_max1(buff_new_max1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); @@ -853,30 +915,37 @@ template void optimizerStatic8bit(T* p, T* g, cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 1024), sycl::range<3>(1, 1, 1024)), [=](sycl::nd_item<3> item_ct1) { - kOptimizerStatic8bit1State(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1); + kOptimizerStatic8bit1State(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_p, dacc_state1, dacc_unorm, dacc_quantiles1, dacc_max1, dacc_new_max1); }); }); } - - DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); + std::memset(new_max1, 0, 1*sizeof(float)); + //DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); { - dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + //dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { using group_load = dpct::group::workgroup_load>; - size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); + size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); sycl::local_accessor tacc(temp_storage_size, cgh); sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); sycl::accessor dacc_state1(buff_state1, cgh, sycl::read_write); + + sycl::accessor dacc_quantiles1(buff_quantiles1, cgh, sycl::read_write); + sycl::accessor dacc_max1(buff_max1, cgh, sycl::read_write); + sycl::accessor dacc_new_max1(buff_new_max1, cgh, sycl::read_write); + sycl::accessor dacc_unorm(buff_unorm, cgh, sycl::read_write); + + //__shared__ vars sycl::local_accessor smem_quantiles1_acc_ct1(sycl::range<1>(256), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kPreconditionOptimizerStatic8bit1State(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_state1); + kPreconditionOptimizerStatic8bit1State(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n, item_ct1, smem_quantiles1_acc_ct1.get_pointer(), tacc, dacc_g, dacc_state1, dacc_unorm, dacc_quantiles1, dacc_max1, dacc_new_max1); }); }); } @@ -886,15 +955,12 @@ template void optimizerStatic8bit(T* p, T* g, break; } -} -catch (sycl::exception const &exc) { +}catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); } - - //============================8 bit blockwise optimizer=============================== #define BLOCKSIZE_2STATE 2048 @@ -1285,6 +1351,18 @@ catch (sycl::exception const &exc) { std::exit(1); } +/* +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +*/ + + + + //===========================gemm_host============================================ template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) @@ -1340,9 +1418,11 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - + + } + //============================gemm 4bit inference ================================ template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) @@ -1392,6 +1472,8 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi } } + + //============================gemm 4 bit inference naive ================= template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) @@ -1407,7 +1489,8 @@ template void gemm_4bit_inference_naive(int m, int n, int sycl::buffer buff_A (A, sycl::range<1>(size)); sycl::buffer buff_B (B, sycl::range<1>(size)); sycl::buffer buff_out (out, sycl::range<1>(size)); - + sycl::buffer buff_absmax(absmax, sycl::range<1>(size)); + sycl::buffer buff_datatype(datatype, sycl::range<1>(size)); { dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); @@ -1417,19 +1500,19 @@ template void gemm_4bit_inference_naive(int m, int n, int sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); sycl::accessor dacc_B(buff_B, cgh, sycl::read_write); sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); - + sycl::accessor dacc_absmax(buff_absmax, cgh, sycl::read_write); + sycl::accessor dacc_datatype(buff_datatype, cgh, sycl::read_write); sycl::local_accessor quant_map_acc_ct1(sycl::range<1>(16), cgh); cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { - kgemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, item_ct1, quant_map_acc_ct1.get_pointer(), dacc_A, dacc_B, dacc_out); + kgemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, item_ct1, quant_map_acc_ct1.get_pointer(), dacc_A, dacc_B, dacc_out, dacc_absmax, dacc_datatype); }); }); } } - //================================spm coo================================== void spmm_coo(int *A_rowidx, int *A_colidx, sycl::half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, sycl::half *B, int ldc, sycl::half* C, bool transposed_B) From b107b9cc7af1967fda9f16144981448bcc44095a Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Fri, 7 Jun 2024 05:50:39 -0700 Subject: [PATCH 60/66] refine --- csrc/sycl/ops.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index a8a44626e..578f3d8ce 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -94,7 +94,7 @@ template void estimateQuantiles(T *A, float *code, float offset, in std::memset(code, 0, 256*sizeof(float)); //DPCT_CHECK_ERROR(q_ct1.memset(code, 0, 256*sizeof(float)).wait()); sycl::context ctx = q_ct1.get_context(); - int size = 512; + int size = NUM_BLOCK; sycl::buffer buff_A(A,sycl::range<1>(size)); @@ -751,7 +751,7 @@ template void optimizerStatic8bit(T* p, T* g, //DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); //DPCT_CHECK_ERROR(q_ct1.memset(new_max2, 0, 1*sizeof(float)).wait()); { - //dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { @@ -788,7 +788,7 @@ template void optimizerStatic8bit(T* p, T* g, } { - //dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { @@ -830,7 +830,7 @@ template void optimizerStatic8bit(T* p, T* g, std::memset(new_max1, 0, 1*sizeof(float)); //DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); { - //dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { @@ -859,7 +859,7 @@ template void optimizerStatic8bit(T* p, T* g, { - //dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { @@ -891,7 +891,7 @@ template void optimizerStatic8bit(T* p, T* g, case LION: { - //dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { @@ -922,7 +922,7 @@ template void optimizerStatic8bit(T* p, T* g, std::memset(new_max1, 0, 1*sizeof(float)); //DPCT_CHECK_ERROR(q_ct1.memset(new_max1, 0, 1*sizeof(float)).wait()); { - //dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { From 7838151df210d9a77cbeb7fb1f97d14c349423ab Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Fri, 7 Jun 2024 06:16:55 -0700 Subject: [PATCH 61/66] refine --- csrc/sycl/kernels.cpp | 98 +++++++++++++++++++++++++++++-------------- csrc/sycl/kernels.h | 26 ++++++++---- csrc/sycl/ops.cpp | 64 +++++++++++++++++++++------- 3 files changed, 132 insertions(+), 56 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index a0731c065..3c1b1faed 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -2225,15 +2225,14 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, //===============================k percentile clipping============================================ - template SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n, - const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl::accessor &dacc_g) + const sycl::nd_item<3> &item_ct1,const sycl_la &tacc, const sycl::accessor &dacc_g, float *dacc_gnorm_vec) { const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); int valid_items = 0; - using group_load = dpct::group::workgroup_load>; + using group_load = dpct_::group::workgroup_load>; auto *d_g = dacc_g.template get_multi_ptr().get(); T vals[NUM_VALS]; @@ -2270,10 +2269,10 @@ SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int // initialize with the same norm for all positions //#pragma unroll 10 for(int j = 0; j < 100; j++) - dpct::atomic_fetch_add(&gnorm_vec[j], local_sum); + dpct::atomic_fetch_add(&dacc_gnorm_vec[j], local_sum); } else - dpct::atomic_fetch_add(&gnorm_vec[step % 100], local_sum); + dpct::atomic_fetch_add(&dacc_gnorm_vec[step % 100], local_sum); } } @@ -3566,11 +3565,11 @@ const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out) //========================================k extract outliers====================== -template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out) +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out, const sycl_dacc &dacc_idx) { - int local_colidx = idx[item_ct1.get_group(2)]; + int local_colidx = dacc_idx[item_ct1.get_group(2)]; - if(FORMAT==COL_TURING) + if(FORMAT== COL_TURING) { // TURING FORMAT: // 8*32 tiles with 4*4 subtiles @@ -3622,7 +3621,7 @@ template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, cha int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; offset += tile_offset_cols + tile_offset_rows; - char val = A[offset]; + char val = dacc_A[offset]; int out_idx = (row*idx_size) + item_ct1.get_group(2); dacc_out[out_idx] = val; } @@ -3660,7 +3659,7 @@ template SYCL_EXTERNAL void kfunc(T *A, T *B, T value, lo template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats) + sycl::half *smem_dequant_stats, const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, const sycl_dacc &dacc_offset_rowidx, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats) { // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block @@ -3674,10 +3673,11 @@ SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int // 4. Do mma operations that accumulate into registers // 5. Each warp stores its output row into matrix C - const int count = max_count[item_ct1.get_group(2)]; - const int local_max_idx = max_idx[item_ct1.get_group(2)]; - const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; - const int local_row_idx = rowidx[offset]; + + const int count = dacc_max_count[item_ct1.get_group(2)]; + const int local_max_idx = dacc_max_idx[item_ct1.get_group(2)]; + const int offset = local_max_idx == 0 ? 0 : dacc_offset_rowidx[local_max_idx-1]; + const int local_row_idx = dacc_rowidx[offset]; const int warp_id = item_ct1.get_local_id(2) / 32; const int warp_idx = item_ct1.get_local_id(2) % 32; @@ -3696,8 +3696,8 @@ SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int // 2. Load A into registers for(int j = 0; j < MAX_SPARSE_COUNT; j++) { - local_valA[j] = j < count ? values[offset+j] : sycl::vec(0.0f).convert()[0]; - local_colidxA[j] = j < count ? colidx[offset+j] : 0; + local_valA[j] = j < count ? dacc_values[offset+j] : sycl::vec(0.0f).convert()[0]; + local_colidxA[j] = j < count ? dacc_colidx[offset+j] : 0; } // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 @@ -3714,11 +3714,9 @@ SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int { for(int i = item_ct1.get_local_id(2); i < SMEM_SIZE; i+=item_ct1.get_local_range(2)) if((idx_col_B+i-local_idx_col_B_offset) < colsB) - smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; + smem_dequant_stats[i] = dacc_dequant_stats[idx_col_B+i-local_idx_col_B_offset]; - /* - DPCT1065:204: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); } @@ -3750,7 +3748,7 @@ SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int #pragma unroll num_items for(int k = 0; k < num_items; k++) if(idx+k < colsB) - local_valsB[k] = B[row_offset+idx+k]; + local_valsB[k] = dacc_B[row_offset+idx+k]; else local_valsB[k] = 0.0f; } @@ -3797,7 +3795,7 @@ SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int #pragma unroll num_items for(int k = 0; k < num_items; k++) if(idx_col_C + k < colsB) - out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; + dacc_out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; } } @@ -4667,22 +4665,57 @@ template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, in template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); + template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); + template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); + template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); + template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); + template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); @@ -4691,9 +4724,9 @@ template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int template void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl::accessor &dacc_out, const sycl::accessor &dacc_bias); -template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out, const sycl_dacc &dacc_idx); -template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out, const sycl_dacc &dacc_idx); @@ -4856,9 +4889,10 @@ MAKE_optimizerStatic8bit2State(ADAM, sycl::half) MAKE_optimizerStatic8bit2State(ADAM, float) template SYCL_EXTERNAL void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n, - const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g); + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g, + float *dacc_gnorm_vec); template SYCL_EXTERNAL void kPercentileClipping(sycl::half * __restrict__ g, float *gnorm_vec, int step, const int n, - const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g); + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g, float *dacc_gnorm_vec); #define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 4f5a878c0..94bd31fa6 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -184,8 +184,8 @@ template extern SYCL_EX const sycl_dacc_float &dacc_quantiles1, const sycl_dacc_float &dacc_absmax1); - -template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g); +//=======================percentile clipping============================ +template extern SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n,const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_g, float *dacc_gnorm_vec); //===============histogram======================== @@ -194,6 +194,7 @@ extern SYCL_EXTERNAL void kHistogramScatterAdd2D(float* histogram, int *index1, const sycl::nd_item<3> &item_ct1, const sycl_dacc_float &dacc_histogram, const sycl_dacc &dacc_index1, const sycl_dacc &dacc_index2, const sycl_dacc_float &dacc_src); +//====================spm======================= template extern SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, @@ -201,8 +202,14 @@ extern SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_i float *__restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_dequant_stats); - + sycl::half *smem_dequant_stats, + const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, + const sycl_dacc &dacc_offset_rowidx, + const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); + +//=====================mm dequant ==================================== template extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( int *__restrict__ const A, float *__restrict__ const rowStats, @@ -213,6 +220,7 @@ extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl::accessor &dacc_out, const sycl::accessor &dacc_bias ); +//==================k row col stats===================== template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A); @@ -234,20 +242,22 @@ template &dacc_val, const sycl_dacc &dacc_nnz_block_ptr); - +//==============================k transfrom row col===================== template extern SYCL_EXTERNAL void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); -template extern SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); +//========================k extract outliers========================= +template extern SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out, const sycl_dacc &dacc_idx); +//=========================gemm device============================ template extern SYCL_EXTERNAL void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, T *smem_A, T *smem_B, const sycl::accessor &dacc_A, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out); - +//=========================gemm 4 bit inf================================ template extern SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *smem_A, unsigned char *smem_B, T *smem_C, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out); - +//====================gemm 4 bit naive inf============================ template extern SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *quant_map, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_absmax, const sycl_dacc_float &dacc_datatype); template extern SYCL_EXTERNAL void kfunc(T *A, T *B, T value, long n, diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 578f3d8ce..2b2c4f335 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1084,23 +1084,27 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, int size = NUM_BLOCK; sycl::buffer buff_g(g,sycl::range<1>(size)); - q_ct1.memset(&gnorm_vec[step % 100], 0, 1*sizeof(float)).wait(); + std::memset(&gnorm_vec[step % 100], 0, 1*sizeof(float)); + sycl::buffer buff_gnorm_vec(gnorm_vec, sycl::range<1>(size)); { dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { - using group_load = dpct::group::workgroup_load>; + using group_load = dpct_::group::workgroup_load>; size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); sycl::local_accessor tacc(temp_storage_size, cgh); sycl::accessor dacc_g(buff_g, cgh, sycl::read_write); - + sycl::accessor dacc_gnorm_vec(buff_gnorm_vec, cgh, sycl::read_write); + + //sycl::local_accessor dacc_gnorm_vec(sycl::range<1>(size), cgh); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { - kPercentileClipping(g, gnorm_vec, step, n, item_ct1, tacc, dacc_g); + kPercentileClipping(g, gnorm_vec, step, n, item_ct1, tacc, dacc_g, dacc_gnorm_vec.get_pointer()); }); }); } @@ -1288,6 +1292,7 @@ template void trans } + template void transform(int8_t *A, int8_t *out, int dim1, int dim2); template void transform( int8_t *A, int8_t *out, int dim1, int dim2); template void transform(int8_t *A, int8_t *out, int dim1, int dim2); @@ -1351,17 +1356,13 @@ catch (sycl::exception const &exc) { std::exit(1); } -/* + template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -*/ - - - //===========================gemm_host============================================ @@ -1571,24 +1572,51 @@ void spmm_coo(int *A_rowidx, int *A_colidx, sycl::half *A_vals, int A_nnz, int A template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, T *B, sycl::half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { + dpct::device_ext &dev_ct1 = dpct::get_current_device(); + sycl::queue &q_ct1 = dev_ct1.in_order_queue(); + sycl::context ctx = q_ct1.get_context(); + int size = NUM_BLOCK; + + sycl::buffer buff_max_count(max_count,sycl::range<1>(size)); + sycl::buffer buff_max_idx(max_idx,sycl::range<1>(size)); + sycl::buffer buff_offset_rowidx(offset_rowidx,sycl::range<1>(size)); + sycl::buffer buff_rowidx(rowidx,sycl::range<1>(size)); + sycl::buffer buff_colidx(colidx,sycl::range<1>(size)); + sycl::buffer buff_values(values,sycl::range<1>(size)); + sycl::buffer buff_out(out,sycl::range<1>(size)); + sycl::buffer buff_B(B, sycl::range<1>(size)); + sycl::buffer buff_dequant_stats(dequant_stats,sycl::range<1>(size)); + + { dpct::has_capability_or_fail(dpct::get_in_order_queue().get_device(), {sycl::aspect::fp16}); - dpct::get_in_order_queue().submit( + q_ct1.submit( [&](sycl::handler &cgh) { + sycl::accessor dacc_max_count(buff_max_count, cgh, sycl::read_write); + sycl::accessor dacc_max_idx(buff_max_idx, cgh, sycl::read_write); + sycl::accessor dacc_offset_rowidx(buff_offset_rowidx, cgh, sycl::read_write); + sycl::accessor dacc_colidx(buff_colidx, cgh, sycl::read_write); + sycl::accessor dacc_rowidx(buff_rowidx, cgh, sycl::read_write); + sycl::accessor dacc_values(buff_values, cgh, sycl::read_write); + sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_dequant_stats(buff_dequant_stats, cgh, sycl::read_write); + sycl::accessor dacc_B(buff_B, cgh, sycl::read_write); + + + //smem sycl::local_accessor smem_dequant_stats_acc_ct1(sycl::range<1>(2048/*SMEM_SIZE*/), cgh); - + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, nnz_rows) * sycl::range<3>(1, 1, 256), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - kspmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB, item_ct1, smem_dequant_stats_acc_ct1.get_pointer()); + kspmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB, item_ct1, smem_dequant_stats_acc_ct1.get_pointer(), dacc_max_count, dacc_max_idx, dacc_offset_rowidx, dacc_rowidx, dacc_colidx, dacc_values, dacc_B, dacc_out, dacc_dequant_stats); }); }); } } - //======================================non gemm 2d quants============================================ //===========================Row col stats================================= @@ -1853,7 +1881,7 @@ template void transformRowToFormat(char * A, char *o template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols) { - int threads = 256; + int threads = 512; // we load 128 column values per warp int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32); int tiledRows = 0; @@ -1875,17 +1903,21 @@ template void extractOutliers(char * A, int *idx, char *out, int id sycl::buffer buff_A(A,sycl::range<1>(size)); sycl::buffer buff_out(out,sycl::range<1>(size)); + sycl::buffer buff_idx(idx,sycl::range<1>(size)); dpct::get_in_order_queue().submit( [&](sycl::handler &cgh) { - sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); + + sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); sycl::accessor dacc_out(buff_out, cgh, sycl::read_write); + sycl::accessor dacc_idx(buff_idx, cgh, sycl::read_write); + cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, threads), sycl::range<3>(1, 1, threads)), [=](sycl::nd_item<3> item_ct1) { - kExtractOutliers(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols, item_ct1, dacc_A, dacc_out); + kExtractOutliers(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols, item_ct1, dacc_A, dacc_out, dacc_idx); }); }); From 46d10ba8dc779eeef17971350686a379d170c544 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Fri, 7 Jun 2024 06:32:01 -0700 Subject: [PATCH 62/66] refine --- csrc/sycl/kernels.cpp | 18 ++++++++-------- csrc/sycl/kernels.h | 2 +- csrc/sycl/ops.cpp | 50 ++++++++++++++++--------------------------- 3 files changed, 28 insertions(+), 42 deletions(-) diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index 3c1b1faed..b7e58afc8 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -2787,7 +2787,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char //==========================k get row col stats========================================== -template void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A) +template void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_nnz_count_row) { // 0. reset stats to -FLT_MAX // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) @@ -2803,7 +2803,7 @@ template>; using group_exchange = dpct::group::exchange; @@ -2934,22 +2934,22 @@ template(sycl::half *__restrict__ co -template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A); -template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A); +template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_nnz_count_row); +template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_nnz_count_row); template unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template unsigned char dQuantize<1>(float* smem_code, const float rand, float x); diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 94bd31fa6..67e90e7e4 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -223,7 +223,7 @@ extern SYCL_EXTERNAL void kdequant_mm_int32_fp16( //==================k row col stats===================== template extern SYCL_EXTERNAL void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, - const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A); + const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_nnz_count_row); //===========================double row col quant=================== template buff_A(A,sycl::range<1>(size)); - - if(nnz_threshold == 0.0) - { - - dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); + sycl::buffer buff_nnz_count_row(nnz_count_row,sycl::range<1>(size)); + sycl::buffer buff_rowStats(rowStats,sycl::range<1>(size)); + sycl::buffer buff_colStats(colStats,sycl::range<1>(size)); + dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); q_ct1.submit( [&](sycl::handler &cgh) { @@ -1656,7 +1655,9 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); sycl::local_accessor tacc(temp_storage_size, cgh); sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); - + sycl::accessor dacc_rowStats(buff_rowStats, cgh, sycl::read_write); + sycl::accessor dacc_colStats(buff_colStats, cgh, sycl::read_write); + sycl::accessor dacc_nnz_count_row(buff_nnz_count_row, cgh, sycl::read_write); //__shared__ vars sycl::local_accessor smem_row_absmax_values_acc_ct1(sycl::range<1>(256), cgh); sycl::local_accessor smem_row_nnz_values_acc_ct1(sycl::range<1>(256), cgh); @@ -1665,35 +1666,20 @@ void getColRowStats(sycl::half * A, float *rowStats, float *colStats, int *nnz_c cgh.parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), [=](sycl::nd_item<3> item_ct1) { + if(nnz_threshold == 0.0){ kgetColRowStats(A, rowStats, colStats, - nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols,item_ct1, smem_row_absmax_values_acc_ct1.get_pointer(), smem_row_nnz_values_acc_ct1.get_pointer(), tacc, dacc_A); + nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols,item_ct1, + smem_row_absmax_values_acc_ct1.get_pointer(), smem_row_nnz_values_acc_ct1.get_pointer(), tacc, + dacc_A, dacc_rowStats, dacc_colStats, dacc_nnz_count_row); + } + else if(nnz_threshold != 0.0){ + kgetColRowStats(A, rowStats, colStats, + nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols,item_ct1, + smem_row_absmax_values_acc_ct1.get_pointer(),smem_row_nnz_values_acc_ct1.get_pointer(), + tacc, dacc_A, dacc_rowStats, dacc_colStats, dacc_nnz_count_row); + } }); }); - } - else if(nnz_threshold != 0.0) - { - dpct::has_capability_or_fail(q_ct1.get_device(), {sycl::aspect::fp16}); - q_ct1.submit( - [&](sycl::handler &cgh) { - - using group_load = dpct::group::workgroup_load>; - size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); - sycl::local_accessor tacc(temp_storage_size, cgh); - sycl::accessor dacc_A(buff_A, cgh, sycl::read_write); - - //__shared__ vars - sycl::local_accessor smem_row_absmax_values_acc_ct1(sycl::range<1>(256), cgh); - sycl::local_accessor smem_row_nnz_values_acc_ct1(sycl::range<1>(256), cgh); - - - cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 512), sycl::range<3>(1, 1, 512)), - [=](sycl::nd_item<3> item_ct1) { - kgetColRowStats(A, rowStats, colStats, - nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols,item_ct1, smem_row_absmax_values_acc_ct1.get_pointer(), smem_row_nnz_values_acc_ct1.get_pointer(), tacc, dacc_A); - }); - }); - } } From 8795844612cc7ec4f319472c66098c12928241eb Mon Sep 17 00:00:00 2001 From: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> Date: Fri, 7 Jun 2024 23:55:17 +0530 Subject: [PATCH 63/66] add dnn build flag --- CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b8eec42b7..c9d40e558 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,7 +199,8 @@ elseif(BUILD_SYCL) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=spir64 -L${MKLROOT}/lib") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ferror-limit=80") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ldnnl") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -c") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w") list(APPEND SRC_FILES ${SYCL_FILES}) From 7007a02bc321ddb4bfe75b404cda09bf5e2c154f Mon Sep 17 00:00:00 2001 From: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> Date: Thu, 27 Jun 2024 10:29:37 +0530 Subject: [PATCH 64/66] update dnn linkage --- CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c9d40e558..eaa65977e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,7 +199,6 @@ elseif(BUILD_SYCL) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=spir64 -L${MKLROOT}/lib") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ldnnl") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -c") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w") list(APPEND SRC_FILES ${SYCL_FILES}) @@ -248,7 +247,7 @@ if(BUILD_MPS) endif() if(BUILD_SYCL) - target_link_libraries(bitsandbytes PUBLIC OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread onednn mkl_dnn) + target_link_libraries(bitsandbytes PUBLIC OpenCL mkl_core pthread m dl mkl_intel_ilp64 mkl_tbb_thread dnnl) endif() if(WIN32) set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") From 6f8ef480b62f1a44ddde21a18b512b76f6bc2900 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 7 Aug 2024 10:58:35 -0700 Subject: [PATCH 65/66] fix cmake --- CMakeLists.txt | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index eaa65977e..0213ffc2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,11 +24,12 @@ if(NOT CMAKE_BUILD_TYPE) endif() # Define included source files -set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) +set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp) set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) -set(SYCL_FILES csrc/sycl/ops.cpp csrc/sycl/kernels.cpp) +set(SYCL_FILES csrc/sycl/kernels.cpp csrc/sycl/ops.cpp csrc/pythonInterface.cpp) +#set(SYCL_FILES csrc/sycl/kernel_gemm.cpp csrc/sycl/op_gemm.cpp csrc/sycl/kernel_quant.cpp csrc/sycl/op_quant.cpp) # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) @@ -198,8 +199,11 @@ elseif(BUILD_SYCL) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=spir64 -L${MKLROOT}/lib") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -c") + if (SYCL_TARGET STREQUAL "INTEL") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=spir64 -L${MKLROOT}/lib") + elseif( SYCL_TARGET STREQUAL "NVIDIA") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda") + endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w") list(APPEND SRC_FILES ${SYCL_FILES}) @@ -247,7 +251,11 @@ if(BUILD_MPS) endif() if(BUILD_SYCL) - target_link_libraries(bitsandbytes PUBLIC OpenCL mkl_core pthread m dl mkl_intel_ilp64 mkl_tbb_thread dnnl) + if (SYCL_TARGET STREQUAL "INTEL") + target_link_libraries(bitsandbytes PUBLIC OpenCL mkl_core pthread m dl mkl_intel_ilp64 mkl_tbb_thread dnnl) + elseif(SYCL_TARGET STREQUAL "NVIDIA") + target_link_libraries(bitsandbytes PUBLIC onemkl pthread m dl) + endif() endif() if(WIN32) set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") From a425d444092b892770630b88aa5edca81c8df1b4 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 7 Aug 2024 11:00:43 -0700 Subject: [PATCH 66/66] upgrade to 2024.2 Intel LLVM compiler release, set Nvidia build flag --- csrc/sycl/kernels.cpp | 240 ++++++++++++++++++++---------------------- csrc/sycl/kernels.h | 3 +- csrc/sycl/ops.cpp | 8 +- csrc/sycl/ops.h | 29 ----- csrc/sycl/utilities.h | 138 ++++++++++++++++++++++++ 5 files changed, 252 insertions(+), 166 deletions(-) create mode 100644 csrc/sycl/utilities.h diff --git a/csrc/sycl/kernels.cpp b/csrc/sycl/kernels.cpp index b7e58afc8..c2a4559d1 100644 --- a/csrc/sycl/kernels.cpp +++ b/csrc/sycl/kernels.cpp @@ -2,7 +2,7 @@ // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. - +#pragma once #include #include #include @@ -15,13 +15,16 @@ #define FLT_MAX std::numeric_limits::max() #define FLT_MIN std::numeric_limits::min() - +#include "utilities.h" #define HLF_MAX 65504 #define TH 1024 #define NUM 4 #define NUM_BLOCK 4096 +#ifdef BLOCK_SIZE +#undef BLOCK_SIZE +#endif // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda float atomicMax(float* address, float val) { @@ -641,15 +644,15 @@ typedef sycl::accessor sycl_dacc_char; template SYCL_EXTERNAL -void kEstimateQuantiles(const T *A, float *code, const float offset, const T max_val, const int n, - const sycl::nd_item<3> &item_ct1, sycl_la tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code) +void kEstimateQuantiles(T*__restrict__ const A, float *code, const float offset, const T max_val, const int n, + const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code) { const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); int valid_items = (item_ct1.get_group(2)+1 == item_ct1.get_group_range(2)) ? n - (item_ct1.get_group(2)*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; const int base_idx = (item_ct1.get_group(2) * BLOCK_ESTIMATE); const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); - using group_load = dpct::group::workgroup_load>; + using group_load = dpct::group::workgroup_load>; //using group_radix_sort = dpct::group::radix_sort; T vals[NUM_ESTIMATE]; @@ -698,7 +701,8 @@ void kEstimateQuantiles(const T *A, float *code, const float offset, const T max // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest // 5. Repeat (3) 8 times for top 8 values in 256 // 6. store with byte index - group_radix_sort(tmp).sort_blocked_to_striped(item_ct1, vals); + // bypass sorting + //group_radix_sort(tmp).sort_blocked_to_striped(item_ct1, vals); item_ct1.barrier(sycl::access::fence_space::local_space); @@ -1339,7 +1343,7 @@ void kOptimizer32bit2State(T* g, T* p, template SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(T* g, T* p, - float* buff_state1, float *unorm, + float* state1, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n, const sycl::nd_item<3> &item_ct1,const sycl_la &tacc,const sycl::accessor &dacc_g,const sycl_dacc_float &dacc_state1, @@ -2232,7 +2236,7 @@ SYCL_EXTERNAL void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); int valid_items = 0; - using group_load = dpct_::group::workgroup_load>; + using group_load = dpct::group::workgroup_load>; auto *d_g = dacc_g.template get_multi_ptr().get(); T vals[NUM_VALS]; @@ -3013,7 +3017,7 @@ template void kdequant_mm_i float local_rowStats[ITEMS_PER_THREAD]; using group_load_int = dpct::group::workgroup_load>; - using group_exchange = exchange; + using group_exchange = dpct::group::exchange; auto *d_A = dacc_A.get_multi_ptr().get(); auto *tmp = tacc.get_multi_ptr().get(); @@ -4507,158 +4511,161 @@ template SYCL_EXTERNAL void kfunc(float *A, float *B, float value, template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - const sycl::accessor &dacc_A, - const sycl::accessor &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - const sycl::accessor &dacc_A, - const sycl::accessor &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, - sycl::half *smem_A, sycl::half *smem_B); + sycl::half *smem_A, sycl::half *smem_B, + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - const sycl::accessor &dacc_A, - const sycl::accessor &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - const sycl::accessor &dacc_A, - const sycl::accessor &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - const sycl::accessor &dacc_A, - const sycl::accessor &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - const sycl::accessor &dacc_A, - const sycl::accessor &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - const sycl::accessor &dacc_A, - const sycl::accessor &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - const sycl::accessor &dacc_A, - const sycl::accessor &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - const sycl::accessor &dacc_A, - const sycl::accessor &dacc_B, - const sycl::accessor &dacc_out);; + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out);; template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - const sycl::accessor &dacc_A, - const sycl::accessor &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - const sycl::accessor &dacc_A, - const sycl::accessor &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - const sycl::accessor &dacc_A, - const sycl::accessor &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void gemm_device(int M, int N, int K, sycl::half * __restrict__ const A, sycl::half* B, sycl::half * out, int lda, int ldb, int ldc, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, sycl::half *smem_B, - const sycl::accessor &dacc_A, - const sycl::accessor &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_A, + const sycl::accessor &dacc_B, + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, - sycl::half *smem_B, + unsigned char *smem_B, sycl::half *smem_C, - const sycl::accessor &dacc_A, + const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, - sycl::half *smem_B, + unsigned char *smem_B, sycl::half *smem_C, - const sycl::accessor &dacc_A, + const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, - sycl::half *smem_B, + unsigned char *smem_B, sycl::half *smem_C, - const sycl::accessor &dacc_A, + const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, sycl::half * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, sycl::half *smem_A, - sycl::half *smem_B, + unsigned char *smem_B, sycl::half *smem_C, - const sycl::accessor &dacc_A, + const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, - const sycl::accessor &dacc_out); + const sycl::accessor &dacc_out); template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, sycl::half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, sycl::half * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, sycl::half *quant_map, - const sycl::accessor &dacc_A, + const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, - const sycl::accessor &dacc_out, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_absmax, const sycl_dacc_float &dacc_datatype); template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, sycl::ext::oneapi::bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, sycl::ext::oneapi::bfloat16 * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, sycl::ext::oneapi::bfloat16 *quant_map, - const sycl::accessor &dacc_A, + const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, - const sycl::accessor &dacc_out, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_absmax, const sycl_dacc_float &dacc_datatype); template SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, float *quant_map, - const sycl::accessor &dacc_A, + const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, - const sycl::accessor &dacc_out, + const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_absmax, const sycl_dacc_float &dacc_datatype); @@ -4669,7 +4676,7 @@ template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int * const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, const sycl_dacc &dacc_offset_rowidx, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, - const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, @@ -4678,7 +4685,7 @@ template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, const sycl_dacc &dacc_offset_rowidx, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, - const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, sycl::half *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, @@ -4687,7 +4694,7 @@ template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, const sycl_dacc &dacc_offset_rowidx, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, - const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, @@ -4696,7 +4703,7 @@ template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int * const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, const sycl_dacc &dacc_offset_rowidx, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, - const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, @@ -4705,7 +4712,7 @@ template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, const sycl_dacc &dacc_offset_rowidx, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, - const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, sycl::half *values, signed char *B, sycl::half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB, @@ -4714,7 +4721,7 @@ template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int const sycl_dacc &dacc_max_count, const sycl_dacc &dacc_max_idx, const sycl_dacc &dacc_offset_rowidx, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, - const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, + const sycl::accessor &dacc_values, const sycl::accessor &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_dequant_stats); @@ -4724,12 +4731,11 @@ template SYCL_EXTERNAL void kspmm_coo_very_sparse_naive(int template void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, sycl::half *out, float* newRowStats, float* newcolStats, sycl::half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n, const sycl::nd_item<3> &item_ct1, float *smem_rowStats, const sycl_la &tacc, const sycl_dacc &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl::accessor &dacc_out, const sycl::accessor &dacc_bias); -template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out, const sycl_dacc &dacc_idx); +template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out, const sycl_dacc &dacc_idx); template SYCL_EXTERNAL void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA, const sycl::nd_item<3> &item_ct1, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out, const sycl_dacc &dacc_idx); - template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); @@ -4743,14 +4749,9 @@ template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPER template SYCL_EXTERNAL void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols, const sycl::nd_item<3> &item_ct1, char *smem_data, const sycl_dacc_char &dacc_A, const sycl_dacc_char &dacc_out); - - template void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr); -template void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normedconst sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, - const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr); - - +template void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(sycl::half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, sycl::half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_stats, unsigned int *smem_nnz_row_idx, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_char &dacc_out_col_normed, const sycl_dacc_char &dacc_out_row_normed, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_rowidx, const sycl_dacc &dacc_colidx, const sycl::accessor &dacc_val, const sycl_dacc &dacc_nnz_block_ptr); template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_nnz_count_row); template void kgetColRowStats(sycl::half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols, const sycl::nd_item<3> &item_ct1, float *smem_row_absmax_values, int *smem_row_nnz_values, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_rowStats, const sycl_dacc_float &dacc_colStats, const sycl_dacc &dacc_nnz_count_row); @@ -4758,8 +4759,8 @@ template void kgetColRowStats(sycl::half * __res template unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template unsigned char dQuantize<1>(float* smem_code, const float rand, float x); -template SYCL_EXTERNAL void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code); -template SYCL_EXTERNAL void kEstimateQuantiles(sycl::half *__restrict__ const A, float *code, const float offset, const sycl::half max_val, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code); +template SYCL_EXTERNAL void kEstimateQuantiles(float* __restrict__ const A, float *code, const float offset, const float max_val, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code); +template SYCL_EXTERNAL void kEstimateQuantiles(sycl::half* __restrict__ const A, float *code, const float offset, const sycl::half max_val, const int n, const sycl::nd_item<3> &item_ct1, const sycl_la &tacc, const sycl::accessor &dacc_A, const sycl_dacc_float &dacc_code); #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template SYCL_EXTERNAL void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ @@ -4776,6 +4777,9 @@ MAKE_PreconditionOptimizer32bit1State(LION, float) MAKE_PreconditionOptimizer32bit1State(LION, sycl::ext::oneapi::bfloat16) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, sycl::half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) +MAKE_PreconditionOptimizer32bit1State(ADAM, sycl::half) +MAKE_PreconditionOptimizer32bit1State(ADAM, float) +MAKE_PreconditionOptimizer32bit1State(ADAM, sycl::ext::oneapi::bfloat16) #define MAKE_Optimizer32bit1State(oname, gtype) \ template SYCL_EXTERNAL void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ @@ -4834,27 +4838,34 @@ MAKE_PreconditionStatic8bit1State(RMSPROP, float) MAKE_PreconditionStatic8bit1State(LION, sycl::half) MAKE_PreconditionStatic8bit1State(LION, float) -#define MAKE_optimizerStatic8bit1State(oname, gtype) \ -template void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ - const float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, \ - const float beta2, \ +#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template void kOptimizerStatic8bit1StateBlockwise( \ + gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ + const float beta1, const float beta2, \ const float eps, const int step, const float lr, \ - float* __restrict__ const quantiles1, \ - float* max1, float* new_max1, \ + float* __restrict__ const quantiles1, \ + float* absmax1, \ float weight_decay, \ - const float gnorm_scale, \ - const int n, const sycl::nd_item<3> &item_ct1, float *smem_quantiles1, const sycl_la &tacc, \ - const sycl::accessor &dacc_g, const sycl::accessor &dacc_p, \ - const sycl_dacc_uc &dacc_state1, const sycl_dacc_float &dacc_unorm, const sycl_dacc_float &dacc_quantiles1, \ - const sycl_dacc_float &dacc_max1, const sycl_dacc_float &dacc_new_max1); \ + const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1,const sycl_la &tacc, \ + const sycl::accessor &dacc_g, \ + const sycl::accessor &dacc_p, \ + const sycl_dacc_uc &dacc_state1, \ + const sycl_dacc_float &dacc_quantiles1, \ + const sycl_dacc_float &dacc_absmax1); \ -MAKE_optimizerStatic8bit1State(MOMENTUM, sycl::half) -MAKE_optimizerStatic8bit1State(MOMENTUM, float) -MAKE_optimizerStatic8bit1State(RMSPROP, sycl::half) -MAKE_optimizerStatic8bit1State(RMSPROP, float) -MAKE_optimizerStatic8bit1State(LION, sycl::half) -MAKE_optimizerStatic8bit1State(LION, float) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::ext::oneapi::bfloat16, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::ext::oneapi::bfloat16, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAM, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAM, sycl::half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAM, sycl::ext::oneapi::bfloat16, 2048, 8) #define MAKE_PreconditionStatic8bit2State(oname, gtype) \ template void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ @@ -5002,30 +5013,3 @@ template void kOptimizerStatic8bit2StateBlockwise( \ - gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ - const float beta1, const float beta2, \ - const float eps, const int step, const float lr, \ - float* __restrict__ const quantiles1, \ - float* absmax1, \ - float weight_decay, \ - const float gnorm_scale, const bool skip_zeros, const int n, const sycl::nd_item<3> &item_ct1, sycl::local_accessor smem_quantiles1, float *smem_exchange1,const sycl_la &tacc, \ - const sycl::accessor &dacc_g, \ - const sycl::accessor &dacc_p, \ - const sycl_dacc_uc &dacc_state1, \ - const sycl_dacc_float &dacc_quantiles1, \ - const sycl_dacc_float &dacc_absmax1); \ - -MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, sycl::half, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, sycl::half, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::half, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(LION, sycl::ext::oneapi::bfloat16, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) -MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, sycl::half, 2048, 8) - diff --git a/csrc/sycl/kernels.h b/csrc/sycl/kernels.h index 67e90e7e4..33fe46354 100644 --- a/csrc/sycl/kernels.h +++ b/csrc/sycl/kernels.h @@ -254,8 +254,7 @@ template extern SYCL_EXTERNAL void gemm_devi //=========================gemm 4 bit inf================================ template extern SYCL_EXTERNAL void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *smem_A, unsigned char *smem_B, - T *smem_C, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out, const sycl::accessor &dacc_A, - const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out); + T *smem_C, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out); //====================gemm 4 bit naive inf============================ template extern SYCL_EXTERNAL void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, const sycl::nd_item<3> &item_ct1, T *quant_map, const sycl::accessor &dacc_A, const sycl_dacc_uc &dacc_B, const sycl::accessor &dacc_out, const sycl_dacc_float &dacc_absmax, const sycl_dacc_float &dacc_datatype); diff --git a/csrc/sycl/ops.cpp b/csrc/sycl/ops.cpp index 8fad13302..d345224d0 100644 --- a/csrc/sycl/ops.cpp +++ b/csrc/sycl/ops.cpp @@ -1092,7 +1092,7 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, q_ct1.submit( [&](sycl::handler &cgh) { - using group_load = dpct_::group::workgroup_load>; + using group_load = dpct::group::workgroup_load>; size_t temp_storage_size = group_load::get_local_memory_size(THREADS_ESTIMATE); sycl::local_accessor tacc(temp_storage_size, cgh); @@ -1357,12 +1357,6 @@ catch (sycl::exception const &exc) { } -template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt( int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); //===========================gemm_host============================================ diff --git a/csrc/sycl/ops.h b/csrc/sycl/ops.h index 3071b8456..19887f028 100644 --- a/csrc/sycl/ops.h +++ b/csrc/sycl/ops.h @@ -27,34 +27,6 @@ #define THREADS_PER_BLOCKS (512) -inline void checkCudaStatus(int status) { - /* - DPCT1000:93: Error handling if-stmt was detected but could not be rewritten. - */ - if (status != 0) { - /* - DPCT1009:94: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string - was inserted. You need to rewrite this code. - */ - printf( - "cuda API failed with status %d: %s\n", status, - "cudaGetErrorString is not supported" /*cudaGetErrorString(status)*/); - /* - DPCT1001:92: The statement could not be removed. - */ - throw std::logic_error("cuda API failed"); - } -} - -inline int checkCublasStatus(int status) { - if (status != 0) { - printf("cuBLAS API failed with status %d\n", status); - //throw std::logic_error("cuBLAS API failed"); - return 1; - } - return 0; -} typedef enum Operations_t { @@ -136,7 +108,6 @@ class ContextCusparse }; - template void estimateQuantiles(T *A, float *code, float offset, int n); void quantize(float *code, float *A, unsigned char *out, int n); diff --git a/csrc/sycl/utilities.h b/csrc/sycl/utilities.h new file mode 100644 index 000000000..0e8bcd0d6 --- /dev/null +++ b/csrc/sycl/utilities.h @@ -0,0 +1,138 @@ + +#include +#include +#include +#include +#include +#include "ops.h" +#include +#include +#include +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +#define THREADS_ESTIMATE 512 +#define NUM_ESTIMATE 8 +#define BLOCK_ESTIMATE 4096 +#define NUM_PER_THREAD 4 + +typedef sycl::ext::oneapi::bfloat16 bf16; + +using std::cout; +using std::endl; + + + +namespace dpct{ +namespace group{ +enum store_algorithm { + + BLOCK_STORE_DIRECT, + BLOCK_STORE_STRIPED, + // To-do: BLOCK_STORE_WARP_TRANSPOSE + // To-do: BLOCK_STORE_VECTORIZE + +}; + +/// Stores a blocked arrangement of work items linear segment of items. +template +__dpct_inline__ void store_blocked(const Item &item, OutputIteratorT block_itr, + InputT (&items)[ITEMS_PER_WORK_ITEM]) { + + // This implementation does not take in account range storage across + // workgroup items To-do: Decide whether range storage is required for group + // storage + size_t linear_tid = item.get_local_linear_id(); + OutputIteratorT workitem_itr = block_itr + (linear_tid * ITEMS_PER_WORK_ITEM); +#pragma unroll + for (uint32_t idx = 0; idx < ITEMS_PER_WORK_ITEM; idx++) { + workitem_itr[idx] = items[idx]; + } +} + +/// Stores a striped arrangement of work items linear segment of items. +template +__dpct_inline__ void store_striped(const Item &item, OutputIteratorT block_itr, + InputT (&items)[ITEMS_PER_WORK_ITEM]) { + + // This implementation does not take in account range storage across + // workgroup items To-do: Decide whether range storage is required for group + // storage + size_t linear_tid = item.get_local_linear_id(); + OutputIteratorT workitem_itr = block_itr + linear_tid; + size_t GROUP_WORK_ITEMS = item.get_global_range().size(); +#pragma unroll + for (uint32_t idx = 0; idx < ITEMS_PER_WORK_ITEM; idx++) { + workitem_itr[(idx * GROUP_WORK_ITEMS)] = items[idx]; + } +} + +/// Stores a warp-striped arrangement of work items linear segment of items. +// Created as free function until exchange mechanism is +// implemented. +// To-do: inline this function with BLOCK_STORE_WARP_TRANSPOSE mechanism +template +__dpct_inline__ void +store_subgroup_striped(const Item &item, OutputIteratorT block_itr, + InputT (&items)[ITEMS_PER_WORK_ITEM]) { + + // This implementation does not take in account range loading across + // workgroup items To-do: Decide whether range loading is required for group + // loading + // This implementation uses unintialized memory for loading linear segments + // into warp striped arrangement. + uint32_t subgroup_offset = item.get_sub_group().get_local_linear_id(); + uint32_t subgroup_size = item.get_sub_group().get_local_linear_range(); + uint32_t subgroup_idx = item.get_sub_group().get_group_linear_id(); + uint32_t initial_offset = + (subgroup_idx * ITEMS_PER_WORK_ITEM * subgroup_size) + subgroup_offset; + OutputIteratorT workitem_itr = block_itr + initial_offset; +#pragma unroll + for (uint32_t idx = 0; idx < ITEMS_PER_WORK_ITEM; idx++) { + workitem_itr[(idx * subgroup_size)] = items[idx]; + } +} + +// template parameters : +// ITEMS_PER_WORK_ITEM: size_t variable controlling the number of items per +// thread/work_item +// ALGORITHM: store_algorithm variable controlling the type of store operation. +// InputT: type for input sequence. +// OutputIteratorT: output iterator type +// Item : typename parameter resembling sycl::nd_item<3> . +template +class workgroup_store { +public: + static size_t get_local_memory_size(size_t group_work_items) { return 0; } + workgroup_store(uint8_t *local_memory) : _local_memory(local_memory) {} + + __dpct_inline__ void store(const Item &item, OutputIteratorT block_itr, + InputT (&items)[ITEMS_PER_WORK_ITEM]) { + + if constexpr (ALGORITHM == BLOCK_STORE_DIRECT) { + store_blocked(item, block_itr, (&items)[ITEMS_PER_WORK_ITEM]); + } else if constexpr (ALGORITHM == BLOCK_STORE_STRIPED) { + store_striped(item, block_itr, (&items)[ITEMS_PER_WORK_ITEM]); + } + } + +private: + uint8_t *_local_memory; + +}; +} +} + + +