Skip to content

Commit

Permalink
matrix::select_k: extra tests and benchmarks (#1821)
Browse files Browse the repository at this point in the history
Add a few extra test and benchmark cases; in particular:
  1. Allow specifying non-trivial input indices
  2. Allow filling the input data with infinities to see how algorithms perform in edge cases

These tests are borrowed from the controversial workaround #1742

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #1821
  • Loading branch information
achirkin authored Sep 18, 2023
1 parent 28b7894 commit b9cf917
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 24 deletions.
64 changes: 63 additions & 1 deletion cpp/bench/prims/matrix/select_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <common/benchmark.hpp>

#include <raft/core/device_resources.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/random/rng.cuh>
#include <raft/sparse/detail/utils.h>
#include <raft/util/cudart_utils.hpp>
Expand All @@ -38,6 +39,19 @@
namespace raft::matrix {
using namespace raft::bench; // NOLINT

template <typename KeyT>
struct replace_with_mask {
KeyT replacement;
int64_t line_length;
int64_t spared_inputs;
constexpr auto inline operator()(int64_t offset, KeyT x, uint8_t mask) -> KeyT
{
auto i = offset % line_length;
// don't replace all the inputs, spare a few elements at the beginning of the input
return (mask && i >= spared_inputs) ? replacement : x;
}
};

template <typename KeyT, typename IdxT, select::Algo Algo>
struct selection : public fixture {
explicit selection(const select::params& p)
Expand Down Expand Up @@ -67,6 +81,21 @@ struct selection : public fixture {
}
}
raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), min_value, max_value);
if (p.frac_infinities > 0.0) {
rmm::device_uvector<uint8_t> mask_buf(p.batch_size * p.len, stream);
auto mask = make_device_vector_view<uint8_t, size_t>(mask_buf.data(), mask_buf.size());
raft::random::bernoulli(handle, state, mask, p.frac_infinities);
KeyT bound = p.select_min ? raft::upper_bound<KeyT>() : raft::lower_bound<KeyT>();
auto mask_in =
make_device_vector_view<const uint8_t, size_t>(mask_buf.data(), mask_buf.size());
auto dists_in = make_device_vector_view<const KeyT>(in_dists_.data(), in_dists_.size());
auto dists_out = make_device_vector_view<KeyT>(in_dists_.data(), in_dists_.size());
raft::linalg::map_offset(handle,
dists_out,
replace_with_mask<KeyT>{bound, int64_t(p.len), int64_t(p.k / 2)},
dists_in,
mask_in);
}
}

void run_benchmark(::benchmark::State& state) override // NOLINT
Expand All @@ -75,8 +104,12 @@ struct selection : public fixture {
std::ostringstream label_stream;
label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k;
if (params_.use_same_leading_bits) { label_stream << "#same-leading-bits"; }
if (params_.frac_infinities > 0) { label_stream << "#infs-" << params_.frac_infinities; }
state.SetLabel(label_stream.str());
loop_on_state(state, [this]() {
common::nvtx::range case_scope("%s - %s", state.name().c_str(), label_stream.str().c_str());
int iter = 0;
loop_on_state(state, [&iter, this]() {
common::nvtx::range lap_scope("lap-", iter++);
select::select_k_impl<KeyT, IdxT>(handle,
Algo,
in_dists_.data(),
Expand Down Expand Up @@ -149,6 +182,35 @@ const std::vector<select::params> kInputs{
{10, 1000000, 64, true, false, true},
{10, 1000000, 128, true, false, true},
{10, 1000000, 256, true, false, true},

{10, 1000000, 1, true, false, false, true, 0.1},
{10, 1000000, 16, true, false, false, true, 0.1},
{10, 1000000, 64, true, false, false, true, 0.1},
{10, 1000000, 128, true, false, false, true, 0.1},
{10, 1000000, 256, true, false, false, true, 0.1},

{10, 1000000, 1, true, false, false, true, 0.9},
{10, 1000000, 16, true, false, false, true, 0.9},
{10, 1000000, 64, true, false, false, true, 0.9},
{10, 1000000, 128, true, false, false, true, 0.9},
{10, 1000000, 256, true, false, false, true, 0.9},
{1000, 10000, 1, true, false, false, true, 0.9},
{1000, 10000, 16, true, false, false, true, 0.9},
{1000, 10000, 64, true, false, false, true, 0.9},
{1000, 10000, 128, true, false, false, true, 0.9},
{1000, 10000, 256, true, false, false, true, 0.9},

{10, 1000000, 1, true, false, false, true, 1.0},
{10, 1000000, 16, true, false, false, true, 1.0},
{10, 1000000, 64, true, false, false, true, 1.0},
{10, 1000000, 128, true, false, false, true, 1.0},
{10, 1000000, 256, true, false, false, true, 1.0},
{1000, 10000, 1, true, false, false, true, 1.0},
{1000, 10000, 16, true, false, false, true, 1.0},
{1000, 10000, 64, true, false, false, true, 1.0},
{1000, 10000, 128, true, false, false, true, 1.0},
{1000, 10000, 256, true, false, false, true, 1.0},
{1000, 10000, 256, true, false, false, true, 0.999},
};

#define SELECTION_REGISTER(KeyT, IdxT, A) \
Expand Down
7 changes: 5 additions & 2 deletions cpp/internal/raft_internal/matrix/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct params {
bool use_index_input = true;
bool use_same_leading_bits = false;
bool use_memory_pool = true;
double frac_infinities = 0.0;
};

inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream&
Expand All @@ -41,8 +42,10 @@ inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream&
os << ", len: " << ss.len;
os << ", k: " << ss.k;
os << (ss.select_min ? ", asc" : ", dsc");
os << (ss.use_index_input ? "" : ", no-input-index");
os << (ss.use_same_leading_bits ? ", same-leading-bits}" : "}");
if (!ss.use_index_input) { os << ", no-input-index"; }
if (ss.use_same_leading_bits) { os << ", same-leading-bits"; }
if (ss.frac_infinities > 0) { os << ", infs: " << ss.frac_infinities; }
os << "}";
return os;
}

Expand Down
34 changes: 34 additions & 0 deletions cpp/test/matrix/select_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,28 @@ auto inputs_random_largek = testing::Values(select::params{100, 100000, 1000, tr
select::params{100, 100000, 2048, false},
select::params{100, 100000, 1237, true});

auto inputs_random_many_infs =
testing::Values(select::params{10, 100000, 1, true, false, false, true, 0.9},
select::params{10, 100000, 16, true, false, false, true, 0.9},
select::params{10, 100000, 64, true, false, false, true, 0.9},
select::params{10, 100000, 128, true, false, false, true, 0.9},
select::params{10, 100000, 256, true, false, false, true, 0.9},
select::params{1000, 10000, 1, true, false, false, true, 0.9},
select::params{1000, 10000, 16, true, false, false, true, 0.9},
select::params{1000, 10000, 64, true, false, false, true, 0.9},
select::params{1000, 10000, 128, true, false, false, true, 0.9},
select::params{1000, 10000, 256, true, false, false, true, 0.9},
select::params{10, 100000, 1, true, false, false, true, 0.999},
select::params{10, 100000, 16, true, false, false, true, 0.999},
select::params{10, 100000, 64, true, false, false, true, 0.999},
select::params{10, 100000, 128, true, false, false, true, 0.999},
select::params{10, 100000, 256, true, false, false, true, 0.999},
select::params{1000, 10000, 1, true, false, false, true, 0.999},
select::params{1000, 10000, 16, true, false, false, true, 0.999},
select::params{1000, 10000, 64, true, false, false, true, 0.999},
select::params{1000, 10000, 128, true, false, false, true, 0.999},
select::params{1000, 10000, 256, true, false, false, true, 0.999});

using ReferencedRandomFloatInt =
SelectK<float, uint32_t, with_ref<select::Algo::kPublicApi>::params_random>;
TEST_P(ReferencedRandomFloatInt, Run) { run(); } // NOLINT
Expand Down Expand Up @@ -111,4 +133,16 @@ INSTANTIATE_TEST_CASE_P( // NOLINT
select::Algo::kRadix8bits,
select::Algo::kRadix11bits,
select::Algo::kRadix11bitsExtraPass)));

using ReferencedRandomFloatIntkWarpsortAsGT =
SelectK<float, uint32_t, with_ref<select::Algo::kWarpImmediate>::params_random>;
TEST_P(ReferencedRandomFloatIntkWarpsortAsGT, Run) { run(); } // NOLINT
INSTANTIATE_TEST_CASE_P( // NOLINT
SelectK,
ReferencedRandomFloatIntkWarpsortAsGT,
testing::Combine(inputs_random_many_infs,
testing::Values(select::Algo::kRadix8bits,
select::Algo::kRadix11bits,
select::Algo::kRadix11bitsExtraPass)));

} // namespace raft::matrix
Loading

0 comments on commit b9cf917

Please sign in to comment.