Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fallback to smaller vector size for swizzle_dyn #433

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 139 additions & 12 deletions crates/core_simd/src/swizzle_dyn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ where
/// A planned compiler improvement will enable using `#[target_feature]` instead.
#[inline]
pub fn swizzle_dyn(self, idxs: Simd<u8, N>) -> Self {
#![allow(unused_imports, unused_unsafe)]
#![allow(unused_imports, unused_unsafe, unreachable_patterns)]
#[cfg(all(
any(target_arch = "aarch64", target_arch = "arm64ec"),
target_endian = "little"
Expand Down Expand Up @@ -57,8 +57,6 @@ where
target_endian = "little"
))]
16 => transize(vqtbl1q_u8, self, idxs),
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi")))]
32 => transize(avx2_pshufb, self, idxs),
#[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))]
32 => {
// Unlike vpshufb, vpermb doesn't zero out values in the result based on the index high bit
Expand All @@ -71,6 +69,8 @@ where
};
transize(swizzler, self, idxs)
}
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi")))]
32 => transize(avx2_pshufb, self, idxs),
// Notable absence: avx512bw pshufb shuffle
#[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))]
64 => {
Expand All @@ -84,20 +84,147 @@ where
};
transize(swizzler, self, idxs)
}
_ => {
let mut array = [0; N];
for (i, k) in idxs.to_array().into_iter().enumerate() {
if (k as usize) < N {
array[i] = self[k as usize];
};
}
array.into()
}
#[cfg(any(
all(
any(
target_arch = "aarch64",
target_arch = "arm64ec",
all(target_arch = "arm", target_feature = "v7")
),
target_feature = "neon",
target_endian = "little"
),
target_feature = "ssse3",
target_feature = "simd128"
))]
_ => dispatch_compat(self, idxs),
_ => swizzle_dyn_scalar(self, idxs),
}
}
}
}

#[inline(always)]
fn swizzle_dyn_scalar<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
{
let mut array = [0; N];
for (i, k) in idxs.to_array().into_iter().enumerate() {
if (k as usize) < N {
array[i] = bytes[k as usize];
};
}
array.into()
}

/// Dispatch to swizzle_dyn_compat and swizzle_dyn_zext according to N.
/// Should only be called if there exists some power-of-two size for which
/// the target architecture has a vectorized swizzle_dyn (e.g. pshufb, vqtbl).
#[inline(always)]
fn dispatch_compat<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
{
#![allow(
dead_code,
unused_unsafe,
unreachable_patterns,
non_contiguous_range_endpoints
)]

// SAFETY: only unsafe usage is transize, see comment on transize
unsafe {
match N {
5..16 => swizzle_dyn_zext::<N, 16>(bytes, idxs),
// only arm actually has 8-byte swizzle_dyn
#[cfg(all(
any(
target_arch = "aarch64",
target_arch = "arm64ec",
all(target_arch = "arm", target_feature = "v7")
),
target_feature = "neon",
target_endian = "little"
))]
16 => transize(swizzle_dyn_compat::<16, 8>, bytes, idxs),
17..32 => swizzle_dyn_zext::<N, 32>(bytes, idxs),
32 => transize(swizzle_dyn_compat::<32, 16>, bytes, idxs),
33..64 => swizzle_dyn_zext::<N, 64>(bytes, idxs),
64 => transize(swizzle_dyn_compat::<64, 32>, bytes, idxs),
_ => swizzle_dyn_scalar(bytes, idxs),
}
}
}

/// Implement swizzle_dyn for N by temporarily zero extending to N_EXT.
#[inline(always)]
#[allow(unused)]
fn swizzle_dyn_zext<const N: usize, const N_EXT: usize>(
bytes: Simd<u8, N>,
idxs: Simd<u8, N>,
) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
LaneCount<N_EXT>: SupportedLaneCount,
{
assert!(N_EXT.is_power_of_two(), "N_EXT should be power of two!");
assert!(N < N_EXT, "N_EXT should be larger than N");
Simd::swizzle_dyn(bytes.resize::<N_EXT>(0), idxs.resize::<N_EXT>(0)).resize::<N>(0)
}

/// "Downgrades" a swizzle_dyn op on N lanes to 4 swizzle_dyn ops on N/2 lanes.
///
/// This only makes sense if swizzle_dyn actually has a vectorized implementation for a lower size (N/2, N/4, N/8, etc).
/// e.g. on x86, swizzle_dyn_compat for N=64 can be efficient if we have at least ssse3 for pshufb
///
/// If there is no vectorized implementation for a lower size,
/// this runs in N*logN time and will be slower than the scalar implementation.
#[inline(always)]
#[allow(unused)]
fn swizzle_dyn_compat<const N: usize, const HALF_N: usize>(
bytes: Simd<u8, N>,
idxs: Simd<u8, N>,
) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
LaneCount<HALF_N>: SupportedLaneCount,
{
use crate::simd::cmp::SimdPartialOrd;
assert!(N.is_power_of_two(), "doesn't work for non-power-of-two N");
assert!(N < u8::MAX as usize, "doesn't work for N >= 256");
assert_eq!(N / 2, HALF_N, "HALF_N must equal N divided by two");

let mid = Simd::splat(HALF_N as u8);

// unset the "mid" bit from the indices, e.g. 8..15 -> 0..7, 16..31 -> 8..15,
// ensuring that a half-swizzle on the higher half of `bytes` will select the correct indices
// since N is a power of two, any zeroing indices will remain zeroing
let idxs_trunc = idxs & !mid;

let idx_lo = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[..HALF_N]);
let idx_hi = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[HALF_N..]);

let bytes_lo = Simd::<u8, HALF_N>::from_slice(&bytes[..HALF_N]);
let bytes_hi = Simd::<u8, HALF_N>::from_slice(&bytes[HALF_N..]);

macro_rules! half_swizzle {
($bytes:ident) => {{
let lo = Simd::swizzle_dyn($bytes, idx_lo);
let hi = Simd::swizzle_dyn($bytes, idx_hi);

let mut res = [0; N];
res[..HALF_N].copy_from_slice(&lo[..]);
res[HALF_N..].copy_from_slice(&hi[..]);
Simd::from_array(res)
}};
}

let result_lo = half_swizzle!(bytes_lo);
let result_hi = half_swizzle!(bytes_hi);
idxs.simd_lt(mid).select(result_lo, result_hi)
}

/// "vpshufb like it was meant to be" on AVX2
///
/// # Safety
Expand Down
Loading