Skip to content

Commit

Permalink
Extract serial API
Browse files Browse the repository at this point in the history
  • Loading branch information
l4l committed May 2, 2022
1 parent f45a30d commit d005d81
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 280 deletions.
15 changes: 10 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,21 @@ name = "fzyr"
name = "fzyr"
path = "src/bin/main.rs"
doc = false
required-features = ["binary-build"]


[dependencies]
ndarray = "^0.11.2"
itertools = "^0.7.8"
crossbeam = "^0.4.1"
bit-vec = "^0.5.0"
clap = "^2.32.0"
console = "^0.6.1"

itertools = { version = "^0.7.8", optional = true }
crossbeam = { version = "^0.4.1", optional = true }
clap = { version = "^2.32.0", optional = true }
console = { version = "^0.6.1", optional = true }

[features]
default = ["binary-build", "parallel"]
binary-build = ["clap", "console"]
parallel = ["itertools", "crossbeam"]

[profile.release]
opt-level = 3
Expand Down
5 changes: 4 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ mod score;
mod search;

pub use score::{config, has_match, locate, score, LocateResult, Score, ScoreResult};
pub use search::{search_locate, search_score, LocateResults, ScoreResults};
pub use search::{search_serial, locate_serial, LocateResults, ScoreResults};

#[cfg(feature = "parallel")]
pub use search::{search_locate, search_score};
289 changes: 15 additions & 274 deletions src/search/mod.rs
Original file line number Diff line number Diff line change
@@ -1,94 +1,36 @@
extern crate crossbeam;
extern crate itertools;

use std::cmp::Ordering;
use std::usize;

use self::crossbeam::channel;
use self::crossbeam::scope as thread_scope;
use self::itertools::kmerge;
use std::iter::ExactSizeIterator;

use score::{has_match, locate_inner, score_inner, LocateResult, ScoreResult};

#[cfg(feature = "parallel")]
mod parallel;

#[cfg(feature = "parallel")]
pub use self::parallel::{search_score, search_locate};

/// Collection of scores and the candidates they apply to
pub type ScoreResults = Vec<ScoreResult>;
/// Collection of scores, locations, and the candidates they apply to
pub type LocateResults = Vec<LocateResult>;

/// Search among a collection of candidates using the given query, returning
/// an ordered collection of results (highest score first)
pub fn search_score(
pub fn search_serial(
query: &str,
candidates: &[&str],
parallelism: usize,
candidates: impl Iterator<Item = impl AsRef<str>> + ExactSizeIterator,
) -> ScoreResults {
search_internal(query, candidates, parallelism, score_inner).collect()
search_worker(candidates, query, 0, score_inner)
}

/// Search among a collection of candidates using the given query, returning
/// an ordered collection of results (highest score first) with the locations
/// of the query in each candidate
pub fn search_locate(
pub fn locate_serial(
query: &str,
candidates: &[&str],
parallelism: usize,
candidates: impl Iterator<Item = impl AsRef<str>> + ExactSizeIterator,
) -> LocateResults {
search_internal(query, candidates, parallelism, locate_inner).collect()
}

fn search_internal<T>(
query: &str,
candidates: &[&str],
parallelism: usize,
search_fn: fn(&str, &str, usize) -> T,
) -> Box<dyn Iterator<Item = T>>
where
T: PartialOrd + Sized + Send + 'static,
{
let parallelism = calculate_parallelism(candidates.len(), parallelism, query.is_empty());
let mut candidates = candidates;
let (sender, receiver) = channel::bounded::<Vec<T>>(parallelism);

if parallelism < 2 {
Box::new(search_worker(candidates, query, 0, search_fn).into_iter())
} else {
thread_scope(|scope| {
let mut remaining_candidates = candidates.len();
let per_thread_count = ceil_div(remaining_candidates, parallelism);
let mut thread_offset = 0;

// Create "parallelism" threads
while remaining_candidates > 0 {
// Search in this thread's share
let split = if remaining_candidates >= per_thread_count {
remaining_candidates -= per_thread_count;
per_thread_count
} else {
remaining_candidates = 0;
remaining_candidates
};
let split = candidates.split_at(split);
let splitted_len = split.0.len();
let sender = sender.clone();
scope.spawn(move || {
sender.send(search_worker(split.0, query, thread_offset, search_fn));
});
thread_offset += splitted_len;

// Remove that share from the candidate slice
candidates = split.1;
}

drop(sender);
});

Box::new(kmerge(receiver))
}
search_worker(candidates, query, 0, locate_inner)
}

// Search among candidates against a query in a single thread
fn search_worker<T>(
candidates: &[&str],
candidates: impl IntoIterator<Item = impl AsRef<str>> + ExactSizeIterator,
query: &str,
offset_index: usize,
search_fn: fn(&str, &str, usize) -> T
Expand All @@ -98,6 +40,7 @@ where
{
let mut out = Vec::with_capacity(candidates.len());
for (index, candidate) in candidates.into_iter().enumerate() {
let candidate = candidate.as_ref();
if has_match(&query, candidate) {
out.push(search_fn(&query, candidate, offset_index + index));
}
Expand All @@ -107,205 +50,3 @@ where
out
}

fn calculate_parallelism(
candidate_count: usize,
configured_parallelism: usize,
empty_query: bool,
) -> usize {
if empty_query {
// No need to do much for no query
return 1;
}

// Use a ramp up to avoid unecessarily starting threads with few candidates
let ramped_parallelism = match candidate_count {
n if n < 17 => ceil_div(n, 4),
n if n > 32 => ceil_div(n, 8),
_ => 4,
};

configured_parallelism
.min(ramped_parallelism)
.min(candidate_count)
.max(1)
}

/// Integer ceiling division
fn ceil_div(a: usize, b: usize) -> usize {
(a + b - 1) / b
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn parallelism_ramp() {
assert_eq!(1, calculate_parallelism(0, 0, false));
assert_eq!(1, calculate_parallelism(1, 0, false));
assert_eq!(1, calculate_parallelism(0, 1, false));
assert_eq!(1, calculate_parallelism(1, 1, false));

assert_eq!(1, calculate_parallelism(2, usize::MAX, false));
assert_eq!(1, calculate_parallelism(3, 4, false));
assert_eq!(1, calculate_parallelism(4, 2, false));

for n in 5..9 {
assert_eq!(2, calculate_parallelism(n, usize::MAX, false));
assert_eq!(1, calculate_parallelism(n, usize::MAX, true));
}

for n in 9..13 {
assert_eq!(3, calculate_parallelism(n, usize::MAX, false));
assert_eq!(1, calculate_parallelism(n, usize::MAX, true));
}

for n in 13..33 {
assert_eq!(4, calculate_parallelism(n, usize::MAX, false));
assert_eq!(1, calculate_parallelism(n, usize::MAX, true));
}

for n in 1..10_000 {
assert!(calculate_parallelism(n, 12, false) <= 12);
assert_eq!(1, calculate_parallelism(n, 12, true));
}
}

fn search_empty_with_parallelism(parallelism: usize) {
let rs = search_score("", &[], parallelism);
assert_eq!(0, rs.len());

let rs = search_score("test", &[], parallelism);
assert_eq!(0, rs.len());
}

fn search_with_parallelism(parallelism: usize) {
search_empty_with_parallelism(parallelism);

let rs = search_score("", &["tags"], parallelism);
assert_eq!(1, rs.len());
assert_eq!(0, rs[0].candidate_index);

let rs = search_score("♺", &["ñîƹ♺à"], parallelism);
assert_eq!(1, rs.len());
assert_eq!(0, rs[0].candidate_index);

let cs = &["tags", "test"];

let rs = search_score("", cs, parallelism);
assert_eq!(2, rs.len());

let rs = search_score("te", cs, parallelism);
assert_eq!(1, rs.len());
assert_eq!(1, rs[0].candidate_index);

let rs = search_score("foobar", cs, parallelism);
assert_eq!(0, rs.len());

let rs = search_score("ts", cs, parallelism);
assert_eq!(2, rs.len());
assert_eq!(
vec![1, 0],
rs.iter().map(|r| r.candidate_index).collect::<Vec<_>>()
);
}

fn search_med_parallelism(parallelism: usize) {
let cs = &[
"one",
"two",
"three",
"four",
"five",
"six",
"seven",
"eight",
"nine",
"ten",
"eleven",
"twelve",
"thirteen",
"fourteen",
"fifteen",
"sixteen",
"seventeen",
"eighteen",
"nineteen",
"twenty",
];

let rs = search_score("", cs, parallelism);
assert_eq!(cs.len(), rs.len());

let rs = search_score("teen", cs, parallelism);
assert_eq!(7, rs.len());
for r in rs {
assert_eq!(
"neet",
cs[r.candidate_index].chars().rev().take(4).collect::<String>()
);
}

let rs = search_score("tee", cs, parallelism);
assert_eq!(9, rs.len());
assert_eq!(
"neet",
cs[rs[0].candidate_index].chars().rev().take(4).collect::<String>()
);

let rs = search_score("six", cs, parallelism);
assert_eq!("six", cs[rs[0].candidate_index]);
}

fn search_large_parallelism(parallelism: usize) {
let n = 100_000;
let mut candidates = Vec::with_capacity(n);
for i in 0..n {
candidates.push(format!("{}", i));
}

let rs = search_score(
"12",
&(candidates.iter().map(|s| &s[..]).collect::<Vec<&str>>()),
parallelism,
);

// This has been precalculated
// e.g. via `$ seq 0 99999 | grep '.*1.*2.*' | wc -l`
assert_eq!(8146, rs.len());
assert_eq!("12", candidates[rs[0].candidate_index]);
}

// TODO: test locate

#[test]
fn search_single() {
search_with_parallelism(0);
search_with_parallelism(1);
search_large_parallelism(1);
}

#[test]
fn search_double() {
search_with_parallelism(2);
search_large_parallelism(2);
}

#[test]
fn search_quad() {
search_med_parallelism(4);
search_large_parallelism(4);
}

#[test]
fn search_quin() {
search_med_parallelism(4);
search_large_parallelism(5);
}

#[test]
fn search_large() {
search_med_parallelism(4);
search_large_parallelism(16);
}
}
Loading

0 comments on commit d005d81

Please sign in to comment.