diff --git a/ggml-cann.cpp b/ggml-cann.cpp index cd336c1075228..2f5376c20be84 100644 --- a/ggml-cann.cpp +++ b/ggml-cann.cpp @@ -429,6 +429,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, return false; case GGML_OP_CONT: ggml_cann_cont(ctx, dst); + break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -445,12 +446,13 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_ALIBI: case GGML_OP_IM2COL: case GGML_OP_POOL_2D: - case GGML_OP_SUM_ROWS: return false; + case GGML_OP_SUM_ROWS: + ggml_cann_sum_rows(ctx, dst); + break; case GGML_OP_ARGSORT: ggml_cann_argsort(ctx, dst); break; - return false; default: return false; } @@ -651,7 +653,6 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, case GGML_OP_CPY: return false; case GGML_OP_DUP: - return true; case GGML_OP_REPEAT: case GGML_OP_CONCAT: case GGML_OP_NONE: @@ -659,9 +660,7 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - return true; case GGML_OP_NORM: - return true; case GGML_OP_ADD: case GGML_OP_MUL: case GGML_OP_DIV: @@ -669,7 +668,6 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, case GGML_OP_RMS_NORM: return false; case GGML_OP_SCALE: - return true; case GGML_OP_SQR: case GGML_OP_CLAMP: case GGML_OP_CONT: @@ -682,18 +680,15 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, case GGML_OP_ALIBI: case GGML_OP_IM2COL: case GGML_OP_POOL_2D: - case GGML_OP_SUM_ROWS: return false; + case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: - return true; case GGML_OP_ACC: - return true; case GGML_OP_GROUP_NORM: return true; case GGML_OP_UPSCALE: return false; case GGML_OP_PAD: - return true; case GGML_OP_ARANGE: return true; case GGML_OP_TIMESTEP_EMBEDDING: diff --git a/ggml-cann/aclnn_ops.cpp b/ggml-cann/aclnn_ops.cpp index a4d17330f38b2..36565db131021 100644 --- a/ggml-cann/aclnn_ops.cpp +++ b/ggml-cann/aclnn_ops.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -475,4 +476,33 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(acl_src1)); ACL_CHECK(aclDestroyTensor(acl_dst)); +} + +void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + ggml_tensor* src = dst->src[0]; + + aclTensor* acl_src = create_acl_tensor(src); + + GGML_ASSERT(dst->ne[0] == 1); + aclTensor* acl_dst = create_acl_tensor(dst); + + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + void* workspaceAddr = nullptr; + + int64_t reduce_dims_host[] = {3}; + aclIntArray* reduce_dims = aclCreateIntArray(reduce_dims_host, 1); + + ACL_CHECK(aclnnReduceSumGetWorkspaceSize(acl_src, reduce_dims, true, + type_mapping(src->type), acl_dst, + &workspaceSize, &executor)); + if (workspaceSize > 0) { + workspaceAddr = ctx.alloc_buffer(workspaceSize); + } + + aclrtStream stream = ctx.stream(); + ACL_CHECK(aclnnReduceSum(workspaceAddr, workspaceSize, executor, stream)); + + ACL_CHECK(aclDestroyTensor(acl_src)); + ACL_CHECK(aclDestroyTensor(acl_dst)); } \ No newline at end of file diff --git a/ggml-cann/aclnn_ops.h b/ggml-cann/aclnn_ops.h index c175dbce25d4f..5de4dabebebe2 100644 --- a/ggml-cann/aclnn_ops.h +++ b/ggml-cann/aclnn_ops.h @@ -45,6 +45,8 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst); void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst); + template