Skip to content

Commit

Permalink
Move multiversioned functions outside of Searcher trait
Browse files Browse the repository at this point in the history
  • Loading branch information
marmeladema committed Jun 29, 2024
1 parent b18a60d commit dbab616
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 110 deletions.
10 changes: 6 additions & 4 deletions src/aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,23 +227,25 @@ impl<N: Needle> NeonSearcher<N> {
#[inline]
unsafe fn neon_2_search_in(&self, haystack: &[u8], end: usize) -> bool {
let hash = VectorHash::<uint8x2_t>::from(&self.neon_half_hash);
self.vector_search_in_neon_version(haystack, end, &hash)
crate::vector_search_in_neon_version(self.needle(), self.position(), haystack, end, &hash)
}

#[inline]
unsafe fn neon_4_search_in(&self, haystack: &[u8], end: usize) -> bool {
let hash = VectorHash::<uint8x4_t>::from(&self.neon_half_hash);
self.vector_search_in_neon_version(haystack, end, &hash)
crate::vector_search_in_neon_version(self.needle(), self.position(), haystack, end, &hash)
}

#[inline]
unsafe fn neon_8_search_in(&self, haystack: &[u8], end: usize) -> bool {
self.vector_search_in_neon_version(haystack, end, &self.neon_half_hash)
let hash = &self.neon_half_hash;
crate::vector_search_in_neon_version(self.needle(), self.position(), haystack, end, hash)
}

#[inline]
unsafe fn neon_search_in(&self, haystack: &[u8], end: usize) -> bool {
self.vector_search_in_neon_version(haystack, end, &self.neon_hash)
let hash = &self.neon_hash;
crate::vector_search_in_neon_version(self.needle(), self.position(), haystack, end, hash)
}

/// Inlined version of `search_in` for hot call sites.
Expand Down
177 changes: 91 additions & 86 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,6 @@ trait NeedleWithSize: Needle {
self.as_bytes().len()
}
}

#[inline]
fn is_empty(&self) -> bool {
self.size() == 0
}
}

impl<N: Needle + ?Sized> NeedleWithSize for N {}
Expand Down Expand Up @@ -192,96 +187,106 @@ impl<T: Vector, V: Vector + From<T>> From<&VectorHash<T>> for VectorHash<V> {
}
}

trait Searcher<N: NeedleWithSize + ?Sized> {
fn needle(&self) -> &N;
#[multiversion::multiversion]
#[clone(target = "[x86|x86_64]+avx2")]
#[clone(target = "wasm32+simd128")]
#[clone(target = "aarch64+neon")]
unsafe fn vector_search_in_chunk<N: NeedleWithSize + ?Sized, V: Vector>(
needle: &N,
position: usize,
hash: &VectorHash<V>,
start: *const u8,
mask: u32,
) -> bool {
let first = V::load(start);
let last = V::load(start.add(position));

let eq_first = V::lanes_eq(hash.first, first);
let eq_last = V::lanes_eq(hash.last, last);

let eq = V::bitwise_and(eq_first, eq_last);
let mut eq = V::to_bitmask(eq) & mask;

let chunk = start.add(1);
let size = needle.size() - 1;
let needle = needle.as_bytes().as_ptr().add(1);

while eq != 0 {
let chunk = chunk.add(eq.trailing_zeros() as usize);
let equal = match N::SIZE {
Some(0) => unreachable!(),
Some(1) => dispatch!(memcmp::specialized::<0>(chunk, needle)),
Some(2) => dispatch!(memcmp::specialized::<1>(chunk, needle)),
Some(3) => dispatch!(memcmp::specialized::<2>(chunk, needle)),
Some(4) => dispatch!(memcmp::specialized::<3>(chunk, needle)),
Some(5) => dispatch!(memcmp::specialized::<4>(chunk, needle)),
Some(6) => dispatch!(memcmp::specialized::<5>(chunk, needle)),
Some(7) => dispatch!(memcmp::specialized::<6>(chunk, needle)),
Some(8) => dispatch!(memcmp::specialized::<7>(chunk, needle)),
Some(9) => dispatch!(memcmp::specialized::<8>(chunk, needle)),
Some(10) => dispatch!(memcmp::specialized::<9>(chunk, needle)),
Some(11) => dispatch!(memcmp::specialized::<10>(chunk, needle)),
Some(12) => dispatch!(memcmp::specialized::<11>(chunk, needle)),
Some(13) => dispatch!(memcmp::specialized::<12>(chunk, needle)),
Some(14) => dispatch!(memcmp::specialized::<13>(chunk, needle)),
Some(15) => dispatch!(memcmp::specialized::<14>(chunk, needle)),
Some(16) => dispatch!(memcmp::specialized::<15>(chunk, needle)),
_ => dispatch!(memcmp::generic(chunk, needle, size)),
};
if equal {
return true;
}

fn position(&self) -> usize;
eq = dispatch!(bits::clear_leftmost_set(eq));
}

#[multiversion::multiversion]
#[clone(target = "[x86|x86_64]+avx2")]
#[clone(target = "wasm32+simd128")]
#[clone(target = "aarch64+neon")]
unsafe fn vector_search_in_chunk<V: Vector>(
&self,
hash: &VectorHash<V>,
start: *const u8,
mask: u32,
) -> bool {
let first = V::load(start);
let last = V::load(start.add(self.position()));

let eq_first = V::lanes_eq(hash.first, first);
let eq_last = V::lanes_eq(hash.last, last);

let eq = V::bitwise_and(eq_first, eq_last);
let mut eq = V::to_bitmask(eq) & mask;

let chunk = start.add(1);
let needle = self.needle().as_bytes().as_ptr().add(1);

while eq != 0 {
let chunk = chunk.add(eq.trailing_zeros() as usize);
let equal = match N::SIZE {
Some(0) => unreachable!(),
Some(1) => dispatch!(memcmp::specialized::<0>(chunk, needle)),
Some(2) => dispatch!(memcmp::specialized::<1>(chunk, needle)),
Some(3) => dispatch!(memcmp::specialized::<2>(chunk, needle)),
Some(4) => dispatch!(memcmp::specialized::<3>(chunk, needle)),
Some(5) => dispatch!(memcmp::specialized::<4>(chunk, needle)),
Some(6) => dispatch!(memcmp::specialized::<5>(chunk, needle)),
Some(7) => dispatch!(memcmp::specialized::<6>(chunk, needle)),
Some(8) => dispatch!(memcmp::specialized::<7>(chunk, needle)),
Some(9) => dispatch!(memcmp::specialized::<8>(chunk, needle)),
Some(10) => dispatch!(memcmp::specialized::<9>(chunk, needle)),
Some(11) => dispatch!(memcmp::specialized::<10>(chunk, needle)),
Some(12) => dispatch!(memcmp::specialized::<11>(chunk, needle)),
Some(13) => dispatch!(memcmp::specialized::<12>(chunk, needle)),
Some(14) => dispatch!(memcmp::specialized::<13>(chunk, needle)),
Some(15) => dispatch!(memcmp::specialized::<14>(chunk, needle)),
Some(16) => dispatch!(memcmp::specialized::<15>(chunk, needle)),
_ => dispatch!(memcmp::generic(chunk, needle, self.needle().size() - 1)),
};
if equal {
return true;
}
false
}

eq = dispatch!(bits::clear_leftmost_set(eq));
#[allow(dead_code)]
#[multiversion::multiversion]
#[clone(target = "[x86|x86_64]+avx2")]
#[clone(target = "wasm32+simd128")]
#[clone(target = "aarch64+neon")]
pub(crate) unsafe fn vector_search_in<N: NeedleWithSize + ?Sized, V: Vector>(
needle: &N,
position: usize,
haystack: &[u8],
end: usize,
hash: &VectorHash<V>,
) -> bool {
debug_assert!(haystack.len() >= needle.size());

let mut chunks = haystack[..end].chunks_exact(V::LANES);
for chunk in &mut chunks {
if dispatch!(vector_search_in_chunk(
needle,
position,
hash,
chunk.as_ptr(),
u32::MAX
)) {
return true;
}
}

false
}

#[multiversion::multiversion]
#[clone(target = "[x86|x86_64]+avx2")]
#[clone(target = "wasm32+simd128")]
#[clone(target = "aarch64+neon")]
unsafe fn vector_search_in<V: Vector>(
&self,
haystack: &[u8],
end: usize,
hash: &VectorHash<V>,
) -> bool {
debug_assert!(haystack.len() >= self.needle().size());

let mut chunks = haystack[..end].chunks_exact(V::LANES);
for chunk in &mut chunks {
if dispatch!(self.vector_search_in_chunk(hash, chunk.as_ptr(), u32::MAX)) {
return true;
}
let remainder = chunks.remainder().len();
if remainder > 0 {
let start = haystack.as_ptr().add(end - V::LANES);
let mask = u32::MAX << (V::LANES - remainder);

if dispatch!(vector_search_in_chunk(needle, position, hash, start, mask)) {
return true;
}
}

let remainder = chunks.remainder().len();
if remainder > 0 {
let start = haystack.as_ptr().add(end - V::LANES);
let mask = u32::MAX << (V::LANES - remainder);
false
}

if dispatch!(self.vector_search_in_chunk(hash, start, mask)) {
return true;
}
}
trait Searcher<N: NeedleWithSize + ?Sized> {
fn needle(&self) -> &N;

false
}
fn position(&self) -> usize;
}

#[cfg(test)]
Expand Down
35 changes: 27 additions & 8 deletions src/stdsimd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,28 +129,47 @@ impl<N: Needle> StdSimdSearcher<N> {
/// Inlined version of `search_in` for hot call sites.
#[inline]
pub fn inlined_search_in(&self, haystack: &[u8]) -> bool {
if haystack.len() <= self.needle.size() {
return haystack == self.needle.as_bytes();
let needle = self.needle();

if haystack.len() <= needle.size() {
return haystack == needle.as_bytes();
}

let end = haystack.len() - self.needle.size() + 1;
let position = self.position();
let end = haystack.len() - needle.size() + 1;

if end < Simd2::LANES {
unreachable!();
} else if end < Simd4::LANES {
let hash = from_hash::<32, 2>(&self.simd32_hash);
unsafe { self.vector_search_in_default_version(haystack, end, &hash) }
unsafe {
crate::vector_search_in_default_version(needle, position, haystack, end, &hash)
}
} else if end < Simd8::LANES {
let hash = from_hash::<32, 4>(&self.simd32_hash);
unsafe { self.vector_search_in_default_version(haystack, end, &hash) }
unsafe {
crate::vector_search_in_default_version(needle, position, haystack, end, &hash)
}
} else if end < Simd16::LANES {
let hash = from_hash::<32, 8>(&self.simd32_hash);
unsafe { self.vector_search_in_default_version(haystack, end, &hash) }
unsafe {
crate::vector_search_in_default_version(needle, position, haystack, end, &hash)
}
} else if end < Simd32::LANES {
let hash = from_hash::<32, 16>(&self.simd32_hash);
unsafe { self.vector_search_in_default_version(haystack, end, &hash) }
unsafe {
crate::vector_search_in_default_version(needle, position, haystack, end, &hash)
}
} else {
unsafe { self.vector_search_in_default_version(haystack, end, &self.simd32_hash) }
unsafe {
crate::vector_search_in_default_version(
needle,
position,
haystack,
end,
&self.simd32_hash,
)
}
}
}

Expand Down
18 changes: 11 additions & 7 deletions src/wasm32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,25 +242,29 @@ impl<N: Needle> Wasm32Searcher<N> {
#[inline]
#[target_feature(enable = "simd128")]
pub unsafe fn inlined_search_in(&self, haystack: &[u8]) -> bool {
if haystack.len() <= self.needle.size() {
return haystack == self.needle.as_bytes();
let needle = self.needle();

if haystack.len() <= needle.size() {
return haystack == needle.as_bytes();
}

let end = haystack.len() - self.needle.size() + 1;
let position = self.position();
let end = haystack.len() - needle.size() + 1;

if end < v16::LANES {
unreachable!();
} else if end < v32::LANES {
let hash = VectorHash::<v16>::from(&self.v128_hash);
self.vector_search_in_simd128_version(haystack, end, &hash)
crate::vector_search_in_simd128_version(needle, position, haystack, end, &hash)
} else if end < v64::LANES {
let hash = VectorHash::<v32>::from(&self.v128_hash);
self.vector_search_in_simd128_version(haystack, end, &hash)
crate::vector_search_in_simd128_version(needle, position, haystack, end, &hash)
} else if end < v128::LANES {
let hash = VectorHash::<v64>::from(&self.v128_hash);
self.vector_search_in_simd128_version(haystack, end, &hash)
crate::vector_search_in_simd128_version(needle, position, haystack, end, &hash)
} else {
self.vector_search_in_simd128_version(haystack, end, &self.v128_hash)
let hash = &self.v128_hash;
crate::vector_search_in_simd128_version(needle, position, haystack, end, hash)
}
}

Expand Down
12 changes: 7 additions & 5 deletions src/x86.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,33 +319,35 @@ impl<N: Needle> Avx2Searcher<N> {
#[target_feature(enable = "avx2")]
unsafe fn sse2_2_search_in(&self, haystack: &[u8], end: usize) -> bool {
let hash = VectorHash::<__m16i>::from(&self.sse2_hash);
self.vector_search_in_avx2_version(haystack, end, &hash)
crate::vector_search_in_avx2_version(self.needle(), self.position(), haystack, end, &hash)
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn sse2_4_search_in(&self, haystack: &[u8], end: usize) -> bool {
let hash = VectorHash::<__m32i>::from(&self.sse2_hash);
self.vector_search_in_avx2_version(haystack, end, &hash)
crate::vector_search_in_avx2_version(self.needle(), self.position(), haystack, end, &hash)
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn sse2_8_search_in(&self, haystack: &[u8], end: usize) -> bool {
let hash = VectorHash::<__m64i>::from(&self.sse2_hash);
self.vector_search_in_avx2_version(haystack, end, &hash)
crate::vector_search_in_avx2_version(self.needle(), self.position(), haystack, end, &hash)
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn sse2_16_search_in(&self, haystack: &[u8], end: usize) -> bool {
self.vector_search_in_avx2_version(haystack, end, &self.sse2_hash)
let hash = &self.sse2_hash;
crate::vector_search_in_avx2_version(self.needle(), self.position(), haystack, end, hash)
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn avx2_search_in(&self, haystack: &[u8], end: usize) -> bool {
self.vector_search_in_avx2_version(haystack, end, &self.avx2_hash)
let hash = &self.avx2_hash;
crate::vector_search_in_avx2_version(self.needle(), self.position(), haystack, end, hash)
}

/// Inlined version of `search_in` for hot call sites.
Expand Down

0 comments on commit dbab616

Please sign in to comment.