Skip to content

Commit

Permalink
add op sum_rows
Browse files Browse the repository at this point in the history
  • Loading branch information
hipudding committed Apr 8, 2024
1 parent 148d70a commit 6ec3a9c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 10 deletions.
15 changes: 5 additions & 10 deletions ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
return false;
case GGML_OP_CONT:
ggml_cann_cont(ctx, dst);
break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
Expand All @@ -445,12 +446,13 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
case GGML_OP_ALIBI:
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
return false;
case GGML_OP_SUM_ROWS:
ggml_cann_sum_rows(ctx, dst);
break;
case GGML_OP_ARGSORT:
ggml_cann_argsort(ctx, dst);
break;
return false;
default:
return false;
}
Expand Down Expand Up @@ -651,25 +653,21 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
case GGML_OP_CPY:
return false;
case GGML_OP_DUP:
return true;
case GGML_OP_REPEAT:
case GGML_OP_CONCAT:
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
return true;
case GGML_OP_NORM:
return true;
case GGML_OP_ADD:
case GGML_OP_MUL:
case GGML_OP_DIV:
return true;
case GGML_OP_RMS_NORM:
return false;
case GGML_OP_SCALE:
return true;
case GGML_OP_SQR:
case GGML_OP_CLAMP:
case GGML_OP_CONT:
Expand All @@ -682,18 +680,15 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
case GGML_OP_ALIBI:
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
return false;
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
return true;
case GGML_OP_ACC:
return true;
case GGML_OP_GROUP_NORM:
return true;
case GGML_OP_UPSCALE:
return false;
case GGML_OP_PAD:
return true;
case GGML_OP_ARANGE:
return true;
case GGML_OP_TIMESTEP_EMBEDDING:
Expand Down
30 changes: 30 additions & 0 deletions ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <aclnnop/aclnn_layer_norm.h>
#include <aclnnop/aclnn_repeat.h>
#include <aclnnop/aclnn_softmax.h>
#include <aclnnop/aclnn_reduce_sum.h>

#include <cmath>
#include <cstring>
Expand Down Expand Up @@ -475,4 +476,33 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {

ACL_CHECK(aclDestroyTensor(acl_src1));
ACL_CHECK(aclDestroyTensor(acl_dst));
}

void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src = dst->src[0];

aclTensor* acl_src = create_acl_tensor(src);

GGML_ASSERT(dst->ne[0] == 1);
aclTensor* acl_dst = create_acl_tensor(dst);

uint64_t workspaceSize = 0;
aclOpExecutor* executor;
void* workspaceAddr = nullptr;

int64_t reduce_dims_host[] = {3};
aclIntArray* reduce_dims = aclCreateIntArray(reduce_dims_host, 1);

ACL_CHECK(aclnnReduceSumGetWorkspaceSize(acl_src, reduce_dims, true,
type_mapping(src->type), acl_dst,
&workspaceSize, &executor));
if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(workspaceSize);
}

aclrtStream stream = ctx.stream();
ACL_CHECK(aclnnReduceSum(workspaceAddr, workspaceSize, executor, stream));

ACL_CHECK(aclDestroyTensor(acl_src));
ACL_CHECK(aclDestroyTensor(acl_dst));
}
2 changes: 2 additions & 0 deletions ggml-cann/aclnn_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);

void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst);

void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);

template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
aclTensor*, uint64_t*, aclOpExecutor**),
aclnnStatus execute(void*, uint64_t, aclOpExecutor*, aclrtStream)>
Expand Down

0 comments on commit 6ec3a9c

Please sign in to comment.