Skip to content

Commit

Permalink
xnn_f16_vprelu_ukernel__avx512fp16_u64 asan fix
Browse files Browse the repository at this point in the history
- use maskz_load for remainder handler

PiperOrigin-RevId: 704522366
  • Loading branch information
fbarchard authored and xnnpack-bot committed Dec 10, 2024
1 parent 8d512f2 commit 6c4a68c
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/f16-vbinary/gen/f16-vprelu-avx512fp16-u32.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void xnn_f16_vprelu_ukernel__avx512fp16_u32(
const __m512h va = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, a));

const __mmask32 vsign = _mm512_cmp_ph_mask(va, vzero, _CMP_LT_OQ);
__m512h vacc = _mm512_mask_mul_ph(va, vsign, va, _mm512_loadu_ph(b));
__m512h vacc = _mm512_mask_mul_ph(va, vsign, va, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, b)));

_mm512_mask_storeu_epi16(o, vmask, _mm512_castph_si512(vacc));
}
Expand Down
2 changes: 1 addition & 1 deletion src/f16-vbinary/gen/f16-vprelu-avx512fp16-u64.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void xnn_f16_vprelu_ukernel__avx512fp16_u64(
const __m512h va = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, a));

const __mmask32 vsign = _mm512_cmp_ph_mask(va, vzero, _CMP_LT_OQ);
__m512h vacc = _mm512_mask_mul_ph(va, vsign, va, _mm512_loadu_ph(b));
__m512h vacc = _mm512_mask_mul_ph(va, vsign, va, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, b)));

_mm512_mask_storeu_epi16(o, vmask, _mm512_castph_si512(vacc));
}
Expand Down
2 changes: 1 addition & 1 deletion src/f16-vbinary/vop-avx512fp16.c.in
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ void xnn_f16_v${OP.lower()}_ukernel__avx512fp16_u${BATCH_TILE}(

$if OP == "PRELU":
const __mmask32 vsign = _mm512_cmp_ph_mask(va, vzero, _CMP_LT_OQ);
__m512h vacc = _mm512_mask_mul_ph(va, vsign, va, _mm512_loadu_ph(b));
__m512h vacc = _mm512_mask_mul_ph(va, vsign, va, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, b)));
$else:
__m512h vacc = ${_MM512_MASKZ_OP_ph}(vmask, va, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, b)));

Expand Down

0 comments on commit 6c4a68c

Please sign in to comment.