Skip to content

Commit

Permalink
add group norm
Browse files Browse the repository at this point in the history
  • Loading branch information
hipudding committed Apr 3, 2024
1 parent c4740d6 commit a56fe41
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 1 deletion.
5 changes: 4 additions & 1 deletion ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
ggml_cann_norm(ctx, dst);
break;
case GGML_OP_GROUP_NORM:
return false;
ggml_cann_group_norm(ctx, dst);
break;
case GGML_OP_CONCAT:
ggml_cann_concat(ctx, dst);
break;
Expand Down Expand Up @@ -679,7 +680,9 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
case GGML_OP_ARGSORT:
return true;
case GGML_OP_ACC:
return false;
case GGML_OP_GROUP_NORM:
return true;
case GGML_OP_UPSCALE:
return false;
case GGML_OP_PAD:
Expand Down
55 changes: 55 additions & 0 deletions ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <aclnnop/aclnn_layer_norm.h>
#include <aclnnop/aclnn_cast.h>
#include <aclnnop/aclnn_group_norm.h>

#include <cmath>
#include <cstring>
Expand Down Expand Up @@ -397,3 +398,57 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
}
}

void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src = dst->src[0];

aclTensor* acl_src = create_acl_tensor(src);
aclTensor* acl_dst = create_acl_tensor(dst);

const float eps = 1e-6f; // TODO: make this a parameter
int n_groups = dst->op_params[0];

uint64_t workspaceSize = 0;
aclOpExecutor* executor;
void* workspaceAddr = nullptr;

int64_t N = src->ne[3];
int64_t C = src->ne[2];
int64_t HxW = src->ne[1] * src->ne[0];

size_t type_size = ggml_type_size(src->type);
int64_t ne[] = {n_groups, N};
size_t nb[] = {type_size, type_size * n_groups};
size_t n_bytes = N * n_groups;
void* buffer;
ACL_CHECK(aclrtMalloc(&buffer, n_bytes * 2, ACL_MEM_MALLOC_HUGE_FIRST));
aclTensor* acl_mean_out =
create_acl_tensor(buffer, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);
aclTensor* acl_rstd_out = create_acl_tensor(
(char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);

ACL_CHECK(aclnnGroupNormGetWorkspaceSize(
acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, acl_dst,
acl_mean_out, acl_rstd_out, &workspaceSize, &executor));

if (workspaceSize > 0) {
ACL_CHECK(aclrtMalloc(&workspaceAddr, workspaceSize,
ACL_MEM_MALLOC_HUGE_FIRST));
}

aclrtStream stream = ctx.stream();

ACL_CHECK(aclnnGroupNorm(workspaceAddr, workspaceSize, executor, stream));

ACL_CHECK(aclDestroyTensor(acl_src));
ACL_CHECK(aclDestroyTensor(acl_dst));
ACL_CHECK(aclDestroyTensor(acl_mean_out));
ACL_CHECK(aclDestroyTensor(acl_rstd_out));

// TODO: free after sync.
ACL_CHECK(aclrtSynchronizeStream(stream));
ACL_CHECK(aclrtFree(buffer));

if (workspaceSize > 0) {
ACL_CHECK(aclrtFree(workspaceAddr));
}
}
2 changes: 2 additions & 0 deletions ggml-cann/aclnn_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst);

void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);

void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);

template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
aclTensor*, uint64_t*, aclOpExecutor**),
aclnnStatus execute(void*, uint64_t, aclOpExecutor*, aclrtStream)>
Expand Down

0 comments on commit a56fe41

Please sign in to comment.