From a56fe41b307c2595bb804cf95e565bb4eec8306a Mon Sep 17 00:00:00 2001 From: huafengchun Date: Wed, 3 Apr 2024 06:35:55 +0000 Subject: [PATCH] add group norm --- ggml-cann.cpp | 5 +++- ggml-cann/aclnn_ops.cpp | 55 +++++++++++++++++++++++++++++++++++++++++ ggml-cann/aclnn_ops.h | 2 ++ 3 files changed, 61 insertions(+), 1 deletion(-) diff --git a/ggml-cann.cpp b/ggml-cann.cpp index 1f68884f22523..86c351c9af9db 100644 --- a/ggml-cann.cpp +++ b/ggml-cann.cpp @@ -400,7 +400,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, ggml_cann_norm(ctx, dst); break; case GGML_OP_GROUP_NORM: - return false; + ggml_cann_group_norm(ctx, dst); + break; case GGML_OP_CONCAT: ggml_cann_concat(ctx, dst); break; @@ -679,7 +680,9 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, case GGML_OP_ARGSORT: return true; case GGML_OP_ACC: + return false; case GGML_OP_GROUP_NORM: + return true; case GGML_OP_UPSCALE: return false; case GGML_OP_PAD: diff --git a/ggml-cann/aclnn_ops.cpp b/ggml-cann/aclnn_ops.cpp index cd76d2410eacd..4d0c0cdc5c72e 100644 --- a/ggml-cann/aclnn_ops.cpp +++ b/ggml-cann/aclnn_ops.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -397,3 +398,57 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { } } +void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + ggml_tensor* src = dst->src[0]; + + aclTensor* acl_src = create_acl_tensor(src); + aclTensor* acl_dst = create_acl_tensor(dst); + + const float eps = 1e-6f; // TODO: make this a parameter + int n_groups = dst->op_params[0]; + + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + void* workspaceAddr = nullptr; + + int64_t N = src->ne[3]; + int64_t C = src->ne[2]; + int64_t HxW = src->ne[1] * src->ne[0]; + + size_t type_size = ggml_type_size(src->type); + int64_t ne[] = {n_groups, N}; + size_t nb[] = {type_size, type_size * n_groups}; + size_t n_bytes = N * n_groups; + void* buffer; + ACL_CHECK(aclrtMalloc(&buffer, n_bytes * 2, ACL_MEM_MALLOC_HUGE_FIRST)); + aclTensor* acl_mean_out = + create_acl_tensor(buffer, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND); + aclTensor* acl_rstd_out = create_acl_tensor( + (char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND); + + ACL_CHECK(aclnnGroupNormGetWorkspaceSize( + acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, acl_dst, + acl_mean_out, acl_rstd_out, &workspaceSize, &executor)); + + if (workspaceSize > 0) { + ACL_CHECK(aclrtMalloc(&workspaceAddr, workspaceSize, + ACL_MEM_MALLOC_HUGE_FIRST)); + } + + aclrtStream stream = ctx.stream(); + + ACL_CHECK(aclnnGroupNorm(workspaceAddr, workspaceSize, executor, stream)); + + ACL_CHECK(aclDestroyTensor(acl_src)); + ACL_CHECK(aclDestroyTensor(acl_dst)); + ACL_CHECK(aclDestroyTensor(acl_mean_out)); + ACL_CHECK(aclDestroyTensor(acl_rstd_out)); + + // TODO: free after sync. + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtFree(buffer)); + + if (workspaceSize > 0) { + ACL_CHECK(aclrtFree(workspaceAddr)); + } +} diff --git a/ggml-cann/aclnn_ops.h b/ggml-cann/aclnn_ops.h index 1c963a4a60cd4..9feaa0ece4f98 100644 --- a/ggml-cann/aclnn_ops.h +++ b/ggml-cann/aclnn_ops.h @@ -39,6 +39,8 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst); void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst); + template