diff --git a/include/valik/search/local_prefilter.hpp b/include/valik/search/local_prefilter.hpp index df4048c2..cf1e9889 100644 --- a/include/valik/search/local_prefilter.hpp +++ b/include/valik/search/local_prefilter.hpp @@ -107,7 +107,8 @@ template void find_pattern_bins(pattern_bounds const & pattern, size_t const & bin_count, binning_bitvector_t const & counting_table, - std::unordered_set & sequence_hits) + std::unordered_map & sequence_hits, + uint64_t & pattern_hits) { // counting vector for the current pattern seqan3::counting_vector total_counts(bin_count, 0); @@ -119,8 +120,9 @@ void find_pattern_bins(pattern_bounds const & pattern, auto &&count = total_counts[current_bin]; if (count >= pattern.threshold) { - // the result_set is a union of results from all patterns of a read - sequence_hits.insert(current_bin); + // the result is a union of results from all patterns of a read + sequence_hits[current_bin]++; + pattern_hits++; } } } @@ -198,14 +200,16 @@ void local_prefilter( minimiser.clear(); - std::unordered_set sequence_hits{}; + uint64_t pattern_hits{0}; + // {bin ID, pattern hit count} + std::unordered_map sequence_hits{}; pattern_begin_positions(seq.size(), arguments.pattern_size, arguments.query_every, [&](size_t const begin) { pattern_bounds const pattern = make_pattern_bounds(begin, arguments, window_span_begin, thresholder); - find_pattern_bins(pattern, bin_count, counting_table, sequence_hits); + find_pattern_bins(pattern, bin_count, counting_table, sequence_hits, pattern_hits); }); - result_cb(record, sequence_hits); + result_cb(record, sequence_hits, pattern_hits); } } diff --git a/include/valik/search/producer_threads_parallel.hpp b/include/valik/search/producer_threads_parallel.hpp index 7139284b..067d4e29 100644 --- a/include/valik/search/producer_threads_parallel.hpp +++ b/include/valik/search/producer_threads_parallel.hpp @@ -43,20 +43,27 @@ inline void prefilter_queries_parallel(seqan3::interleaved_bloom_filter records_slice{&records[start], &records[end]}; - auto prefilter_cb = [&queue,&arguments,&verbose_out,&ibf](query_t const& record, std::unordered_set const& bin_hits) + auto prefilter_cb = [&queue,&arguments,&verbose_out,&ibf](query_t const & record, + std::unordered_map const & bin_hits, + uint64_t const & total_pattern_hits) { if (bin_hits.size() > std::max((size_t) 4, (size_t) std::round(ibf.bin_count() / 2.0))) { - if (!arguments.keep_repeats) + if (arguments.verbose) + verbose_out.write_warning(record, bin_hits.size()); + if (arguments.keep_repeats) // keep bin hits that are supported by the most patterns per query segment { - verbose_out.write_disabled_record(record, bin_hits.size(), arguments.verbose); - return; + size_t mean_bin_support = std::max((size_t) 2, (size_t) std::round((double) total_pattern_hits / (double) bin_hits.size())); + for (auto const [bin, count] : bin_hits) + { + if (count > mean_bin_support) + queue.insert(bin, record); + } } - else if (arguments.verbose) - verbose_out.write_warning(record, bin_hits.size()); + return; } - for (size_t const bin : bin_hits) + for (auto const [bin, count] : bin_hits) { queue.insert(bin, record); }