diff --git a/crates/core_simd/src/swizzle_dyn.rs b/crates/core_simd/src/swizzle_dyn.rs index 8a1079042f0..b033f1cb87e 100644 --- a/crates/core_simd/src/swizzle_dyn.rs +++ b/crates/core_simd/src/swizzle_dyn.rs @@ -15,7 +15,7 @@ where /// A planned compiler improvement will enable using `#[target_feature]` instead. #[inline] pub fn swizzle_dyn(self, idxs: Simd) -> Self { - #![allow(unused_imports, unused_unsafe)] + #![allow(unused_imports, unused_unsafe, unreachable_patterns)] #[cfg(all( any(target_arch = "aarch64", target_arch = "arm64ec"), target_endian = "little" @@ -66,20 +66,146 @@ where // FIXME: initial AVX512VBMI variant didn't actually pass muster // #[cfg(target_feature = "avx512vbmi")] // 64 => transize(x86::_mm512_permutexvar_epi8, self, idxs), - _ => { - let mut array = [0; N]; - for (i, k) in idxs.to_array().into_iter().enumerate() { - if (k as usize) < N { - array[i] = self[k as usize]; - }; - } - array.into() - } + #[cfg(any( + all( + any( + target_arch = "aarch64", + target_arch = "arm64ec", + all(target_arch = "arm", target_feature = "v7") + ), + target_feature = "neon", + target_endian = "little" + ), + target_feature = "ssse3", + target_feature = "simd128" + ))] + _ => dispatch_compat(self, idxs), + _ => swizzle_dyn_scalar(self, idxs), } } } } +#[inline(always)] +fn swizzle_dyn_scalar(bytes: Simd, idxs: Simd) -> Simd +where + LaneCount: SupportedLaneCount, +{ + let mut array = [0; N]; + for (i, k) in idxs.to_array().into_iter().enumerate() { + if (k as usize) < N { + array[i] = bytes[k as usize]; + }; + } + array.into() +} + +/// Dispatch two swizzle_dyn_compat and swizzle_dyn_zext according to N. +/// Should only be called if the target architecture has a vectorized swizzle_dyn for some power-of-two size (e.g 8, 16). +#[inline(always)] +fn dispatch_compat(bytes: Simd, idxs: Simd) -> Simd +where + LaneCount: SupportedLaneCount, +{ + #![allow( + dead_code, + unused_unsafe, + unreachable_patterns, + non_contiguous_range_endpoints + )] + + // SAFETY: only unsafe usage is transize, see comment on transize + unsafe { + match N { + 5..16 => swizzle_dyn_zext::(bytes, idxs), + // only arm actually has 8-byte swizzle_dyn + #[cfg(all( + any( + target_arch = "aarch64", + target_arch = "arm64ec", + all(target_arch = "arm", target_feature = "v7") + ), + target_feature = "neon", + target_endian = "little" + ))] + 16 => transize(swizzle_dyn_compat::<16, 8>, bytes, idxs), + 17..32 => swizzle_dyn_zext::(bytes, idxs), + 32 => transize(swizzle_dyn_compat::<32, 16>, bytes, idxs), + 33..64 => swizzle_dyn_zext::(bytes, idxs), + 64 => transize(swizzle_dyn_compat::<64, 32>, bytes, idxs), + _ => swizzle_dyn_scalar(bytes, idxs), + } + } +} + +/// Implement swizzle_dyn for N by temporarily zero extending to N_EXT. +#[inline(always)] +#[allow(unused)] +fn swizzle_dyn_zext( + bytes: Simd, + idxs: Simd, +) -> Simd +where + LaneCount: SupportedLaneCount, + LaneCount: SupportedLaneCount, +{ + assert!(N_EXT.is_power_of_two(), "N_EXT should be power of two!"); + assert!(N < N_EXT, "N_EXT should be larger than N"); + Simd::swizzle_dyn(bytes.resize::(0), idxs.resize::(0)).resize::(0) +} + +/// "Downgrades" a swizzle_dyn op on N lanes to 4 swizzle_dyn ops on N/2 lanes. +/// +/// This only makes sense if swizzle_dyn actually has a vectorized implementation for a lower size (N/2, N/4, N/8, etc). +/// e.g. on x86, swizzle_dyn_compat for N=64 can be efficient if we have at least ssse3 for pshufb +/// +/// If there is no vectorized implementation for a lower size, +/// this runs in N*logN time and will be slower than the scalar implementation. +#[inline(always)] +#[allow(unused)] +fn swizzle_dyn_compat( + bytes: Simd, + idxs: Simd, +) -> Simd +where + LaneCount: SupportedLaneCount, + LaneCount: SupportedLaneCount, +{ + use crate::simd::cmp::SimdPartialOrd; + assert!(N.is_power_of_two(), "doesn't work for non-power-of-two N"); + assert!(N < u8::MAX as usize, "doesn't work for N >= 256"); + assert_eq!(N / 2, HALF_N, "HALF_N must equal N divided by two"); + + let mid = Simd::splat(HALF_N as u8); + + // unset the "mid" bit from the indices, e.g. 8..15 -> 0..7, 16..31 -> 8..15, + // ensuring that a half-swizzle on the higher half of `bytes` will select the correct indices + // since N is a power of two, any zeroing indices will remain zeroing + let idxs_trunc = idxs & !mid; + + let idx_lo = Simd::::from_slice(&idxs_trunc[..HALF_N]); + let idx_hi = Simd::::from_slice(&idxs_trunc[HALF_N..]); + + let bytes_lo = Simd::::from_slice(&bytes[..HALF_N]); + let bytes_hi = Simd::::from_slice(&bytes[HALF_N..]); + + macro_rules! half_swizzle { + ($bytes:ident) => {{ + let lo = Simd::swizzle_dyn($bytes, idx_lo); + let hi = Simd::swizzle_dyn($bytes, idx_hi); + + let mut res = [0; N]; + res[..HALF_N].copy_from_slice(&lo[..]); + res[HALF_N..].copy_from_slice(&hi[..]); + Simd::from_array(res) + }}; + } + + let result_lo = half_swizzle!(bytes_lo); + let result_hi = half_swizzle!(bytes_hi); + idxs.simd_lt(mid).select(result_lo, result_hi) +} + /// "vpshufb like it was meant to be" on AVX2 /// /// # Safety