diff --git a/ggml-cann.cpp b/ggml-cann.cpp index 2397c6f2de088..e9709f1a0d15a 100644 --- a/ggml-cann.cpp +++ b/ggml-cann.cpp @@ -122,7 +122,6 @@ GGML_CALL static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor, void* dst) { GGML_ASSERT(tensor->op == GGML_OP_NONE); - size_t n_bytes = ggml_nbytes(tensor); int64_t n_elems = ggml_nelements(tensor); int64_t groups = n_elems / QK4_0; size_t quant_bytes = n_elems * sizeof(uint8_t) / 2; @@ -131,7 +130,7 @@ GGML_CALL static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor, uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes); for (int i = 0; i < groups; i++) { - block_q4_0* group = (block_q4_0*)((char*)src + i * sizeof(block_q4_0)); + const block_q4_0* group = (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0)); *scale_offset = group->d; scale_offset++; @@ -161,7 +160,6 @@ GGML_CALL static void ggml_backend_cann_transform_back_q4_0( const ggml_tensor* tensor, void* src, void* dst) { GGML_ASSERT(tensor->op == GGML_OP_NONE); - size_t n_bytes = ggml_nbytes(tensor); int64_t n_elems = ggml_nelements(tensor); int64_t groups = n_elems / QK4_0; size_t quant_bytes = n_elems * sizeof(uint8_t) / 2; @@ -198,7 +196,6 @@ GGML_CALL static void ggml_backend_cann_transform_back_q4_0( GGML_CALL static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor, const void* src, void* dst) { - size_t n_bytes = ggml_nbytes(tensor); int64_t n_elems = ggml_nelements(tensor); int64_t groups = n_elems / QK8_0; size_t quant_bytes = n_elems * sizeof(uint8_t); @@ -207,7 +204,7 @@ GGML_CALL static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor, uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes); for (int i = 0; i < groups; i++) { - block_q8_0* group = (block_q8_0*)((char*)src + i * sizeof(block_q8_0)); + const block_q8_0* group = (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0)); *scale_offset = group->d; scale_offset++; size_t group_quant_size = QK8_0 * sizeof(uint8_t); @@ -218,13 +215,12 @@ GGML_CALL static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor, GGML_CALL static void ggml_backend_cann_transform_back_q8_0( const ggml_tensor* tensor, const void* src, void* dst) { - size_t n_bytes = ggml_nbytes(tensor); int64_t n_elems = ggml_nelements(tensor); int64_t groups = n_elems / QK8_0; size_t quant_bytes = n_elems * sizeof(uint8_t); - uint8_t* quant_offset = (uint8_t*)src; - uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes); + const uint8_t* quant_offset = (const uint8_t*)src; + const uint16_t* scale_offset = (const uint16_t*)((const char*)src + quant_bytes); for (int i = 0; i < groups; i++) { block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0)); @@ -292,13 +288,10 @@ GGML_CALL static void ggml_backend_cann_buffer_init_tensor( ggml_backend_buffer_t buffer, ggml_tensor* tensor) { if (tensor->view_src != NULL && tensor->view_offs == 0) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); - tensor->backend = tensor->view_src->backend; set_tensor_extra(buffer, tensor); return; } - tensor->backend = GGML_BACKEND_TYPE_GPU; - // TODO: can backend doesn't support quantized yet. Just leave the code // here. if (ggml_is_quantized(tensor->type)) { @@ -320,7 +313,6 @@ GGML_CALL static void ggml_backend_cann_buffer_init_tensor( GGML_CALL static void ggml_backend_cann_buffer_set_tensor( ggml_backend_buffer_t buffer, ggml_tensor* tensor, const void* data, size_t offset, size_t size) { - GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); GGML_ASSERT(size == ggml_nbytes(tensor)); ggml_backend_cann_buffer_context* ctx = (ggml_backend_cann_buffer_context*)buffer->context; @@ -331,18 +323,18 @@ GGML_CALL static void ggml_backend_cann_buffer_set_tensor( // Why aclrtSynchronizeDevice? if (!need_transform(tensor->type)) { - ACL_CHECK(aclrtMemcpy(tensor->data, size, (char*)data + offset, size, + ACL_CHECK(aclrtMemcpy(tensor->data, size, (const char*)data + offset, size, ACL_MEMCPY_HOST_TO_DEVICE)); } else { void* transform_buffer = malloc(size); - ggml_backend_cann_transform(tensor, (char*)data + offset, + ggml_backend_cann_transform(tensor, (const char*)data + offset, transform_buffer); #ifndef NDEBUG void* check_buffer = malloc(size); ggml_backend_cann_transform_back(tensor, transform_buffer, check_buffer); - GGML_ASSERT(memcmp((char*)data + offset, check_buffer, size) == 0); + GGML_ASSERT(memcmp((const char*)data + offset, check_buffer, size) == 0); free(check_buffer); #endif ACL_CHECK(aclrtMemcpy(tensor->data, size, transform_buffer, size, @@ -355,7 +347,6 @@ GGML_CALL static void ggml_backend_cann_buffer_get_tensor( ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data, size_t offset, size_t size) { GGML_ASSERT(size == ggml_nbytes(tensor)); - GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); ggml_backend_cann_buffer_context* ctx = (ggml_backend_cann_buffer_context*)buffer->context; @@ -377,8 +368,6 @@ GGML_CALL static void ggml_backend_cann_buffer_get_tensor( GGML_CALL static bool ggml_backend_cann_buffer_cpy_tensor( ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) { if (ggml_backend_buffer_is_cann(src->buffer)) { - GGML_ASSERT(src->backend == GGML_BACKEND_TYPE_GPU); - GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU); ggml_backend_cann_buffer_context* src_ctx = (ggml_backend_cann_buffer_context*)src->buffer->context; ggml_backend_cann_buffer_context* dst_ctx = @@ -505,27 +494,12 @@ GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alloc_size( GGML_UNUSED(buft); } -GGML_CALL static bool ggml_backend_cann_buffer_type_supports_backend( - ggml_backend_buffer_type_t buft, ggml_backend_t backend) { - if (!ggml_backend_is_cann(backend)) { - return false; - } - - ggml_backend_cann_buffer_type_context* buft_ctx = - (ggml_backend_cann_buffer_type_context*)buft->context; - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; - - return buft_ctx->device == cann_ctx->device; -} - static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = { /* .get_name = */ ggml_backend_cann_buffer_type_name, /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size, - /* .supports_backend = */ ggml_backend_cann_buffer_type_supports_backend, /* .is_host = */ NULL, }; @@ -680,9 +654,6 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_ROPE: ggml_cann_rope(ctx, dst); break; - case GGML_OP_ALIBI: - ggml_cann_alibi(ctx, dst); - break; case GGML_OP_IM2COL: ggml_cann_im2col(ctx, dst); break; @@ -695,6 +666,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_ARGSORT: ggml_cann_argsort(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + return false; default: return false; } @@ -733,7 +706,6 @@ GGML_CALL static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend, const void* data, size_t offset, size_t size) { - GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); ggml_backend_cann_context* cann_ctx = (ggml_backend_cann_context*)backend->context; ggml_backend_buffer_t buf = @@ -766,7 +738,6 @@ GGML_CALL static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend, GGML_CALL static void ggml_backend_cann_get_tensor_async( ggml_backend_t backend, const ggml_tensor* tensor, void* data, size_t offset, size_t size) { - GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); ggml_backend_cann_context* cann_ctx = (ggml_backend_cann_context*)backend->context; ggml_backend_buffer_t buf = @@ -802,9 +773,6 @@ GGML_CALL static bool ggml_backend_cann_cpy_tensor_async( return false; } - GGML_ASSERT(src->backend == GGML_BACKEND_TYPE_GPU); - GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU); - ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; ggml_backend_buffer_t buf_dst = @@ -975,7 +943,6 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: - case GGML_OP_ALIBI: case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: @@ -988,6 +955,8 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: return true; + case GGML_OP_FLASH_ATTN_EXT: + return false; default: return false; } @@ -995,12 +964,18 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, GGML_UNUSED(backend); } -GGML_CALL static bool ggml_backend_cann_offload_op(ggml_backend_t backend, - const ggml_tensor* op) { - const int min_batch_size = 32; - GGML_UNUSED(backend); +static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cann_buffer_type_name; +} - return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS; +GGML_CALL static bool ggml_backend_cann_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { + if (ggml_backend_buft_is_cann(buft)) { + ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *)backend->context; + ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *)buft->context; + return buft_ctx->device == cann_ctx->device; + } + + return false; } static ggml_backend_event_t ggml_backend_cann_event_new( @@ -1059,9 +1034,11 @@ static ggml_backend_i ggml_backend_cann_interface = { /* .synchronize = */ ggml_backend_cann_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_cann_graph_compute, /* .supports_op = */ ggml_backend_cann_supports_op, + /* .supports_buft = */ ggml_backend_cann_supports_buft, /* .offload_op = */ NULL, /* .event_new = */ ggml_backend_cann_event_new, /* .event_free = */ ggml_backend_cann_event_free, diff --git a/ggml-cann/acl_tensor.cpp b/ggml-cann/acl_tensor.cpp index e09e2557f3b16..946b721a1f397 100644 --- a/ggml-cann/acl_tensor.cpp +++ b/ggml-cann/acl_tensor.cpp @@ -24,7 +24,7 @@ aclDataType type_mapping(ggml_type type) { return ACL_DT_UNDEFINED; } -bool nb3_is_valid(const ggml_tensor* tensor) { +static bool nb3_is_valid(const ggml_tensor* tensor) { // check tensor->nb[3] is contiguous by ne. if (tensor->nb[3] == tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * ggml_element_size(tensor)) { @@ -45,19 +45,6 @@ bool nb3_is_valid(const ggml_tensor* tensor) { aclTensor* create_acl_tensor(const ggml_tensor* tensor, int64_t* bcast_ne, size_t* bcast_nb, int64_t bcast_dims, aclFormat format, size_t offset) { - size_t size = ggml_nbytes(tensor); - void* deviceAddr = nullptr; - - if (tensor->backend == GGML_BACKEND_TYPE_GPU) { - deviceAddr = tensor->data; - } else { - // TODO: Consider quantification. - GGML_ASSERT(!ggml_is_quantized(tensor->type)); - ACL_CHECK(aclrtMalloc(&deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMemcpy(deviceAddr, size, tensor->data, size, - ACL_MEMCPY_HOST_TO_DEVICE)); - } - // If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be // added. int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2]; @@ -114,7 +101,7 @@ aclTensor* create_acl_tensor(const ggml_tensor* tensor, int64_t* bcast_ne, aclTensor* acl_tensor = aclCreateTensor(acl_ne, dims, type_mapping(tensor->type), acl_stride, offset / ggml_element_size(tensor), format, - acl_storage_ne, dims, deviceAddr); + acl_storage_ne, dims, tensor->data); return acl_tensor; } diff --git a/ggml-cann/aclnn_ops.cpp b/ggml-cann/aclnn_ops.cpp index 00c91378f7a5e..15d221fdbc078 100644 --- a/ggml-cann/aclnn_ops.cpp +++ b/ggml-cann/aclnn_ops.cpp @@ -32,7 +32,7 @@ #include "kernels/ascendc_kernels.h" -void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src, +static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, int64_t* repeat_array, ggml_tensor* bind_tensor) { // repeat tensor along each dim with repeat_array @@ -74,7 +74,7 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(acl_dst)); } -void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, +static void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, aclTensor* acl_src1, aclTensor* acl_dst, ggml_tensor* bind_tensor) { // add: dst = acl_src0 + alpha*acl_src1 @@ -160,7 +160,7 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(acl_dst)); } -void aclnn_concat(ggml_backend_cann_context& ctx, aclTensorList* tensorList, +static void aclnn_concat(ggml_backend_cann_context& ctx, aclTensorList* tensorList, aclTensor* acl_dst, int64_t concat_dim, ggml_tensor* bind_tensor) { uint64_t workspaceSize = 0; @@ -194,7 +194,7 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(acl_dst)); } -void aclnn_arange(ggml_backend_cann_context& ctx, aclTensor* acl_dst, +static void aclnn_arange(ggml_backend_cann_context& ctx, aclTensor* acl_dst, float start, float stop, float step, int64_t n_elements, ggml_tensor* bind_tensor) { // arange: [start, stop), out(i+1) = out(i) + step. @@ -552,7 +552,7 @@ void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx, ACL_CHECK(aclDestroyTensor(acl_dst)); } -void aclnn_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst, +static void aclnn_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst, aclTensor* acl_src, aclTensor* acl_dst, int64_t* paddings, float value = 0.0f) { aclIntArray* acl_pad = aclCreateIntArray(paddings, GGML_MAX_DIMS * 2); @@ -623,7 +623,6 @@ void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // params const int32_t* opts = (const int32_t*)dst->op_params; - enum ggml_op_pool op = static_cast(opts[0]); const int k0 = opts[1]; const int k1 = opts[2]; const int s0 = opts[3]; @@ -744,7 +743,7 @@ void ggml_cann_max_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyIntArray(dilations)); } -void cann_copy(ggml_backend_cann_context& ctx, ggml_tensor* dst, +static void cann_copy(ggml_backend_cann_context& ctx, ggml_tensor* dst, aclTensor* acl_src, aclTensor* acl_dst) { uint64_t workspaceSize = 0; aclOpExecutor* executor; @@ -943,7 +942,7 @@ aclnnStatus aclnnRmsNorm(void* workspace, uint64_t workspaceSize, } #endif -aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, ggml_tensor* dst, +static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, ggml_tensor* dst, int64_t* ne, int64_t dims, aclDataType type, size_t type_size) { int64_t elements = 1; @@ -964,7 +963,7 @@ aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, ggml_tensor* dst, return zero; } -aclTensor* aclnn_ones(ggml_backend_cann_context& ctx, ggml_tensor* dst, +static aclTensor* aclnn_ones(ggml_backend_cann_context& ctx, ggml_tensor* dst, int64_t* ne, int64_t dims, aclDataType type, size_t type_size, float value = 1.0f) { aclTensor* acl_tensor = aclnn_zero(ctx, dst, ne, dims, type, type_size); @@ -1080,7 +1079,7 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, ACL_CHECK(aclDestroyTensor(acl_dst)); } -void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, +static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, aclDataType cast_data_type, ggml_tensor* bind_tensor) { uint64_t workspaceSize = 0; @@ -1097,7 +1096,7 @@ void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, stream)); } -void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src, +static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, int64_t* new_dim, uint64_t dims, ggml_tensor* bind_tensor) { aclIntArray* acl_dims = aclCreateIntArray(new_dim, dims); @@ -1153,8 +1152,6 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const int64_t N = is_2D ? ne13 : ne12; const int64_t IC = is_2D ? ne12 : ne11; - const int64_t IH = is_2D ? ne11 : 1; - const int64_t IW = ne10; const int64_t KH = is_2D ? ne01 : 1; const int64_t KW = ne00; @@ -1249,7 +1246,7 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyIntArray(strides)); } -void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src, +static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src, ggml_tensor* bind_tensor) { uint64_t workspaceSize = 0; aclOpExecutor* executor; @@ -1265,7 +1262,7 @@ void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclnnInplaceExp(workspaceAddr, workspaceSize, executor, ctx.stream())); } -void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, float scale, +static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, float scale, ggml_tensor* bind_tensor) { aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); @@ -1285,7 +1282,7 @@ void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, float scale, ACL_CHECK(aclDestroyScalar(acl_scale)); } -void aclnn_inplace_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, +static void aclnn_inplace_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_other, ggml_tensor* bind_tensor) { uint64_t workspaceSize = 0; aclOpExecutor* executor; @@ -1301,7 +1298,7 @@ void aclnn_inplace_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclnnInplaceMul(workspaceAddr, workspaceSize, executor, ctx.stream())); } -void aclnn_noinplcace_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, +static void aclnn_noinplcace_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_other, aclTensor* acl_dst, ggml_tensor* bind_tensor) { uint64_t workspaceSize = 0; @@ -1317,7 +1314,7 @@ void aclnn_noinplcace_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, ACL_CHECK(aclnnMul(workspaceAddr, workspaceSize, executor, ctx.stream())); } -void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, +static void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, ggml_tensor* bind_tensor) { uint64_t workspaceSize = 0; aclOpExecutor* executor; @@ -1332,7 +1329,7 @@ void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, ACL_CHECK(aclnnCos(workspaceAddr, workspaceSize, executor, ctx.stream())); } -void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, +static void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, ggml_tensor* bind_tensor) { uint64_t workspaceSize = 0; aclOpExecutor* executor; @@ -1452,7 +1449,7 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ACL_CHECK(aclDestroyTensor(acl_dst)); } -void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, +static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, aclTensor* acl_dst, ggml_tensor* bind_tensor) { // fill acl_dst with scalar value. @@ -1473,7 +1470,7 @@ void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, ACL_CHECK(aclDestroyScalar(acl_scalar)); } -void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_dst, +static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_dst, aclTensor* acl_exp, ggml_tensor* bind_tensor) { // acl_dst = acl_dst^acl_exp @@ -1491,11 +1488,12 @@ void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_dst, executor, ctx.stream())); } -void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, +static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_position, aclTensor* acl_dst, const int n_head, const int64_t src_ne0, const int64_t src_ne1, const int64_t src_ne2, const int64_t src_ne3, const size_t src_nb0, float max_bias, ggml_tensor* dst) { + GGML_UNUSED(src_ne1); const int64_t ne2_ne3 = src_ne2 * src_ne3; GGML_ASSERT(src_nb0 == sizeof(float)); GGML_ASSERT(n_head == src_ne2); @@ -1537,7 +1535,7 @@ void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, size_t tmp_arange2_nb[] = {sizeof(dst->type)}; aclTensor* tmp_arange2_tensor = create_acl_tensor( - tmp_arange_buffer + n_heads_log2_floor * ggml_type_size(dst->type), + (char*)tmp_arange_buffer + n_heads_log2_floor * ggml_type_size(dst->type), type_mapping(dst->type), ggml_type_size(dst->type), tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, @@ -1560,7 +1558,7 @@ void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; size_t tmp_mk_base2_nb[] = {sizeof(dst->type)}; aclTensor* tmp_mk_base2_tensor = create_acl_tensor( - tmp_mk_base_buffer + n_heads_log2_floor * ggml_type_size(dst->type), + (char*)tmp_mk_base_buffer + n_heads_log2_floor * ggml_type_size(dst->type), type_mapping(dst->type), ggml_type_size(dst->type), tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor, dst); @@ -1615,59 +1613,11 @@ void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, ACL_CHECK(aclDestroyTensor(tmp_output_tensor)); } -void ggml_cann_alibi(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; - - const int n_head = ((int32_t*)dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t*)dst->op_params + 2, sizeof(float)); - - const int64_t ne0 = src->ne[0]; // all_seq_len = n_past + ne1 - const int64_t ne1 = src->ne[1]; // seq_len_without_past - const int64_t ne2 = src->ne[2]; // n_head -> this is k - const int64_t ne3 = src->ne[3]; // batch - - const int64_t n = ggml_nrows(src); - const int64_t ne2_ne3 = n / ne1; // ne2*ne3 - - const size_t nb0 = src->nb[0]; - - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(n_head == ne2); - - // position arange: [0, ..., ne0) - float start = 0; - float stop = ne0; - float step = 1; - int64_t n_elements_arange = ne0; - int64_t tmp_position_ne[] = {ne0, 1, 1, 1}; - size_t tmp_position_nb[] = {sizeof(dst->type)}; - - void* tmp_position_buffer = ctx.alloc_buffer(dst, ne0 * sizeof(dst->type)); - aclTensor* tmp_position_tensor = create_acl_tensor( - tmp_position_buffer, type_mapping(dst->type), ggml_type_size(dst->type), - tmp_position_ne, tmp_position_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); - - aclnn_arange(ctx, tmp_position_tensor, start, stop, step, n_elements_arange, - dst); - - // call alibi - aclTensor* acl_src = create_acl_tensor(src); - aclTensor* acl_dst = create_acl_tensor(dst); - aclnn_alibi(ctx, acl_src, tmp_position_tensor, acl_dst, n_head, ne0, ne1, - ne2, ne3, nb0, max_bias, dst); - - ACL_CHECK(aclDestroyTensor(tmp_position_tensor)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); -} - void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; ggml_cann_dup(ctx, dst); } -void aclnn_inplace_add(ggml_backend_cann_context& ctx, aclTensor* acl_src, +static void aclnn_inplace_add(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, ggml_tensor* bind_tensor) { aclScalar* alpha = nullptr; float alphaValue = 1.0f; @@ -1823,7 +1773,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { } } -void aclnn_repeat_interleave(ggml_backend_cann_context& ctx, aclTensor* acl_src, +static void aclnn_repeat_interleave(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, int64_t dim, int64_t repeats, int64_t output_size, ggml_tensor* bind_tensor) { // each elem in acl_src will repeat. repeat number is `repeats`, repeats dim @@ -1845,7 +1795,7 @@ void aclnn_repeat_interleave(ggml_backend_cann_context& ctx, aclTensor* acl_src, executor, main_stream)); } -void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input, +static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input, aclTensor* acl_weight, aclTensor* acl_dst, ggml_tensor* bind_tensor) { int8_t cube_math_type = 1; // ALLOW_FP32_DOWN_PRECISION, when input is @@ -1866,7 +1816,7 @@ void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input, ACL_CHECK(aclnnMatmul(workspaceAddr, workspaceSize, executor, main_stream)); } -void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { +static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src0 = dst->src[0]; // weight ggml_tensor* src1 = dst->src[1]; // input @@ -1958,7 +1908,7 @@ void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(acl_dst)); } -void ggml_cann_mul_mat_q8_0(ggml_backend_cann_context& ctx, ggml_tensor* dst) { +static void ggml_cann_mul_mat_q8_0(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src0 = dst->src[0]; // weight ggml_tensor* src1 = dst->src[1]; // input @@ -2098,7 +2048,7 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { } } -void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src, +static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, int64_t* shifts, int64_t* dims, ggml_tensor* bind_tensor) { aclIntArray* acl_shifts = aclCreateIntArray(shifts, 1); @@ -2189,18 +2139,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { void* cos_buffer = ctx.alloc_buffer(dst, src0->ne[0] * src0->ne[2] * sizeof(float_t)); - try { - GGML_ASSERT(param.input_ne[0] == 128); - aclrtlaunch_ascendc_rope_init_cache(param.position_ne[0], ctx.stream(), - position_cast_buffer, sin_buffer, - cos_buffer, param_buffer); - ACL_CHECK(aclrtFree(param_buffer)); - } catch (...) { - for (int i = 0; i < 4; i++) { - printf("ne%d: %d, %d \n", i, param.input_ne[i], - param.position_ne[i]); - } - } + aclrtlaunch_ascendc_rope_init_cache(param.position_ne[0], ctx.stream(), + position_cast_buffer, sin_buffer, + cos_buffer, param_buffer); + ACL_CHECK(aclrtFree(param_buffer)); // reshape sin&cos // TODO: ne[3] != 0 @@ -2261,7 +2203,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { acl_minus_one_tensor = create_acl_tensor( minus_one_scale_buffer, ACL_INT64, sizeof(int64_t), minus_one_ne, minus_one_nb, GGML_MAX_DIMS); - int64_t minus_one_scale[src0->ne[0]]; + int64_t* minus_one_scale = new int64_t[src0->ne[0]]; for (int i = 0; i < src0->ne[0]; i += 2) { minus_one_scale[i] = -1.0; minus_one_scale[i + 1] = 1.0; @@ -2270,6 +2212,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclrtMemcpy(minus_one_scale_buffer, src0->ne[0] * sizeof(int64_t), minus_one_scale, src0->ne[0] * sizeof(int64_t), ACL_MEMCPY_HOST_TO_DEVICE); + delete[] minus_one_scale; } else { // roll input: [q0,q1,q2,...] -> [q_half,q_half+1,..., // q0,q1,...q_half-1] @@ -2298,7 +2241,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { acl_minus_one_tensor = create_acl_tensor( minus_one_scale_buffer, ACL_INT64, sizeof(int64_t), minus_one_ne, minus_one_nb, GGML_MAX_DIMS); - int64_t minus_one_scale[src0->ne[0]]; + int64_t* minus_one_scale = new int64_t[src0->ne[0]]; for (int i = 0; i < src0->ne[0]; i++) { if (i < src0->ne[0] / 2) { minus_one_scale[i] = -1.0; @@ -2310,6 +2253,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclrtMemcpy(minus_one_scale_buffer, src0->ne[0] * sizeof(int64_t), minus_one_scale, src0->ne[0] * sizeof(int64_t), ACL_MEMCPY_HOST_TO_DEVICE); + delete[] minus_one_scale; } // input * scale diff --git a/ggml-cann/aclnn_ops.h b/ggml-cann/aclnn_ops.h index be6a5dce6ff91..2053650ec0447 100644 --- a/ggml-cann/aclnn_ops.h +++ b/ggml-cann/aclnn_ops.h @@ -67,8 +67,6 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst); void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* dst); -void ggml_cann_alibi(ggml_backend_cann_context& ctx, ggml_tensor* dst); - void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst); void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 32fec9772e4fc..baf182c6628d2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1645,32 +1645,9 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr; ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias); - -// GGML_OP_ALIBI -struct test_alibi : public test_case { - const ggml_type type; - const std::array ne_a; - const int n_past; - const int n_head; - const float bias_max; - - std::string vars() override { - return VARS_TO_STR5(type, ne_a, n_past, n_head, bias_max); - } - - test_alibi(ggml_type type = GGML_TYPE_F32, - std::array ne_a = {30, 20, 10, 1}, - int n_past = 0, int n_head = 10, - float bias_max = 0.9f) - : type(type), ne_a(ne_a), n_past(n_past), n_head(n_head), bias_max(bias_max) {} - - ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); - ggml_tensor * out = ggml_alibi(ctx, a, n_past, n_head, bias_max); return out; } }; - enum llm_norm_type { LLM_NORM, LLM_NORM_RMS, @@ -2353,19 +2330,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } } - - for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B - test_cases.emplace_back(new test_rope(type, {128, 32, 512, 1}, 128, 0, 512)); // llama 8B - test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512)); // llama 13B - test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512)); // llama 30B - test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512)); // llama 65B - test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B) - test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B) - test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm) - test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2) } for (int v : { 0, 1, 2, 3 }) { @@ -2407,12 +2371,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } } - - for (float bias_max : {-0.5, 0.5}) { - test_cases.emplace_back(new test_alibi(GGML_TYPE_F32, {16, 2, 10, 1}, 0, 10, bias_max)); - test_cases.emplace_back(new test_alibi(GGML_TYPE_F32, {16, 2, 32, 1}, 0, 32, bias_max)); - test_cases.emplace_back(new test_alibi(GGML_TYPE_F32, {128, 4, 10, 1}, 0, 10, bias_max)); - test_cases.emplace_back(new test_alibi(GGML_TYPE_F32, {128, 4, 32, 1}, 0, 32, bias_max)); } // these tests are disabled to save execution time, but they can be handy for debugging