Skip to content

Commit

Permalink
AMX GEMM/IGEMM microkernels use __msan_unpoison on tile buffers
Browse files Browse the repository at this point in the history
- Change res from 4 tile buffers into array of 4 tile buffers

PiperOrigin-RevId: 702144715
  • Loading branch information
fbarchard authored and xnnpack-bot committed Dec 3, 2024
1 parent ecf886d commit 51a0103
Show file tree
Hide file tree
Showing 70 changed files with 3,722 additions and 2,706 deletions.

Large diffs are not rendered by default.

175 changes: 94 additions & 81 deletions src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-16x64c4-minmax-avx512amx.c

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#if defined(__has_feature)
#if __has_feature(memory_sanitizer)
#include <sanitizer/msan_interface.h>
#endif
#endif

#include <immintrin.h>

Expand Down Expand Up @@ -43,10 +48,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x64c4__avx512amx(
// TODO: amxintrin.h only provide intrinsics for __x86_64__
// Update if amxintrin changes
#if defined(__x86_64__)
__attribute__((aligned(64))) int32_t res0[1 * 16];
__attribute__((aligned(64))) int32_t res1[1 * 16];
__attribute__((aligned(64))) int32_t res2[1 * 16];
__attribute__((aligned(64))) int32_t res3[1 * 16];
__attribute__((aligned(64))) int32_t res[4][1 * 16];

kc = round_up_po2(kc, 4 * sizeof(int8_t));
const size_t kremainder = (kc & 63) ? (kc & 63) : 64;
Expand All @@ -65,19 +67,19 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x64c4__avx512amx(
// Load tile configuration
__attribute__((aligned(64))) struct __tile_config tile_data = {0};
tile_data.palette_id = 1;
tile_data.rows[0] = mr; // tmm0 = res 0
tile_data.rows[1] = mr; // tmm1 = res 1
tile_data.rows[2] = mr; // tmm2 = res 2
tile_data.rows[3] = mr; // tmm3 = res 3
tile_data.rows[0] = mr; // tmm0 = res[0]
tile_data.rows[1] = mr; // tmm1 = res[1]
tile_data.rows[2] = mr; // tmm2 = res[2]
tile_data.rows[3] = mr; // tmm3 = res[3]
tile_data.rows[4] = mr; // tmm4 = input
tile_data.rows[5] = 16; // tmm5 = weights
tile_data.rows[6] = mr; // tmm6 = input remainder
tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder

tile_data.colsb[0] = 64; // tmm0 = res 0
tile_data.colsb[1] = 64; // tmm1 = res 1
tile_data.colsb[2] = 64; // tmm2 = res 1
tile_data.colsb[3] = 64; // tmm3 = res 1
tile_data.colsb[0] = 64; // tmm0 = res[0]
tile_data.colsb[1] = 64; // tmm1 = res[1]
tile_data.colsb[2] = 64; // tmm2 = res[2]
tile_data.colsb[3] = 64; // tmm3 = res[3]
tile_data.colsb[4] = 64; // tmm4 = input
tile_data.colsb[5] = 64; // tmm5 = weights
tile_data.colsb[6] = kremainder; // tmm6 = input remainder
Expand Down Expand Up @@ -141,20 +143,31 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x64c4__avx512amx(
k -= kremainder * sizeof(int8_t);
}

// Add tile to bias
_tile_stored(0, res0, 64);
_tile_stored(1, res1, 64);
_tile_stored(2, res2, 64);
_tile_stored(3, res3, 64);

// TODO: Instead of processing up to 4 tiles (16x64) consider
// quantizing 1 tile at a time (16 registers)
_tile_stored(0, &res[0][0], 64);
_tile_stored(1, &res[1][0], 64);
_tile_stored(2, &res[2][0], 64);
_tile_stored(3, &res[3][0], 64);

// TODO: Fix msan for AMX
#if defined(__has_feature)
#if __has_feature(memory_sanitizer)
__msan_unpoison(res, sizeof(res));
#endif
#endif

// TODO: Instead of processing up to 4 tiles (16x64) consider
// quantizing 1 row at a time.
__m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point));
__m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point));
__m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point));
__m512i vacc0xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[0].zero_point));
vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0));
vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0));
vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0));
vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0));
// Add tile to bias
vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0));
vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0));
vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0));
vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0));

__m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF);
__m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV);
Expand Down
103 changes: 58 additions & 45 deletions src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x64c4-minmax-avx512amx.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#if defined(__has_feature)
#if __has_feature(memory_sanitizer)
#include <sanitizer/msan_interface.h>
#endif
#endif

#include <immintrin.h>

Expand Down Expand Up @@ -43,10 +48,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx(
// TODO: amxintrin.h only provide intrinsics for __x86_64__
// Update if amxintrin changes
#if defined(__x86_64__)
__attribute__((aligned(64))) int32_t res0[7 * 16];
__attribute__((aligned(64))) int32_t res1[7 * 16];
__attribute__((aligned(64))) int32_t res2[7 * 16];
__attribute__((aligned(64))) int32_t res3[7 * 16];
__attribute__((aligned(64))) int32_t res[4][7 * 16];

kc = round_up_po2(kc, 4 * sizeof(int8_t));
const size_t kremainder = (kc & 63) ? (kc & 63) : 64;
Expand All @@ -65,19 +67,19 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx(
// Load tile configuration
__attribute__((aligned(64))) struct __tile_config tile_data = {0};
tile_data.palette_id = 1;
tile_data.rows[0] = mr; // tmm0 = res 0
tile_data.rows[1] = mr; // tmm1 = res 1
tile_data.rows[2] = mr; // tmm2 = res 2
tile_data.rows[3] = mr; // tmm3 = res 3
tile_data.rows[0] = mr; // tmm0 = res[0]
tile_data.rows[1] = mr; // tmm1 = res[1]
tile_data.rows[2] = mr; // tmm2 = res[2]
tile_data.rows[3] = mr; // tmm3 = res[3]
tile_data.rows[4] = mr; // tmm4 = input
tile_data.rows[5] = 16; // tmm5 = weights
tile_data.rows[6] = mr; // tmm6 = input remainder
tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder

tile_data.colsb[0] = 64; // tmm0 = res 0
tile_data.colsb[1] = 64; // tmm1 = res 1
tile_data.colsb[2] = 64; // tmm2 = res 1
tile_data.colsb[3] = 64; // tmm3 = res 1
tile_data.colsb[0] = 64; // tmm0 = res[0]
tile_data.colsb[1] = 64; // tmm1 = res[1]
tile_data.colsb[2] = 64; // tmm2 = res[2]
tile_data.colsb[3] = 64; // tmm3 = res[3]
tile_data.colsb[4] = 64; // tmm4 = input
tile_data.colsb[5] = 64; // tmm5 = weights
tile_data.colsb[6] = kremainder; // tmm6 = input remainder
Expand Down Expand Up @@ -165,12 +167,22 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx(
k -= kremainder * sizeof(int8_t);
}

// Add tile to bias
_tile_stored(0, res0, 64);
_tile_stored(1, res1, 64);
_tile_stored(2, res2, 64);
_tile_stored(3, res3, 64);
// TODO: Instead of processing up to 4 tiles (16x64) consider
// quantizing 1 tile at a time (16 registers)
_tile_stored(0, &res[0][0], 64);
_tile_stored(1, &res[1][0], 64);
_tile_stored(2, &res[2][0], 64);
_tile_stored(3, &res[3][0], 64);

// TODO: Fix msan for AMX
#if defined(__has_feature)
#if __has_feature(memory_sanitizer)
__msan_unpoison(res, sizeof(res));
#endif
#endif

// TODO: Instead of processing up to 4 tiles (16x64) consider
// quantizing 1 row at a time.
__m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point));
__m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point));
__m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point));
Expand Down Expand Up @@ -199,34 +211,35 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx(
__m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[6].zero_point));
__m512i vacc6xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[6].zero_point));
__m512i vacc6xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[6].zero_point));
vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0));
vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0));
vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0));
vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0));
vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16));
vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16));
vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16));
vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16));
vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32));
vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32));
vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32));
vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32));
vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48));
vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48));
vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48));
vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48));
vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64));
vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64));
vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64));
vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64));
vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80));
vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80));
vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80));
vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80));
vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96));
vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96));
vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96));
vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96));
// Add tile to bias
vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0));
vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0));
vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0));
vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0));
vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16));
vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16));
vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16));
vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16));
vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32));
vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32));
vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32));
vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32));
vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48));
vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48));
vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48));
vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48));
vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64));
vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64));
vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64));
vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64));
vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80));
vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80));
vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80));
vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80));
vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96));
vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96));
vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96));
vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96));

__m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF);
__m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV);
Expand Down
Loading

0 comments on commit 51a0103

Please sign in to comment.