diff --git a/ggml-cann/aclnn_ops.cpp b/ggml-cann/aclnn_ops.cpp index 0aa9ddca45d00..04dfe51eeda5f 100644 --- a/ggml-cann/aclnn_ops.cpp +++ b/ggml-cann/aclnn_ops.cpp @@ -4,13 +4,12 @@ #include #include #include +#include #include #include #include -// TODO: repeat is implemented through add to apply bcast. Optimize it. -// change to use aclnnRepeat void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src = dst->src[0]; GGML_ASSERT(ggml_can_repeat(src, dst)); @@ -20,45 +19,30 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // Set dst to a zero tensor. ACL_CHECK(aclrtMemsetAsync(dst->data, nbytes, 0, nbytes, main_stream)); - aclTensor* acl_src; - aclTensor* acl_dst; + aclTensor* acl_src = create_acl_tensor(src); + aclTensor* acl_dst = create_acl_tensor(dst); - // Short cut for same shape. - if (ggml_are_same_shape(src, dst)) { - ACL_CHECK(aclrtMemcpyAsync(dst->data, nbytes, src->data, nbytes, - ACL_MEMCPY_DEVICE_TO_DEVICE, main_stream)); - } else { - if (need_bcast(dst, src)) { - BCAST_SHAPE(dst, src); - acl_dst = create_acl_tensor(dst, BCAST_PARAM(dst)); - acl_src = create_acl_tensor(src, BCAST_PARAM(src)); - } else { - acl_dst = create_acl_tensor(dst); - acl_src = create_acl_tensor(src); - } - - // Add src0 to dst. - aclScalar* alpha = nullptr; - int alphaValue = 1; - alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_INT32); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src, alpha, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - workspaceAddr = ctx.alloc_buffer(workspaceSize); - } - - ACL_CHECK(aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, - main_stream)); - - ACL_CHECK(aclDestroyScalar(alpha)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + int64_t repeatsArray[] = {dst->ne[3] / src->ne[3], dst->ne[2] / src->ne[2], + dst->ne[1] / src->ne[1], dst->ne[0] / src->ne[0]}; + + aclIntArray *repeats = aclCreateIntArray(repeatsArray, GGML_MAX_DIMS); + + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + void* workspaceAddr = nullptr; + + ACL_CHECK(aclnnRepeatGetWorkspaceSize(acl_src, repeats, acl_dst, &workspaceSize, &executor)); + + if (workspaceSize > 0) { + workspaceAddr = ctx.alloc_buffer(workspaceSize); } + + aclrtStream stream = ctx.stream(); + ACL_CHECK(aclnnRepeat(workspaceAddr, workspaceSize, executor, stream)); + ACL_CHECK(aclDestroyIntArray(repeats)); + ACL_CHECK(aclDestroyTensor(acl_src)); + ACL_CHECK(aclDestroyTensor(acl_dst)); + } void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) {