Skip to content

Commit

Permalink
Implement fallback to smaller vector size for swizzle_dyn
Browse files Browse the repository at this point in the history
  • Loading branch information
cvijdea-bd committed Aug 25, 2024
1 parent 4697d39 commit 169c541
Showing 1 changed file with 136 additions and 10 deletions.
146 changes: 136 additions & 10 deletions crates/core_simd/src/swizzle_dyn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ where
/// A planned compiler improvement will enable using `#[target_feature]` instead.
#[inline]
pub fn swizzle_dyn(self, idxs: Simd<u8, N>) -> Self {
#![allow(unused_imports, unused_unsafe)]
#![allow(unused_imports, unused_unsafe, unreachable_patterns)]
#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm64ec"),
target_endian = "little"
Expand Down Expand Up @@ -66,20 +66,146 @@ where
// FIXME: initial AVX512VBMI variant didn't actually pass muster
// #[cfg(target_feature = "avx512vbmi")]
// 64 => transize(x86::_mm512_permutexvar_epi8, self, idxs),
_ => {
let mut array = [0; N];
for (i, k) in idxs.to_array().into_iter().enumerate() {
if (k as usize) < N {
array[i] = self[k as usize];
};
}
array.into()
}
#[cfg(any(
all(
any(
target_arch = "aarch64",
target_arch = "arm64ec",
all(target_arch = "arm", target_feature = "v7")
),
target_feature = "neon",
target_endian = "little"
),
target_feature = "ssse3",
target_feature = "simd128"
))]
_ => dispatch_compat(self, idxs),
_ => swizzle_dyn_scalar(self, idxs),
}
}
}
}

#[inline(always)]
fn swizzle_dyn_scalar<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
{
let mut array = [0; N];
for (i, k) in idxs.to_array().into_iter().enumerate() {
if (k as usize) < N {
array[i] = bytes[k as usize];
};
}
array.into()
}

/// Dispatch two swizzle_dyn_compat and swizzle_dyn_zext according to N.
/// Should only be called if the target architecture has a vectorized swizzle_dyn for some power-of-two size (e.g 8, 16).
#[inline(always)]
fn dispatch_compat<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
{
#![allow(
dead_code,
unused_unsafe,
unreachable_patterns,
non_contiguous_range_endpoints
)]

// SAFETY: only unsafe usage is transize, see comment on transize
unsafe {
match N {
5..16 => swizzle_dyn_zext::<N, 16>(bytes, idxs),
// only arm actually has 8-byte swizzle_dyn
#[cfg(all(
any(
target_arch = "aarch64",
target_arch = "arm64ec",
all(target_arch = "arm", target_feature = "v7")
),
target_feature = "neon",
target_endian = "little"
))]
16 => transize(swizzle_dyn_compat::<16, 8>, bytes, idxs),
17..32 => swizzle_dyn_zext::<N, 32>(bytes, idxs),
32 => transize(swizzle_dyn_compat::<32, 16>, bytes, idxs),
33..64 => swizzle_dyn_zext::<N, 64>(bytes, idxs),
64 => transize(swizzle_dyn_compat::<64, 32>, bytes, idxs),
_ => swizzle_dyn_scalar(bytes, idxs),
}
}
}

/// Implement swizzle_dyn for N by temporarily zero extending to N_EXT.
#[inline(always)]
#[allow(unused)]
fn swizzle_dyn_zext<const N: usize, const N_EXT: usize>(
bytes: Simd<u8, N>,
idxs: Simd<u8, N>,
) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
LaneCount<N_EXT>: SupportedLaneCount,
{
assert!(N_EXT.is_power_of_two(), "N_EXT should be power of two!");
assert!(N < N_EXT, "N_EXT should be larger than N");
Simd::swizzle_dyn(bytes.resize::<N_EXT>(0), idxs.resize::<N_EXT>(0)).resize::<N>(0)
}

/// "Downgrades" a swizzle_dyn op on N lanes to 4 swizzle_dyn ops on N/2 lanes.
///
/// This only makes sense if swizzle_dyn actually has a vectorized implementation for a lower size (N/2, N/4, N/8, etc).
/// e.g. on x86, swizzle_dyn_compat for N=64 can be efficient if we have at least ssse3 for pshufb
///
/// If there is no vectorized implementation for a lower size,
/// this runs in N*logN time and will be slower than the scalar implementation.
#[inline(always)]
#[allow(unused)]
fn swizzle_dyn_compat<const N: usize, const HALF_N: usize>(
bytes: Simd<u8, N>,
idxs: Simd<u8, N>,
) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
LaneCount<HALF_N>: SupportedLaneCount,
{
use crate::simd::cmp::SimdPartialOrd;
assert!(N.is_power_of_two(), "doesn't work for non-power-of-two N");
assert!(N < u8::MAX as usize, "doesn't work for N >= 256");
assert_eq!(N / 2, HALF_N, "HALF_N must equal N divided by two");

let mid = Simd::splat(HALF_N as u8);

// unset the "mid" bit from the indices, e.g. 8..15 -> 0..7, 16..31 -> 8..15,
// ensuring that a half-swizzle on the higher half of `bytes` will select the correct indices
// since N is a power of two, any zeroing indices will remain zeroing
let idxs_trunc = idxs & !mid;

let idx_lo = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[..HALF_N]);
let idx_hi = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[HALF_N..]);

let bytes_lo = Simd::<u8, HALF_N>::from_slice(&bytes[..HALF_N]);
let bytes_hi = Simd::<u8, HALF_N>::from_slice(&bytes[HALF_N..]);

macro_rules! half_swizzle {
($bytes:ident) => {{
let lo = Simd::swizzle_dyn($bytes, idx_lo);
let hi = Simd::swizzle_dyn($bytes, idx_hi);

let mut res = [0; N];
res[..HALF_N].copy_from_slice(&lo[..]);
res[HALF_N..].copy_from_slice(&hi[..]);
Simd::from_array(res)
}};
}

let result_lo = half_swizzle!(bytes_lo);
let result_hi = half_swizzle!(bytes_hi);
idxs.simd_lt(mid).select(result_lo, result_hi)
}

/// "vpshufb like it was meant to be" on AVX2
///
/// # Safety
Expand Down

0 comments on commit 169c541

Please sign in to comment.