Skip to content

Commit

Permalink
bf16->f32 avx512bf16 GEMM microkernels
Browse files Browse the repository at this point in the history
rsp is now always valid

PiperOrigin-RevId: 718323976
  • Loading branch information
alankelly authored and xnnpack-bot committed Jan 27, 2025
1 parent a108468 commit 1388280
Show file tree
Hide file tree
Showing 273 changed files with 21,142 additions and 1,517 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ jobs:
env:
CC: gcc-9
CXX: g++-9
BAZEL_DEFINES: --define=xnn_enable_avxvnni=false --define=xnn_enable_avx256vnni=false --define=xnn_enable_avxvnniint8=false --define=xnn_enable_avx512amx=false --define=xnn_enable_avx512fp16=false
BAZEL_DEFINES: --define=xnn_enable_avxvnni=false --define=xnn_enable_avx256vnni=false --define=xnn_enable_avxvnniint8=false --define=xnn_enable_avx512amx=false --define=xnn_enable_avx512fp16=false --define=xnn_enable_avx512bf16=false
steps:
- uses: actions/checkout@v4
- name: Update apt
Expand Down
28 changes: 28 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,18 @@ config_setting(
define_values = {"xnn_enable_avx512fp16": "false"},
)

# Enables usage of Intel AVX512-BF16 (bf16 arithmetic) kernels.
config_setting(
name = "xnn_enable_avx512bf16_explicit_true",
define_values = {"xnn_enable_avx512bf16": "true"},
)

# Disables usage of Intel AVX512-BF16 (bf16 arithmetic) kernels.
config_setting(
name = "xnn_enable_avx512bf16_explicit_false",
define_values = {"xnn_enable_avx512bf16": "false"},
)

# Enables usage of Intel AVX-VNNI (integer dot product) kernels.
config_setting(
name = "xnn_enable_avxvnni_explicit_true",
Expand Down Expand Up @@ -1662,6 +1674,22 @@ selects.config_setting_group(
],
)

selects.config_setting_group(
name = "avx512bf16_enabled_by_default",
match_any = [
"//build_config:x86_64",
],
)

alias(
name = "avx512bf16_enabled",
actual = select({
":xnn_enable_avx512bf16_explicit_true": ":xnn_enable_avx512bf16_explicit_true",
":xnn_enable_avx512bf16_explicit_false": ":xnn_enable_avx512bf16_explicit_true",
"//conditions:default": ":avx512bf16_enabled_by_default",
}),
)

selects.config_setting_group(
name = "arm_bf16_enabled_by_default",
match_any = [
Expand Down
19 changes: 19 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,18 @@ ELSEIF(CMAKE_C_COMPILER_ID STREQUAL "Clang")
ELSEIF(CMAKE_C_COMPILER_ID STREQUAL "MSVC")
SET(XNNPACK_ENABLE_AVX512FP16 OFF)
ENDIF()
OPTION(XNNPACK_ENABLE_AVX512BF16 "Build XNNPACK with AVX512-BF16 micro-kernels" ON)
IF(CMAKE_C_COMPILER_ID STREQUAL "GNU")
IF(CMAKE_C_COMPILER_VERSION VERSION_LESS "13")
SET(XNNPACK_ENABLE_AVX512BF16 OFF)
ENDIF()
ELSEIF(CMAKE_C_COMPILER_ID STREQUAL "Clang")
IF(CMAKE_C_COMPILER_VERSION VERSION_LESS "15")
SET(XNNPACK_ENABLE_AVX512BF16 OFF)
ENDIF()
ELSEIF(CMAKE_C_COMPILER_ID STREQUAL "MSVC")
SET(XNNPACK_ENABLE_AVX512BF16 OFF)
ENDIF()
OPTION(XNNPACK_ENABLE_HVX "Build XNNPACK with Hexagon HVX micro-kernels" ON)
OPTION(XNNPACK_ENABLE_KLEIDIAI "Use KleidiAI GEMM microkernels for Arm" ON)
IF(XNNPACK_TARGET_PROCESSOR STREQUAL "arm64" AND XNNPACK_ENABLE_ARM_I8MM AND NOT CMAKE_C_COMPILER_ID STREQUAL "MSVC")
Expand Down Expand Up @@ -307,6 +319,7 @@ ADD_COMPILE_DEFINITIONS("XNN_ENABLE_AVX512VNNI=$<BOOL:${XNNPACK_ENABLE_AVX512VNN
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_AVX512VNNIGFNI=$<BOOL:${XNNPACK_ENABLE_AVX512VNNIGFNI}>")
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_AVX512AMX=$<BOOL:${XNNPACK_ENABLE_AVX512AMX}>")
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_AVX512FP16=$<BOOL:${XNNPACK_ENABLE_AVX512FP16}>")
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_AVX512BF16=$<BOOL:${XNNPACK_ENABLE_AVX512BF16}>")
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_VSX=$<BOOL:${XNNPACK_ENABLE_VSX}>")
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_ASSEMBLY=$<BOOL:${XNNPACK_ENABLE_ASSEMBLY}>")
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_MEMOPT=$<BOOL:${XNNPACK_ENABLE_MEMOPT}>")
Expand Down Expand Up @@ -677,6 +690,9 @@ IF(XNNPACK_TARGET_PROCESSOR MATCHES "^x86(_64)?$")
IF(XNNPACK_ENABLE_AVX512FP16)
LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_AVX512FP16_MICROKERNEL_SRCS})
ENDIF()
IF(XNNPACK_ENABLE_AVX512BF16)
LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_AVX512BF16_MICROKERNEL_SRCS})
ENDIF()
IF(XNNPACK_ENABLE_AVXVNNI)
LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_AVXVNNI_MICROKERNEL_SRCS})
ENDIF()
Expand Down Expand Up @@ -727,6 +743,9 @@ IF(XNNPACK_TARGET_PROCESSOR MATCHES "^x86(_64)?$")
IF(XNNPACK_ENABLE_AVX512FP16)
LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_AVX512FP16_MICROKERNEL_SRCS})
ENDIF()
IF(XNNPACK_ENABLE_AVX512BF16)
LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_AVX512BF16_MICROKERNEL_SRCS})
ENDIF()
IF(XNNPACK_ENABLE_AVXVNNI)
LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_AVXVNNI_MICROKERNEL_SRCS})
ENDIF()
Expand Down
267 changes: 267 additions & 0 deletions bench/qd8-f32-qc4w-gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,273 @@
#endif // XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64


#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY
static void qd8_f32_qc4w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_ld32_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_ld32_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_ld32_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_ld32_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_ld32_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_ld32_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_ld32_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_ld32_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_ld32_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_ld32_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_ld32_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_ld32_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_ld64_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_ld64_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_ld64_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_ld64_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_ld64_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_ld64_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_ld64_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_ld64_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_ld64_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_ld64_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_ld64_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_ld64_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_ld128_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_ld128_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_ld128_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_ld128_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_ld128_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_ld128_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_ld128_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_ld128_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_ld128_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_ld128_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_ld128_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_ld128_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_ld32_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_ld32_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_ld32_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_ld32_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_ld32_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_ld32_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_ld32_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_ld32_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_ld32_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_ld32_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_ld32_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_ld32_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_ld64_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_ld64_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_ld64_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_ld64_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_ld64_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_ld64_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_ld64_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_ld64_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_ld64_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_ld64_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_ld64_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_ld64_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_ld128_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_ld128_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_ld128_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_ld128_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_ld128_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_ld128_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_ld128_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_ld128_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_ld128_2)

static void qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_ld128_2(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_ld128_2,
xnn_init_f32_qc4w_minmax_scalar_params,
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64,
/*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1,
benchmark::utils::CheckNEONDOT);
}

BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_ld128_2)
#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY


#if XNN_ENABLE_ARM_DOTPROD && (XNN_ARCH_ARM || XNN_ARCH_ARM64)
static void qd8_f32_qc4w_gemm_minmax_ukernel_1x8c4__neondot(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
Expand Down
Loading

0 comments on commit 1388280

Please sign in to comment.