diff --git a/cpp/bench/prims/matrix/select_k.cu b/cpp/bench/prims/matrix/select_k.cu index 1bff66cac4..992fda8a38 100644 --- a/cpp/bench/prims/matrix/select_k.cu +++ b/cpp/bench/prims/matrix/select_k.cu @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -38,6 +39,19 @@ namespace raft::matrix { using namespace raft::bench; // NOLINT +template +struct replace_with_mask { + KeyT replacement; + int64_t line_length; + int64_t spared_inputs; + constexpr auto inline operator()(int64_t offset, KeyT x, uint8_t mask) -> KeyT + { + auto i = offset % line_length; + // don't replace all the inputs, spare a few elements at the beginning of the input + return (mask && i >= spared_inputs) ? replacement : x; + } +}; + template struct selection : public fixture { explicit selection(const select::params& p) @@ -67,6 +81,21 @@ struct selection : public fixture { } } raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), min_value, max_value); + if (p.frac_infinities > 0.0) { + rmm::device_uvector mask_buf(p.batch_size * p.len, stream); + auto mask = make_device_vector_view(mask_buf.data(), mask_buf.size()); + raft::random::bernoulli(handle, state, mask, p.frac_infinities); + KeyT bound = p.select_min ? raft::upper_bound() : raft::lower_bound(); + auto mask_in = + make_device_vector_view(mask_buf.data(), mask_buf.size()); + auto dists_in = make_device_vector_view(in_dists_.data(), in_dists_.size()); + auto dists_out = make_device_vector_view(in_dists_.data(), in_dists_.size()); + raft::linalg::map_offset(handle, + dists_out, + replace_with_mask{bound, int64_t(p.len), int64_t(p.k / 2)}, + dists_in, + mask_in); + } } void run_benchmark(::benchmark::State& state) override // NOLINT @@ -75,8 +104,12 @@ struct selection : public fixture { std::ostringstream label_stream; label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k; if (params_.use_same_leading_bits) { label_stream << "#same-leading-bits"; } + if (params_.frac_infinities > 0) { label_stream << "#infs-" << params_.frac_infinities; } state.SetLabel(label_stream.str()); - loop_on_state(state, [this]() { + common::nvtx::range case_scope("%s - %s", state.name().c_str(), label_stream.str().c_str()); + int iter = 0; + loop_on_state(state, [&iter, this]() { + common::nvtx::range lap_scope("lap-", iter++); select::select_k_impl(handle, Algo, in_dists_.data(), @@ -149,6 +182,35 @@ const std::vector kInputs{ {10, 1000000, 64, true, false, true}, {10, 1000000, 128, true, false, true}, {10, 1000000, 256, true, false, true}, + + {10, 1000000, 1, true, false, false, true, 0.1}, + {10, 1000000, 16, true, false, false, true, 0.1}, + {10, 1000000, 64, true, false, false, true, 0.1}, + {10, 1000000, 128, true, false, false, true, 0.1}, + {10, 1000000, 256, true, false, false, true, 0.1}, + + {10, 1000000, 1, true, false, false, true, 0.9}, + {10, 1000000, 16, true, false, false, true, 0.9}, + {10, 1000000, 64, true, false, false, true, 0.9}, + {10, 1000000, 128, true, false, false, true, 0.9}, + {10, 1000000, 256, true, false, false, true, 0.9}, + {1000, 10000, 1, true, false, false, true, 0.9}, + {1000, 10000, 16, true, false, false, true, 0.9}, + {1000, 10000, 64, true, false, false, true, 0.9}, + {1000, 10000, 128, true, false, false, true, 0.9}, + {1000, 10000, 256, true, false, false, true, 0.9}, + + {10, 1000000, 1, true, false, false, true, 1.0}, + {10, 1000000, 16, true, false, false, true, 1.0}, + {10, 1000000, 64, true, false, false, true, 1.0}, + {10, 1000000, 128, true, false, false, true, 1.0}, + {10, 1000000, 256, true, false, false, true, 1.0}, + {1000, 10000, 1, true, false, false, true, 1.0}, + {1000, 10000, 16, true, false, false, true, 1.0}, + {1000, 10000, 64, true, false, false, true, 1.0}, + {1000, 10000, 128, true, false, false, true, 1.0}, + {1000, 10000, 256, true, false, false, true, 1.0}, + {1000, 10000, 256, true, false, false, true, 0.999}, }; #define SELECTION_REGISTER(KeyT, IdxT, A) \ diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index b72e67580a..1d15c5fc03 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -33,6 +33,7 @@ struct params { bool use_index_input = true; bool use_same_leading_bits = false; bool use_memory_pool = true; + double frac_infinities = 0.0; }; inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& @@ -41,8 +42,10 @@ inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& os << ", len: " << ss.len; os << ", k: " << ss.k; os << (ss.select_min ? ", asc" : ", dsc"); - os << (ss.use_index_input ? "" : ", no-input-index"); - os << (ss.use_same_leading_bits ? ", same-leading-bits}" : "}"); + if (!ss.use_index_input) { os << ", no-input-index"; } + if (ss.use_same_leading_bits) { os << ", same-leading-bits"; } + if (ss.frac_infinities > 0) { os << ", infs: " << ss.frac_infinities; } + os << "}"; return os; } diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index 63f020b420..ce4e3e867e 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -70,6 +70,28 @@ auto inputs_random_largek = testing::Values(select::params{100, 100000, 1000, tr select::params{100, 100000, 2048, false}, select::params{100, 100000, 1237, true}); +auto inputs_random_many_infs = + testing::Values(select::params{10, 100000, 1, true, false, false, true, 0.9}, + select::params{10, 100000, 16, true, false, false, true, 0.9}, + select::params{10, 100000, 64, true, false, false, true, 0.9}, + select::params{10, 100000, 128, true, false, false, true, 0.9}, + select::params{10, 100000, 256, true, false, false, true, 0.9}, + select::params{1000, 10000, 1, true, false, false, true, 0.9}, + select::params{1000, 10000, 16, true, false, false, true, 0.9}, + select::params{1000, 10000, 64, true, false, false, true, 0.9}, + select::params{1000, 10000, 128, true, false, false, true, 0.9}, + select::params{1000, 10000, 256, true, false, false, true, 0.9}, + select::params{10, 100000, 1, true, false, false, true, 0.999}, + select::params{10, 100000, 16, true, false, false, true, 0.999}, + select::params{10, 100000, 64, true, false, false, true, 0.999}, + select::params{10, 100000, 128, true, false, false, true, 0.999}, + select::params{10, 100000, 256, true, false, false, true, 0.999}, + select::params{1000, 10000, 1, true, false, false, true, 0.999}, + select::params{1000, 10000, 16, true, false, false, true, 0.999}, + select::params{1000, 10000, 64, true, false, false, true, 0.999}, + select::params{1000, 10000, 128, true, false, false, true, 0.999}, + select::params{1000, 10000, 256, true, false, false, true, 0.999}); + using ReferencedRandomFloatInt = SelectK::params_random>; TEST_P(ReferencedRandomFloatInt, Run) { run(); } // NOLINT @@ -111,4 +133,16 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kRadix8bits, select::Algo::kRadix11bits, select::Algo::kRadix11bitsExtraPass))); + +using ReferencedRandomFloatIntkWarpsortAsGT = + SelectK::params_random>; +TEST_P(ReferencedRandomFloatIntkWarpsortAsGT, Run) { run(); } // NOLINT +INSTANTIATE_TEST_CASE_P( // NOLINT + SelectK, + ReferencedRandomFloatIntkWarpsortAsGT, + testing::Combine(inputs_random_many_infs, + testing::Values(select::Algo::kRadix8bits, + select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass))); + } // namespace raft::matrix diff --git a/cpp/test/matrix/select_k.cuh b/cpp/test/matrix/select_k.cuh index e0e0cad225..eaabbb3357 100644 --- a/cpp/test/matrix/select_k.cuh +++ b/cpp/test/matrix/select_k.cuh @@ -49,14 +49,16 @@ auto gen_simple_ids(uint32_t batch_size, uint32_t len) -> std::vector template struct io_simple { public: - bool not_supported = false; + bool not_supported = false; + std::optional algo = std::nullopt; io_simple(const select::params& spec, const std::vector& in_dists, + const std::optional>& in_ids, const std::vector& out_dists, const std::vector& out_ids) : in_dists_(in_dists), - in_ids_(gen_simple_ids(spec.batch_size, spec.len)), + in_ids_(in_ids.value_or(gen_simple_ids(spec.batch_size, spec.len))), out_dists_(out_dists), out_ids_(out_ids) { @@ -78,12 +80,14 @@ template struct io_computed { public: bool not_supported = false; + select::Algo algo; io_computed(const select::params& spec, const select::Algo& algo, const std::vector& in_dists, const std::optional>& in_ids = std::nullopt) - : in_dists_(in_dists), + : algo(algo), + in_dists_(in_dists), in_ids_(in_ids.value_or(gen_simple_ids(spec.batch_size, spec.len))), out_dists_(spec.batch_size * spec.k), out_ids_(spec.batch_size * spec.k) @@ -223,32 +227,62 @@ struct SelectK // NOLINT if (ref.not_supported || res.not_supported) { GTEST_SKIP(); } ASSERT_TRUE(hostVecMatch(ref.get_out_dists(), res.get_out_dists(), Compare())); - // If the dists (keys) are the same, different corresponding ids may end up in the selection due - // to non-deterministic nature of some implementations. - auto& in_ids = ref.get_in_ids(); - auto& in_dists = ref.get_in_dists(); - auto compare_ids = [&in_ids, &in_dists](const IdxT& i, const IdxT& j) { + // If the dists (keys) are the same, different corresponding ids may end up in the selection + // due to non-deterministic nature of some implementations. + auto compare_ids = [this](const IdxT& i, const IdxT& j) { if (i == j) return true; + auto& in_ids = ref.get_in_ids(); + auto& in_dists = ref.get_in_dists(); auto ix_i = static_cast(std::find(in_ids.begin(), in_ids.end(), i) - in_ids.begin()); auto ix_j = static_cast(std::find(in_ids.begin(), in_ids.end(), j) - in_ids.begin()); - if (static_cast(ix_i) >= in_ids.size() || static_cast(ix_j) >= in_ids.size()) - return false; + auto forgive_i = forgive_algo(ref.algo, i); + auto forgive_j = forgive_algo(res.algo, j); + // Some algorithms return invalid indices in special cases. + // This can be considered as TODO for us to fix. + if (static_cast(ix_i) >= in_ids.size()) return forgive_i; + if (static_cast(ix_j) >= in_ids.size()) return forgive_j; auto dist_i = in_dists[ix_i]; auto dist_j = in_dists[ix_j]; if (dist_i == dist_j) return true; + const auto bound = spec.select_min ? raft::upper_bound() : raft::lower_bound(); + if (forgive_i && dist_i == bound) return true; + if (forgive_j && dist_j == bound) return true; + // Otherwise really fail std::cout << "ERROR: ref[" << ix_i << "] = " << dist_i << " != " << "res[" << ix_j << "] = " << dist_j << std::endl; return false; }; ASSERT_TRUE(hostVecMatch(ref.get_out_ids(), res.get_out_ids(), compare_ids)); } + + auto forgive_algo(const std::optional& algo, IdxT ix) const -> bool + { + if (!algo.has_value()) { return false; } + switch (algo.value()) { + // not sure which algo this is. + case select::Algo::kPublicApi: return true; + // warp-sort-based algos currently return zero index for inf distances. + case select::Algo::kWarpAuto: + case select::Algo::kWarpImmediate: + case select::Algo::kWarpFiltered: + case select::Algo::kWarpDistributed: + case select::Algo::kWarpDistributedShm: return ix == 0; + // FAISS version returns a special invalid value: + case select::Algo::kFaissBlockSelect: return ix == std::numeric_limits::max(); + // Do not forgive by default + default: return false; + } + } }; template struct params_simple { - using io_t = io_simple; - using input_t = - std::tuple, std::vector, std::vector>; + using io_t = io_simple; + using input_t = std::tuple, + std::optional>, + std::vector, + std::vector>; using params_t = std::tuple; static auto read(params_t ps) -> Params @@ -259,15 +293,17 @@ struct params_simple { std::get<0>(ins), algo, io_simple( - std::get<0>(ins), std::get<1>(ins), std::get<2>(ins), std::get<3>(ins))); + std::get<0>(ins), std::get<1>(ins), std::get<2>(ins), std::get<3>(ins), std::get<4>(ins))); } }; +auto inf_f = std::numeric_limits::max(); auto inputs_simple_f = testing::Values( params_simple::input_t( {5, 5, 5, true, true}, {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + std::nullopt, {1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0}, {4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}), @@ -275,12 +311,14 @@ auto inputs_simple_f = testing::Values( {5, 5, 3, true, true}, {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + std::nullopt, {1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, {4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}), params_simple::input_t( {5, 5, 5, true, false}, {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + std::nullopt, {1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0}, {4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}), @@ -288,20 +326,31 @@ auto inputs_simple_f = testing::Values( {5, 5, 3, true, false}, {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + std::nullopt, {1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, {4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}), params_simple::input_t( {5, 7, 3, true, true}, {5.0, 4.0, 3.0, 2.0, 1.3, 7.5, 19.0, 9.0, 2.0, 3.0, 3.0, 5.0, 6.0, 4.0, 2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0, 5.0, 7.0, 2.5, 4.0, 7.0, 8.0, 8.0, 1.0, 3.0, 2.0, 5.0, 4.0, 1.1, 1.2}, + std::nullopt, {1.3, 2.0, 3.0, 2.0, 3.0, 3.0, 1.0, 1.0, 1.0, 2.5, 4.0, 5.0, 1.0, 1.1, 1.2}, {4, 3, 2, 1, 2, 3, 3, 5, 6, 2, 3, 0, 0, 5, 6}), - params_simple::input_t( - {1, 7, 3, true, true}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {1.0, 1.0, 1.0}, {3, 5, 6}), - params_simple::input_t( - {1, 7, 3, false, false}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {5.0, 4.0, 3.0}, {2, 4, 1}), - params_simple::input_t( - {1, 7, 3, false, true}, {2.0, 3.0, 5.0, 9.0, 4.0, 9.0, 9.0}, {9.0, 9.0, 9.0}, {3, 5, 6}), + params_simple::input_t({1, 7, 3, true, true}, + {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, + std::nullopt, + {1.0, 1.0, 1.0}, + {3, 5, 6}), + params_simple::input_t({1, 7, 3, false, false}, + {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, + std::nullopt, + {5.0, 4.0, 3.0}, + {2, 4, 1}), + params_simple::input_t({1, 7, 3, false, true}, + {2.0, 3.0, 5.0, 9.0, 4.0, 9.0, 9.0}, + std::nullopt, + {9.0, 9.0, 9.0}, + {3, 5, 6}), params_simple::input_t( {1, 130, 5, false, true}, {19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, @@ -309,6 +358,7 @@ auto inputs_simple_f = testing::Values( 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 4, 2, 3, 2, 3, 2, 3, 2, 3, 2, 20}, + std::nullopt, {20, 19, 18, 17, 16}, {129, 0, 117, 116, 115}), params_simple::input_t( @@ -318,8 +368,20 @@ auto inputs_simple_f = testing::Values( 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 4, 2, 3, 2, 3, 2, 3, 2, 3, 2, 20}, + std::nullopt, {20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6}, - {129, 0, 117, 116, 115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105})); + {129, 0, 117, 116, 115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105}), + params_simple::input_t( + select::params{1, 32, 31, true, true}, + {0, 1, 2, 3, inf_f, inf_f, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + std::optional{std::vector{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, + 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, + 9, 8, 7, 6, 75, 74, 3, 2, 1, 0}}, + {0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, inf_f}, + {31, 30, 29, 28, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, + 13, 12, 11, 10, 9, 8, 7, 6, 75, 74, 3, 2, 1, 0, 27})); using SimpleFloatInt = SelectK; TEST_P(SimpleFloatInt, Run) { run(); } // NOLINT @@ -335,6 +397,12 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kWarpFiltered, select::Algo::kWarpDistributed))); +template +struct replace_with_mask { + KeyT replacement; + constexpr auto inline operator()(KeyT x, uint8_t mask) -> KeyT { return mask ? replacement : x; } +}; + template struct with_ref { template @@ -354,6 +422,19 @@ struct with_ref { rmm::device_uvector dists_d(spec.len * spec.batch_size, s); raft::random::RngState r(42); normal(handle, r, dists_d.data(), dists_d.size(), KeyT(10.0), KeyT(100.0)); + + if (spec.frac_infinities > 0.0) { + rmm::device_uvector mask_buf(dists_d.size(), s); + auto mask = make_device_vector_view(mask_buf.data(), mask_buf.size()); + raft::random::bernoulli(handle, r, mask, spec.frac_infinities); + KeyT bound = spec.select_min ? raft::upper_bound() : raft::lower_bound(); + auto mask_in = + make_device_vector_view(mask_buf.data(), mask_buf.size()); + auto dists_in = make_device_vector_view(dists_d.data(), dists_d.size()); + auto dists_out = make_device_vector_view(dists_d.data(), dists_d.size()); + raft::linalg::map(handle, dists_out, replace_with_mask{bound}, dists_in, mask_in); + } + update_host(dists.data(), dists_d.data(), dists_d.size(), s); s.synchronize(); }