From 0aecdee0c74ef3ea8769e2a2953638aeebc97c51 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 8 Nov 2024 17:03:31 -0500 Subject: [PATCH] Select k back --- cpp/bench/prims/CMakeLists.txt | 4 +- cpp/bench/prims/matrix/select_k.cu | 345 ++++++++++++++++++++++ cpp/test/CMakeLists.txt | 2 + cpp/test/matrix/select_k.cu | 148 ++++++++++ cpp/test/matrix/select_k.cuh | 450 +++++++++++++++++++++++++++++ cpp/test/matrix/select_large_k.cu | 36 +++ 6 files changed, 983 insertions(+), 2 deletions(-) create mode 100644 cpp/bench/prims/matrix/select_k.cu create mode 100644 cpp/test/matrix/select_k.cu create mode 100644 cpp/test/matrix/select_k.cuh create mode 100644 cpp/test/matrix/select_large_k.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 73fbb6d8ce..6bc8c802b4 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -97,8 +97,8 @@ if(BUILD_PRIMS_BENCH) ) ConfigureBench( - NAME MATRIX_BENCH PATH matrix/argmin.cu matrix/gather.cu main.cpp OPTIONAL LIB - EXPLICIT_INSTANTIATE_ONLY + NAME MATRIX_BENCH PATH matrix/argmin.cu matrix/select_k.cu matrix/gather.cu main.cpp OPTIONAL + LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( diff --git a/cpp/bench/prims/matrix/select_k.cu b/cpp/bench/prims/matrix/select_k.cu new file mode 100644 index 0000000000..ff04e6f8a8 --- /dev/null +++ b/cpp/bench/prims/matrix/select_k.cu @@ -0,0 +1,345 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +namespace raft::matrix { +using namespace raft::bench; // NOLINT + +template +struct replace_with_mask { + KeyT replacement; + int64_t line_length; + int64_t spared_inputs; + constexpr auto inline operator()(int64_t offset, KeyT x, uint8_t mask) -> KeyT + { + auto i = offset % line_length; + // don't replace all the inputs, spare a few elements at the beginning of the input + return (mask && i >= spared_inputs) ? replacement : x; + } +}; + +template +struct selection : public fixture { + explicit selection(const select::params& p) + : fixture(p.use_memory_pool), + params_(p), + in_dists_(p.batch_size * p.len, stream), + in_ids_(p.batch_size * p.len, stream), + out_dists_(p.batch_size * p.k, stream), + out_ids_(p.batch_size * p.k, stream) + { + raft::sparse::iota_fill(in_ids_.data(), IdxT(p.batch_size), IdxT(p.len), stream); + raft::random::RngState state{42}; + + KeyT min_value = -1.0; + KeyT max_value = 1.0; + if (p.use_same_leading_bits) { + if constexpr (std::is_same_v) { + uint32_t min_bits = 0x3F800000; // 1.0 + uint32_t max_bits = 0x3F8000FF; // 1.00003 + memcpy(&min_value, &min_bits, sizeof(KeyT)); + memcpy(&max_value, &max_bits, sizeof(KeyT)); + } else if constexpr (std::is_same_v) { + uint64_t min_bits = 0x3FF0000000000000; // 1.0 + uint64_t max_bits = 0x3FF0000FFFFFFFFF; // 1.000015 + memcpy(&min_value, &min_bits, sizeof(KeyT)); + memcpy(&max_value, &max_bits, sizeof(KeyT)); + } + } + raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), min_value, max_value); + if (p.frac_infinities > 0.0) { + rmm::device_uvector mask_buf(p.batch_size * p.len, stream); + auto mask = make_device_vector_view(mask_buf.data(), mask_buf.size()); + raft::random::bernoulli(handle, state, mask, p.frac_infinities); + KeyT bound = p.select_min ? raft::upper_bound() : raft::lower_bound(); + auto mask_in = + make_device_vector_view(mask_buf.data(), mask_buf.size()); + auto dists_in = make_device_vector_view(in_dists_.data(), in_dists_.size()); + auto dists_out = make_device_vector_view(in_dists_.data(), in_dists_.size()); + raft::linalg::map_offset(handle, + dists_out, + replace_with_mask{bound, int64_t(p.len), int64_t(p.k / 2)}, + dists_in, + mask_in); + } + } + + void run_benchmark(::benchmark::State& state) override // NOLINT + { + try { + std::ostringstream label_stream; + label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k; + if (params_.use_same_leading_bits) { label_stream << "#same-leading-bits"; } + if (params_.frac_infinities > 0) { label_stream << "#infs-" << params_.frac_infinities; } + state.SetLabel(label_stream.str()); + common::nvtx::range case_scope("%s - %s", state.name().c_str(), label_stream.str().c_str()); + int iter = 0; + loop_on_state(state, [&iter, this]() { + common::nvtx::range lap_scope("lap-", iter++); + + std::optional> in_ids_view; + if (params_.use_index_input) { + in_ids_view = raft::make_device_matrix_view( + in_ids_.data(), params_.batch_size, params_.len); + } + + matrix::select_k(handle, + raft::make_device_matrix_view( + in_dists_.data(), params_.batch_size, params_.len), + in_ids_view, + raft::make_device_matrix_view( + out_dists_.data(), params_.batch_size, params_.k), + raft::make_device_matrix_view( + out_ids_.data(), params_.batch_size, params_.k), + params_.select_min, + false, + Algo); + }); + } catch (raft::exception& e) { + state.SkipWithError(e.what()); + } + } + + private: + const select::params params_; + rmm::device_uvector in_dists_, out_dists_; + rmm::device_uvector in_ids_, out_ids_; +}; + +const std::vector kInputs{ + {20000, 500, 1, true}, + {20000, 500, 2, true}, + {20000, 500, 4, true}, + {20000, 500, 8, true}, + {20000, 500, 16, true}, + {20000, 500, 32, true}, + {20000, 500, 64, true}, + {20000, 500, 128, true}, + {20000, 500, 256, true}, + + {1000, 10000, 1, true}, + {1000, 10000, 2, true}, + {1000, 10000, 4, true}, + {1000, 10000, 8, true}, + {1000, 10000, 16, true}, + {1000, 10000, 32, true}, + {1000, 10000, 64, true}, + {1000, 10000, 128, true}, + {1000, 10000, 256, true}, + + {100, 100000, 1, true}, + {100, 100000, 2, true}, + {100, 100000, 4, true}, + {100, 100000, 8, true}, + {100, 100000, 16, true}, + {100, 100000, 32, true}, + {100, 100000, 64, true}, + {100, 100000, 128, true}, + {100, 100000, 256, true}, + + {10, 1000000, 1, true}, + {10, 1000000, 2, true}, + {10, 1000000, 4, true}, + {10, 1000000, 8, true}, + {10, 1000000, 16, true}, + {10, 1000000, 32, true}, + {10, 1000000, 64, true}, + {10, 1000000, 128, true}, + {10, 1000000, 256, true}, + + {10, 1000000, 1, true, false, true}, + {10, 1000000, 2, true, false, true}, + {10, 1000000, 4, true, false, true}, + {10, 1000000, 8, true, false, true}, + {10, 1000000, 16, true, false, true}, + {10, 1000000, 32, true, false, true}, + {10, 1000000, 64, true, false, true}, + {10, 1000000, 128, true, false, true}, + {10, 1000000, 256, true, false, true}, + + {10, 1000000, 1, true, false, false, true, 0.1}, + {10, 1000000, 16, true, false, false, true, 0.1}, + {10, 1000000, 64, true, false, false, true, 0.1}, + {10, 1000000, 128, true, false, false, true, 0.1}, + {10, 1000000, 256, true, false, false, true, 0.1}, + + {10, 1000000, 1, true, false, false, true, 0.9}, + {10, 1000000, 16, true, false, false, true, 0.9}, + {10, 1000000, 64, true, false, false, true, 0.9}, + {10, 1000000, 128, true, false, false, true, 0.9}, + {10, 1000000, 256, true, false, false, true, 0.9}, + {1000, 10000, 1, true, false, false, true, 0.9}, + {1000, 10000, 16, true, false, false, true, 0.9}, + {1000, 10000, 64, true, false, false, true, 0.9}, + {1000, 10000, 128, true, false, false, true, 0.9}, + {1000, 10000, 256, true, false, false, true, 0.9}, + + {10, 1000000, 1, true, false, false, true, 1.0}, + {10, 1000000, 16, true, false, false, true, 1.0}, + {10, 1000000, 64, true, false, false, true, 1.0}, + {10, 1000000, 128, true, false, false, true, 1.0}, + {10, 1000000, 256, true, false, false, true, 1.0}, + {1000, 10000, 1, true, false, false, true, 1.0}, + {1000, 10000, 16, true, false, false, true, 1.0}, + {1000, 10000, 64, true, false, false, true, 1.0}, + {1000, 10000, 128, true, false, false, true, 1.0}, + {1000, 10000, 256, true, false, false, true, 1.0}, + {1000, 10000, 256, true, false, false, true, 0.999}, +}; + +#define SELECTION_REGISTER(KeyT, IdxT, A) \ + namespace BENCHMARK_PRIVATE_NAME(selection) { \ + using SelectK = selection; \ + RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ + } + +SELECTION_REGISTER(float, uint32_t, kAuto); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT + +SELECTION_REGISTER(double, uint32_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, uint32_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, uint32_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(double, uint32_t, kWarpAuto); // NOLINT + +SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, int64_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT + +// For learning a heuristic of which selection algorithm to use, we +// have a couple of additional constraints when generating the dataset: +// 1. We want these benchmarks to be optionally enabled from the commandline - +// there are thousands of them, and the run-time is non-trivial. This should be opt-in only +// 2. We test out larger k values - that won't work for all algorithms. This requires filtering +// the input parameters per algorithm. +// This makes the code to generate this dataset different from the code above to +// register other benchmarks +#define SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, A, input) \ + { \ + using SelectK = selection; \ + std::stringstream name; \ + name << "SelectKDataset/" << #KeyT "/" #IdxT "/" #A << "/" << input.batch_size << "/" \ + << input.len << "/" << input.k << "/" << input.use_index_input << "/" \ + << input.use_memory_pool; \ + auto* b = ::benchmark::internal::RegisterBenchmarkInternal( \ + new raft::bench::internal::Fixture(name.str(), input)); \ + b->UseManualTime(); \ + b->Unit(benchmark::kMillisecond); \ + } + +const static size_t MAX_MEMORY = 16 * 1024 * 1024 * 1024ULL; + +// registers the input for all algorithms +#define SELECTION_REGISTER_INPUT(KeyT, IdxT, input) \ + { \ + size_t mem = input.batch_size * input.len * (sizeof(KeyT) + sizeof(IdxT)); \ + if (mem < MAX_MEMORY) { \ + SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kRadix8bits, input) \ + SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kRadix11bits, input) \ + SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kRadix11bitsExtraPass, input) \ + if (input.k <= raft::matrix::detail::select::warpsort::kMaxCapacity) { \ + SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kWarpImmediate, input) \ + SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kWarpFiltered, input) \ + SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kWarpDistributed, input) \ + SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, kWarpDistributedShm, input) \ + } \ + } \ + } + +void add_select_k_dataset_benchmarks() +{ + // define a uniform grid + std::vector inputs; + + size_t grid_increment = 1; + std::vector k_vals; + for (size_t k = 0; k < 13; k += grid_increment) { + k_vals.push_back(1 << k); + } + // Add in values just past the limit for warp/faiss select + k_vals.push_back(257); + k_vals.push_back(2049); + + const static bool select_min = true; + const static bool use_ids = false; + + for (size_t row = 0; row < 13; row += grid_increment) { + for (size_t col = 10; col < 28; col += grid_increment) { + for (auto k : k_vals) { + inputs.push_back( + select::params{size_t(1 << row), size_t(1 << col), k, select_min, use_ids}); + } + } + } + + // also add in some random values + std::default_random_engine rng(42); + std::uniform_real_distribution<> row_dist(0, 13); + std::uniform_real_distribution<> col_dist(10, 28); + std::uniform_real_distribution<> k_dist(0, 13); + for (size_t i = 0; i < 1024; ++i) { + auto row = static_cast(pow(2, row_dist(rng))); + auto col = static_cast(pow(2, col_dist(rng))); + auto k = static_cast(pow(2, k_dist(rng))); + inputs.push_back(select::params{row, col, k, select_min, use_ids}); + } + + for (auto& input : inputs) { + SELECTION_REGISTER_INPUT(double, int64_t, input); + SELECTION_REGISTER_INPUT(double, uint32_t, input); + SELECTION_REGISTER_INPUT(float, int64_t, input); + SELECTION_REGISTER_INPUT(float, uint32_t, input); + } + + // also try again without a memory pool to see if there are significant differences + for (auto input : inputs) { + input.use_memory_pool = false; + SELECTION_REGISTER_INPUT(double, int64_t, input); + SELECTION_REGISTER_INPUT(double, uint32_t, input); + SELECTION_REGISTER_INPUT(float, int64_t, input); + SELECTION_REGISTER_INPUT(float, uint32_t, input); + } +} +} // namespace raft::matrix diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 939c299e24..3a20598d95 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -194,6 +194,8 @@ if(BUILD_TESTS) matrix/sample_rows.cu matrix/slice.cu matrix/triangular.cu + matrix/select_k.cu + matrix/select_large_k.cu sparse/spectral_matrix.cu LIB EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu new file mode 100644 index 0000000000..f3eb32b2e1 --- /dev/null +++ b/cpp/test/matrix/select_k.cu @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "select_k.cuh" + +namespace raft::matrix { + +auto inputs_random_longlist = testing::Values(select::params{1, 130, 15, false}, + select::params{1, 128, 15, false}, + select::params{20, 700, 1, true}, + select::params{20, 700, 2, true}, + select::params{20, 700, 3, true}, + select::params{20, 700, 4, true}, + select::params{20, 700, 5, true}, + select::params{20, 700, 6, true}, + select::params{20, 700, 7, true}, + select::params{20, 700, 8, true}, + select::params{20, 700, 9, true}, + select::params{20, 700, 10, true, false}, + select::params{20, 700, 11, true}, + select::params{20, 700, 12, true}, + select::params{20, 700, 16, true}, + select::params{100, 1700, 17, true}, + select::params{100, 1700, 31, true, false}, + select::params{100, 1700, 32, false}, + select::params{100, 1700, 33, false}, + select::params{100, 1700, 63, false}, + select::params{100, 1700, 64, false, false}, + select::params{100, 1700, 65, false}, + select::params{100, 1700, 255, true}, + select::params{100, 1700, 256, true}, + select::params{100, 1700, 511, false}, + select::params{100, 1700, 512, true}, + select::params{100, 1700, 1023, false, false}, + select::params{100, 1700, 1024, true}, + select::params{100, 1700, 1700, true}); + +auto inputs_random_largesize = testing::Values(select::params{100, 100000, 1, true}, + select::params{100, 100000, 2, true}, + select::params{100, 100000, 3, true, false}, + select::params{100, 100000, 7, true}, + select::params{100, 100000, 16, true}, + select::params{100, 100000, 31, true}, + select::params{100, 100000, 32, true, false}, + select::params{100, 100000, 60, true}, + select::params{100, 100000, 100, true, false}, + select::params{100, 100000, 200, true}, + select::params{100000, 100, 100, false}, + select::params{1, 1000000000, 1, true}, + select::params{1, 1000000000, 16, false, false}, + select::params{1, 1000000000, 64, false}, + select::params{1, 1000000000, 128, true, false}, + select::params{1, 1000000000, 256, false, false}); + +auto inputs_random_largek = testing::Values(select::params{100, 100000, 1000, true}, + select::params{100, 100000, 2000, false}, + select::params{100, 100000, 100000, true, false}, + select::params{100, 100000, 2048, false}, + select::params{100, 100000, 1237, true}); + +auto inputs_random_many_infs = + testing::Values(select::params{10, 100000, 1, true, false, false, true, 0.9}, + select::params{10, 100000, 16, true, false, false, true, 0.9}, + select::params{10, 100000, 64, true, false, false, true, 0.9}, + select::params{10, 100000, 128, true, false, false, true, 0.9}, + select::params{10, 100000, 256, true, false, false, true, 0.9}, + select::params{1000, 10000, 1, true, false, false, true, 0.9}, + select::params{1000, 10000, 16, true, false, false, true, 0.9}, + select::params{1000, 10000, 64, true, false, false, true, 0.9}, + select::params{1000, 10000, 128, true, false, false, true, 0.9}, + select::params{1000, 10000, 256, true, false, false, true, 0.9}, + select::params{10, 100000, 1, true, false, false, true, 0.999}, + select::params{10, 100000, 16, true, false, false, true, 0.999}, + select::params{10, 100000, 64, true, false, false, true, 0.999}, + select::params{10, 100000, 128, true, false, false, true, 0.999}, + select::params{10, 100000, 256, true, false, false, true, 0.999}, + select::params{1000, 10000, 1, true, false, false, true, 0.999}, + select::params{1000, 10000, 16, true, false, false, true, 0.999}, + select::params{1000, 10000, 64, true, false, false, true, 0.999}, + select::params{1000, 10000, 128, true, false, false, true, 0.999}, + select::params{1000, 10000, 256, true, false, false, true, 0.999}); + +using ReferencedRandomFloatInt = + SelectK::params_random>; +TEST_P(ReferencedRandomFloatInt, Run) { run(); } // NOLINT +INSTANTIATE_TEST_CASE_P( // NOLINT + SelectK, + ReferencedRandomFloatInt, + testing::Combine(inputs_random_longlist, + testing::Values(SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass, + SelectAlgo::kWarpImmediate, + SelectAlgo::kWarpFiltered, + SelectAlgo::kWarpDistributed, + SelectAlgo::kWarpDistributedShm))); + +using ReferencedRandomDoubleSizeT = + SelectK::params_random>; +TEST_P(ReferencedRandomDoubleSizeT, Run) { run(); } // NOLINT +INSTANTIATE_TEST_CASE_P( // NOLINT + SelectK, + ReferencedRandomDoubleSizeT, + testing::Combine(inputs_random_longlist, + testing::Values(SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass, + SelectAlgo::kWarpImmediate, + SelectAlgo::kWarpFiltered, + SelectAlgo::kWarpDistributed, + SelectAlgo::kWarpDistributedShm))); + +using ReferencedRandomDoubleInt = + SelectK::params_random>; +TEST_P(ReferencedRandomDoubleInt, LargeSize) { run(); } // NOLINT +INSTANTIATE_TEST_CASE_P( // NOLINT + SelectK, + ReferencedRandomDoubleInt, + testing::Combine(inputs_random_largesize, + testing::Values(SelectAlgo::kWarpAuto, + SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass))); + +using ReferencedRandomFloatIntkWarpsortAsGT = + SelectK::params_random>; +TEST_P(ReferencedRandomFloatIntkWarpsortAsGT, Run) { run(); } // NOLINT +INSTANTIATE_TEST_CASE_P( // NOLINT + SelectK, + ReferencedRandomFloatIntkWarpsortAsGT, + testing::Combine(inputs_random_many_infs, + testing::Values(SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass))); + +} // namespace raft::matrix diff --git a/cpp/test/matrix/select_k.cuh b/cpp/test/matrix/select_k.cuh new file mode 100644 index 0000000000..f22f4f5fa7 --- /dev/null +++ b/cpp/test/matrix/select_k.cuh @@ -0,0 +1,450 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include + +#include + +#include + +#include + +#include +#include + +namespace raft::matrix { + +template +auto gen_simple_ids(uint32_t batch_size, uint32_t len) -> std::vector +{ + std::vector out(batch_size * len); + auto s = rmm::cuda_stream_default; + rmm::device_uvector out_d(out.size(), s); + sparse::iota_fill(out_d.data(), IdxT(batch_size), IdxT(len), s); + update_host(out.data(), out_d.data(), out.size(), s); + s.synchronize(); + return out; +} + +template +struct io_simple { + public: + bool not_supported = false; + std::optional algo = std::nullopt; + + io_simple(const select::params& spec, + const std::vector& in_dists, + const std::optional>& in_ids, + const std::vector& out_dists, + const std::vector& out_ids) + : in_dists_(in_dists), + in_ids_(in_ids.value_or(gen_simple_ids(spec.batch_size, spec.len))), + out_dists_(out_dists), + out_ids_(out_ids) + { + } + + auto get_in_dists() -> std::vector& { return in_dists_; } + auto get_in_ids() -> std::vector& { return in_ids_; } + auto get_out_dists() -> std::vector& { return out_dists_; } + auto get_out_ids() -> std::vector& { return out_ids_; } + + private: + std::vector in_dists_; + std::vector in_ids_; + std::vector out_dists_; + std::vector out_ids_; +}; + +template +struct io_computed { + public: + bool not_supported = false; + SelectAlgo algo; + + io_computed(const select::params& spec, + const SelectAlgo& algo, + const std::vector& in_dists, + const std::optional>& in_ids = std::nullopt) + : algo(algo), + in_dists_(in_dists), + in_ids_(in_ids.value_or(gen_simple_ids(spec.batch_size, spec.len))), + out_dists_(spec.batch_size * spec.k), + out_ids_(spec.batch_size * spec.k) + { + // check if the size is supported by the algorithm + switch (algo) { + case SelectAlgo::kWarpAuto: + case SelectAlgo::kWarpImmediate: + case SelectAlgo::kWarpFiltered: + case SelectAlgo::kWarpDistributed: + case SelectAlgo::kWarpDistributedShm: { + if (spec.k > raft::matrix::detail::select::warpsort::kMaxCapacity) { + not_supported = true; + return; + } + } break; + default: break; + } + + resources handle{}; + auto stream = resource::get_cuda_stream(handle); + + rmm::device_uvector in_dists_d(in_dists_.size(), stream); + rmm::device_uvector in_ids_d(in_ids_.size(), stream); + rmm::device_uvector out_dists_d(out_dists_.size(), stream); + rmm::device_uvector out_ids_d(out_ids_.size(), stream); + + update_device(in_dists_d.data(), in_dists_.data(), in_dists_.size(), stream); + update_device(in_ids_d.data(), in_ids_.data(), in_ids_.size(), stream); + + std::optional> in_ids_view; + if (spec.use_index_input) { + in_ids_view = raft::make_device_matrix_view( + in_ids_d.data(), spec.batch_size, spec.len); + } + + matrix::select_k( + handle, + raft::make_device_matrix_view( + in_dists_d.data(), spec.batch_size, spec.len), + in_ids_view, + raft::make_device_matrix_view(out_dists_d.data(), spec.batch_size, spec.k), + raft::make_device_matrix_view(out_ids_d.data(), spec.batch_size, spec.k), + spec.select_min, + false, + algo); + + update_host(out_dists_.data(), out_dists_d.data(), out_dists_.size(), stream); + update_host(out_ids_.data(), out_ids_d.data(), out_ids_.size(), stream); + + interruptible::synchronize(stream); + + auto p = topk_sort_permutation(out_dists_, out_ids_, spec.k, spec.select_min); + apply_permutation(out_dists_, p); + apply_permutation(out_ids_, p); + } + + auto get_in_dists() -> std::vector& { return in_dists_; } + auto get_in_ids() -> std::vector& { return in_ids_; } + auto get_out_dists() -> std::vector& { return out_dists_; } + auto get_out_ids() -> std::vector& { return out_ids_; } + + private: + std::vector in_dists_; + std::vector in_ids_; + std::vector out_dists_; + std::vector out_ids_; + + auto topk_sort_permutation(const std::vector& vec, + const std::vector& inds, + uint32_t k, + bool select_min) -> std::vector + { + std::vector p(vec.size()); + std::iota(p.begin(), p.end(), 0); + if (select_min) { + std::sort(p.begin(), p.end(), [&vec, &inds, k](IdxT i, IdxT j) { + const IdxT ik = i / k; + const IdxT jk = j / k; + if (ik == jk) { + if (vec[i] == vec[j]) { return inds[i] < inds[j]; } + return vec[i] < vec[j]; + } + return ik < jk; + }); + } else { + std::sort(p.begin(), p.end(), [&vec, &inds, k](IdxT i, IdxT j) { + const IdxT ik = i / k; + const IdxT jk = j / k; + if (ik == jk) { + if (vec[i] == vec[j]) { return inds[i] < inds[j]; } + return vec[i] > vec[j]; + } + return ik < jk; + }); + } + return p; + } + + template + void apply_permutation(std::vector& vec, const std::vector& p) // NOLINT + { + for (auto i = IdxT(vec.size()) - 1; i > 0; i--) { + auto j = p[i]; + while (j > i) + j = p[j]; + std::swap(vec[j], vec[i]); + } + } +}; + +template +using Params = std::tuple; + +template typename ParamsReader> +struct SelectK // NOLINT + : public testing::TestWithParam::params_t> { + const select::params spec; + const SelectAlgo algo; + typename ParamsReader::io_t ref; + io_computed res; + + explicit SelectK(Params::io_t> ps) + : spec(std::get<0>(ps)), + algo(std::get<1>(ps)), // NOLINT + ref(std::get<2>(ps)), // NOLINT + res(spec, algo, ref.get_in_dists(), ref.get_in_ids()) // NOLINT + { + } + + explicit SelectK(typename ParamsReader::params_t ps) + : SelectK(ParamsReader::read(ps)) + { + } + + SelectK() + : SelectK(testing::TestWithParam::params_t>::GetParam()) + { + } + + void run() + { + if (ref.not_supported || res.not_supported) { GTEST_SKIP(); } + ASSERT_TRUE(hostVecMatch(ref.get_out_dists(), res.get_out_dists(), Compare())); + + // If the dists (keys) are the same, different corresponding ids may end up in the selection + // due to non-deterministic nature of some implementations. + auto compare_ids = [this](const IdxT& i, const IdxT& j) { + if (i == j) return true; + auto& in_ids = ref.get_in_ids(); + auto& in_dists = ref.get_in_dists(); + auto ix_i = static_cast(std::find(in_ids.begin(), in_ids.end(), i) - in_ids.begin()); + auto ix_j = static_cast(std::find(in_ids.begin(), in_ids.end(), j) - in_ids.begin()); + auto forgive_i = forgive_algo(ref.algo, i); + auto forgive_j = forgive_algo(res.algo, j); + // Some algorithms return invalid indices in special cases. + // TODO: https://github.com/rapidsai/raft/issues/1822 + if (static_cast(ix_i) >= in_ids.size()) return forgive_i; + if (static_cast(ix_j) >= in_ids.size()) return forgive_j; + auto dist_i = in_dists[ix_i]; + auto dist_j = in_dists[ix_j]; + if (dist_i == dist_j) return true; + const auto bound = spec.select_min ? raft::upper_bound() : raft::lower_bound(); + if (forgive_i && dist_i == bound) return true; + if (forgive_j && dist_j == bound) return true; + // Otherwise really fail + std::cout << "ERROR: ref[" << ix_i << "] = " << dist_i << " != " + << "res[" << ix_j << "] = " << dist_j << std::endl; + return false; + }; + ASSERT_TRUE(hostVecMatch(ref.get_out_ids(), res.get_out_ids(), compare_ids)); + } + + auto forgive_algo(const std::optional& algo, IdxT ix) const -> bool + { + if (!algo.has_value()) { return false; } + switch (algo.value()) { + // not sure which algo this is. + case SelectAlgo::kAuto: return true; + // warp-sort-based algos currently return zero index for inf distances. + case SelectAlgo::kWarpAuto: + case SelectAlgo::kWarpImmediate: + case SelectAlgo::kWarpFiltered: + case SelectAlgo::kWarpDistributed: + case SelectAlgo::kWarpDistributedShm: return ix == 0; + // Do not forgive by default + default: return false; + } + } +}; + +template +struct params_simple { + using io_t = io_simple; + using input_t = std::tuple, + std::optional>, + std::vector, + std::vector>; + using params_t = std::tuple; + + static auto read(params_t ps) -> Params + { + auto ins = std::get<0>(ps); + auto algo = std::get<1>(ps); + return std::make_tuple( + std::get<0>(ins), + algo, + io_simple( + std::get<0>(ins), std::get<1>(ins), std::get<2>(ins), std::get<3>(ins), std::get<4>(ins))); + } +}; + +auto inf_f = std::numeric_limits::max(); +auto inputs_simple_f = testing::Values( + params_simple::input_t( + {5, 5, 5, true, true}, + {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, + 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + std::nullopt, + {1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, + 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0}, + {4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}), + params_simple::input_t( + {5, 5, 3, true, true}, + {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, + 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + std::nullopt, + {1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, + {4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}), + params_simple::input_t( + {5, 5, 5, true, false}, + {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, + 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + std::nullopt, + {1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, + 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0}, + {4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}), + params_simple::input_t( + {5, 5, 3, true, false}, + {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, + 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + std::nullopt, + {1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, + {4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}), + params_simple::input_t( + {5, 7, 3, true, true}, + {5.0, 4.0, 3.0, 2.0, 1.3, 7.5, 19.0, 9.0, 2.0, 3.0, 3.0, 5.0, 6.0, 4.0, 2.0, 3.0, 5.0, 1.0, + 4.0, 1.0, 1.0, 5.0, 7.0, 2.5, 4.0, 7.0, 8.0, 8.0, 1.0, 3.0, 2.0, 5.0, 4.0, 1.1, 1.2}, + std::nullopt, + {1.3, 2.0, 3.0, 2.0, 3.0, 3.0, 1.0, 1.0, 1.0, 2.5, 4.0, 5.0, 1.0, 1.1, 1.2}, + {4, 3, 2, 1, 2, 3, 3, 5, 6, 2, 3, 0, 0, 5, 6}), + params_simple::input_t({1, 7, 3, true, true}, + {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, + std::nullopt, + {1.0, 1.0, 1.0}, + {3, 5, 6}), + params_simple::input_t({1, 7, 3, false, false}, + {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, + std::nullopt, + {5.0, 4.0, 3.0}, + {2, 4, 1}), + params_simple::input_t({1, 7, 3, false, true}, + {2.0, 3.0, 5.0, 9.0, 4.0, 9.0, 9.0}, + std::nullopt, + {9.0, 9.0, 9.0}, + {3, 5, 6}), + params_simple::input_t( + {1, 130, 5, false, true}, + {19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, + 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 4, 2, 3, 2, 3, 2, 3, 2, 3, 2, 20}, + std::nullopt, + {20, 19, 18, 17, 16}, + {129, 0, 117, 116, 115}), + params_simple::input_t( + {1, 130, 15, false, true}, + {19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, + 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 4, 2, 3, 2, 3, 2, 3, 2, 3, 2, 20}, + std::nullopt, + {20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6}, + {129, 0, 117, 116, 115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105}), + params_simple::input_t( + select::params{1, 32, 31, true, true}, + {0, 1, 2, 3, inf_f, inf_f, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + std::optional{std::vector{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, + 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, + 9, 8, 7, 6, 75, 74, 3, 2, 1, 0}}, + {0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, inf_f}, + {31, 30, 29, 28, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, + 13, 12, 11, 10, 9, 8, 7, 6, 75, 74, 3, 2, 1, 0, 27})); + +using SimpleFloatInt = SelectK; +TEST_P(SimpleFloatInt, Run) { run(); } // NOLINT +INSTANTIATE_TEST_CASE_P( // NOLINT + SelectK, + SimpleFloatInt, + testing::Combine(inputs_simple_f, + testing::Values(SelectAlgo::kAuto, + SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass, + SelectAlgo::kWarpImmediate, + SelectAlgo::kWarpFiltered, + SelectAlgo::kWarpDistributed))); + +template +struct replace_with_mask { + KeyT replacement; + constexpr auto inline operator()(KeyT x, uint8_t mask) -> KeyT { return mask ? replacement : x; } +}; + +template +struct with_ref { + template + struct params_random { + using io_t = io_computed; + using params_t = std::tuple; + + static auto read(params_t ps) -> Params + { + auto spec = std::get<0>(ps); + auto algo = std::get<1>(ps); + std::vector dists(spec.len * spec.batch_size); + + raft::resources handle; + { + auto s = resource::get_cuda_stream(handle); + rmm::device_uvector dists_d(spec.len * spec.batch_size, s); + raft::random::RngState r(42); + normal(handle, r, dists_d.data(), dists_d.size(), KeyT(10.0), KeyT(100.0)); + + if (spec.frac_infinities > 0.0) { + rmm::device_uvector mask_buf(dists_d.size(), s); + auto mask = make_device_vector_view(mask_buf.data(), mask_buf.size()); + raft::random::bernoulli(handle, r, mask, spec.frac_infinities); + KeyT bound = spec.select_min ? raft::upper_bound() : raft::lower_bound(); + auto mask_in = + make_device_vector_view(mask_buf.data(), mask_buf.size()); + auto dists_in = make_device_vector_view(dists_d.data(), dists_d.size()); + auto dists_out = make_device_vector_view(dists_d.data(), dists_d.size()); + raft::linalg::map(handle, dists_out, replace_with_mask{bound}, dists_in, mask_in); + } + + update_host(dists.data(), dists_d.data(), dists_d.size(), s); + s.synchronize(); + } + + return std::make_tuple(spec, algo, io_computed(spec, RefAlgo, dists)); + } + }; +}; + +} // namespace raft::matrix diff --git a/cpp/test/matrix/select_large_k.cu b/cpp/test/matrix/select_large_k.cu new file mode 100644 index 0000000000..baa07f5e87 --- /dev/null +++ b/cpp/test/matrix/select_large_k.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "select_k.cuh" + +namespace raft::matrix { + +auto inputs_random_largek = testing::Values(select::params{100, 100000, 1000, true}, + select::params{100, 100000, 2000, false}, + select::params{100, 100000, 100000, true, false}, + select::params{100, 100000, 2048, false}, + select::params{100, 100000, 1237, true}); + +using ReferencedRandomFloatSizeT = + SelectK::params_random>; +TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } // NOLINT +INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT + ReferencedRandomFloatSizeT, + testing::Combine(inputs_random_largek, + testing::Values(SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass))); + +} // namespace raft::matrix