Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SVE Implementation for SDDMMCOO with copyrhs op #7857

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ if (${BUILD_TYPE} STREQUAL "dev")
if (MSVC)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Od")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Od")
elseif ( CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(AARCH64)")
# Check if the compiler supports ARMv8.2-A or later with SVE
include(CheckCCompilerFlag)
# Try to detect whether the system supports SVE
check_c_compiler_flag("-march=armv8.2-a+sve" SUPPORTS_SVE)
# Output the result
if(SUPPORTS_SVE)
message(STATUS "Hardware supports SVE")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0 -g3 -ggdb -march=armv8.2-a+sve")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g3 -ggdb -march=armv8.2-a+sve")
endif()
else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0 -g3 -ggdb")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g3 -ggdb")
Expand Down
52 changes: 52 additions & 0 deletions src/array/cpu/sddmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include <dgl/bcast.h>
#include <dgl/runtime/parallel_for.h>

#ifdef __ARM_FEATURE_SVE
#include <arm_sve.h> // to leverage sve intrinsics
#endif

#include "../selector.h"

namespace dgl {
Expand Down Expand Up @@ -222,6 +226,54 @@ struct Dot {

} // namespace op

// SDDMMCoo Specialization
#ifdef __ARM_FEATURE_SVE
template <>
void SDDMMCoo <int32_t, float, dgl::aten::cpu::op::CopyRhs<float>, 0, 2> (
const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out) {
const bool has_idx = !IsNullArray(coo.data);
const int32_t* row = coo.row.Ptr<int32_t>();
const int32_t* col = coo.col.Ptr<int32_t>();
const int32_t* edges = coo.data.Ptr<int32_t>();
const float* X = lhs.Ptr<float>();
const float* Y = rhs.Ptr<float>();
const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
rhs_dim = bcast.rhs_len, reduce_size = bcast.reduce_size;
float* O = out.Ptr<float>();
#pragma omp parallel for
for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
const int32_t rid = row[i];
const int32_t cid = col[i];
const int32_t eid = has_idx ? edges[i] : i;
float* out_off = O + eid * dim;
if (!bcast.use_bcast && reduce_size == 1) {
for (int64_t k = 0; k < dim; k += svcntw()) {
svbool_t pgk = svwhilelt_b32(k, dim);
int64_t rhs_base1 = cid * rhs_dim;
svfloat32_t rhs_off_vector = svld1_f32(pgk, &Y[rhs_base1 + k]);
svst1_f32(pgk, &out_off[k], rhs_off_vector);
}
} else {
//with bcast.use_bcast == true, Op::use_lhs == false, and Op::Call
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const float* lhs_off =
dgl::aten::cpu::op::CopyRhs<float>::use_lhs ? X + rid * lhs_dim +
lhs_add * reduce_size
: nullptr;

const float* rhs_off =
dgl::aten::cpu::op::CopyRhs<float>::use_rhs ? Y + cid * rhs_dim +
rhs_add * reduce_size
: nullptr;
out_off[k] = dgl::aten::cpu::op::CopyRhs<float>::Call(lhs_off, rhs_off, bcast.reduce_size);
}
}
}
}
#endif

} // namespace cpu
} // namespace aten
} // namespace dgl
Expand Down