Skip to content

Commit

Permalink
[experimental] Kleidi - add operator level tests (#1173)
Browse files Browse the repository at this point in the history
  • Loading branch information
digantdesai authored Oct 30, 2024
1 parent 581d8e0 commit 186e578
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 9 deletions.
6 changes: 6 additions & 0 deletions torchao/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

cmake_minimum_required(VERSION 3.19)
include(CMakeDependentOption)

project(torchao)

Expand All @@ -21,6 +22,7 @@ if(NOT TORCHAO_INCLUDE_DIRS)
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..)
endif()

option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF)
include(CMakePrintHelpers)

add_compile_options("-Wall" "-Werror" "-Wno-deprecated")
Expand All @@ -30,6 +32,10 @@ message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
include_directories(${TORCHAO_INCLUDE_DIRS})

if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
if(TORCHAO_BUILD_KLEIDIAI)
message(STATUS "Building with Arm KleidiAI library")
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
endif()
# Defines target torchao_kernels_aarch64
add_subdirectory(kernels/cpu/aarch64)
add_subdirectory(ops/linear_8bit_act_xbit_weight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ void kernel(
const void* activation_data,
float clamp_min,
float clamp_max) {
assert(output_m_stride == n);
if (clamp_min == 0 && clamp_max == 0) {
clamp_min = std::numeric_limits<float>::lowest();
clamp_max = std::numeric_limits<float>::max();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ void kernel(
const void* activation_data,
float clamp_min,
float clamp_max) {
assert(output_m_stride == n);
if (clamp_min == 0 && clamp_max == 0) {
clamp_min = std::numeric_limits<float>::lowest();
clamp_max = std::numeric_limits<float>::max();
Expand Down
3 changes: 3 additions & 0 deletions torchao/experimental/ops/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ FetchContent_Declare(
)
FetchContent_MakeAvailable(googletest)
enable_testing()
if(TORCHAO_BUILD_KLEIDIAI)
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
endif()

include_directories(${TORCHAO_INCLUDE_DIRS})

Expand Down
14 changes: 12 additions & 2 deletions torchao/experimental/ops/tests/build_and_run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,20 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

IS_ARM64=0
hash arch; retval=$?
if [[ ${retval} -eq 0 && $(arch) == "arm64" ]]; then
IS_ARM64=1
fi

export CMAKE_OUT=/tmp/cmake-out/torchao/tests
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S . -B ${CMAKE_OUT}
cmake \
-DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
-DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \
-S . \
-B ${CMAKE_OUT}

cmake --build ${CMAKE_OUT}
cmake --build ${CMAKE_OUT}

# Run
${CMAKE_OUT}/test_linear_8bit_act_xbit_weight
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
#include <torchao/experimental/ops/memory.h>
#include <torchao/experimental/ops/parallel.h>

#if defined(TORCHAO_ENABLE_KLEIDI)
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h>
#endif // TORCHAO_ENABLE_KLEIDI

const float kTol = 1.0e-5;

using namespace torchao::ops::linear_8bit_act_xbit_weight;

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp, bool has_kleidi=false>
UKernelConfig get_ukernel_config() {
UKernelConfig config;

if constexpr (!has_kleidi) {
namespace ukernel = torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;
config.mr = 1;
Expand All @@ -36,14 +41,40 @@ UKernelConfig get_ukernel_config() {
&ukernel::prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>;
config.kernel_fn =
&ukernel::kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>;
} else {
#if defined(TORCHAO_ENABLE_KLEIDI)
assert (weight_nbit == 4);
assert (!has_weight_zeros);

namespace kernel = torchao::kernels::cpu::aarch64::kleidi::
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32;

auto uk = kernel::get_ukernel();
config.mr = uk.get_mr();
config.nr = uk.get_nr();

config.activation_data_size_fn = &kernel::activation_data_size;
config.weight_data_size_fn = &kernel::weight_data_size;

config.preferred_activation_data_alignment = kernel::get_preferred_alignement();
config.preferred_weight_data_alignment = kernel::get_preferred_alignement();

config.prepare_activation_data_fn = &kernel::prepare_activation_data;
config.prepare_weight_data_fn = &kernel::prepare_weight_data;

config.kernel_fn = &kernel::kernel;
#else
assert (false);
#endif // TORCHAO_ENABLE_KLEIDI
}

return config;
}

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp, bool has_kleidi=false>
void test_linear_8bit_act_xbit_weight(int m, int n, int k, int group_size) {
auto ukernel_config =
get_ukernel_config<weight_nbit, has_weight_zeros, has_bias, has_clamp>();
auto ukernel_config =
get_ukernel_config<weight_nbit, has_weight_zeros, has_bias, has_clamp, has_kleidi>();

auto test_case = torchao::
channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate(
Expand All @@ -54,7 +85,8 @@ void test_linear_8bit_act_xbit_weight(int m, int n, int k, int group_size) {
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp);
has_clamp,
/*round_weight_scales_to_bf16=*/has_kleidi);

auto output = std::vector<float>(m * n);

Expand Down Expand Up @@ -230,3 +262,45 @@ TEST(test_linear_8bit_act_xbit_weight, GroupSizeNotDivisibleBy16) {
},
std::runtime_error);
}

#if defined(TORCHAO_ENABLE_KLEIDI)
TEST(test_linear_8bit_act_xbit_weight, KleidiSmall) {
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/1, /*n=*/2, /*k=*/32, /*group_size=*/32);
}

TEST(test_linear_8bit_act_xbit_weight, KleidiStandard) {
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/13, /*n=*/20, /*k=*/32, /*group_size=*/32);
}

TEST(test_linear_8bit_act_xbit_weight, KleidiHasClamp) {
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
true /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/17, /*n=*/10, /*k=*/32 * 2, /*group_size=*/32);
}

TEST(test_linear_8bit_act_xbit_weight, KleidiHasBias) {
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
true /*has_bias*/,
true /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/23, /*n=*/18, /*k=*/32 * 3, /*group_size=*/32);
}
#endif // TORCHAO_ENABLE_KLEIDI

0 comments on commit 186e578

Please sign in to comment.