Skip to content

Commit

Permalink
optimize op repeat
Browse files Browse the repository at this point in the history
  • Loading branch information
hipudding committed Apr 7, 2024
1 parent ab167c2 commit 339a3fb
Showing 1 changed file with 23 additions and 39 deletions.
62 changes: 23 additions & 39 deletions ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
#include <aclnnop/aclnn_cast.h>
#include <aclnnop/aclnn_group_norm.h>
#include <aclnnop/aclnn_softmax.h>
#include <aclnnop/aclnn_repeat.h>

#include <cmath>
#include <cstring>
#include <vector>

// TODO: repeat is implemented through add to apply bcast. Optimize it.
// change to use aclnnRepeat
void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src = dst->src[0];
GGML_ASSERT(ggml_can_repeat(src, dst));
Expand All @@ -20,45 +19,30 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// Set dst to a zero tensor.
ACL_CHECK(aclrtMemsetAsync(dst->data, nbytes, 0, nbytes, main_stream));

aclTensor* acl_src;
aclTensor* acl_dst;
aclTensor* acl_src = create_acl_tensor(src);
aclTensor* acl_dst = create_acl_tensor(dst);

// Short cut for same shape.
if (ggml_are_same_shape(src, dst)) {
ACL_CHECK(aclrtMemcpyAsync(dst->data, nbytes, src->data, nbytes,
ACL_MEMCPY_DEVICE_TO_DEVICE, main_stream));
} else {
if (need_bcast(dst, src)) {
BCAST_SHAPE(dst, src);
acl_dst = create_acl_tensor(dst, BCAST_PARAM(dst));
acl_src = create_acl_tensor(src, BCAST_PARAM(src));
} else {
acl_dst = create_acl_tensor(dst);
acl_src = create_acl_tensor(src);
}

// Add src0 to dst.
aclScalar* alpha = nullptr;
int alphaValue = 1;
alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_INT32);

uint64_t workspaceSize = 0;
aclOpExecutor* executor;
void* workspaceAddr = nullptr;

ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src, alpha,
&workspaceSize, &executor));
if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(workspaceSize);
}

ACL_CHECK(aclnnInplaceAdd(workspaceAddr, workspaceSize, executor,
main_stream));

ACL_CHECK(aclDestroyScalar(alpha));
ACL_CHECK(aclDestroyTensor(acl_src));
ACL_CHECK(aclDestroyTensor(acl_dst));
int64_t repeatsArray[] = {dst->ne[3] / src->ne[3], dst->ne[2] / src->ne[2],
dst->ne[1] / src->ne[1], dst->ne[0] / src->ne[0]};

aclIntArray *repeats = aclCreateIntArray(repeatsArray, GGML_MAX_DIMS);

uint64_t workspaceSize = 0;
aclOpExecutor* executor;
void* workspaceAddr = nullptr;

ACL_CHECK(aclnnRepeatGetWorkspaceSize(acl_src, repeats, acl_dst, &workspaceSize, &executor));

if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(workspaceSize);
}

aclrtStream stream = ctx.stream();
ACL_CHECK(aclnnRepeat(workspaceAddr, workspaceSize, executor, stream));
ACL_CHECK(aclDestroyIntArray(repeats));
ACL_CHECK(aclDestroyTensor(acl_src));
ACL_CHECK(aclDestroyTensor(acl_dst));

}

void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
Expand Down

0 comments on commit 339a3fb

Please sign in to comment.