From 0b9692b25f78cd1b27631e354e3f8921a976645c Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 19 Mar 2024 18:07:21 +0100 Subject: [PATCH 1/2] random sampling of dataset rows with improved memory utilization (#2155) The random sampling of IVF methods was reverted (#2144) due to large memory utilization #2141. This PR improves the memory consumption of subsamling: it is O(n_train) where n_train is the size of the subsampled dataset. This PR adds the following new APIs: - random::excess_sampling (todo may just call as sample_without_replacement) - matrix::sample_rows - matrix::gather for host input matrix Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/2155 --- cpp/bench/prims/CMakeLists.txt | 2 +- cpp/bench/prims/matrix/gather.cu | 38 ++++- cpp/bench/prims/random/subsample.cu | 112 ++++++++++++++ cpp/include/raft/matrix/detail/gather.cuh | 87 +++++++++++ .../raft/matrix/detail/sample_rows.cuh | 57 +++++++ cpp/include/raft/matrix/sample_rows.cuh | 75 ++++++++++ cpp/include/raft/random/detail/rng_impl.cuh | 138 ++++++++++++++++- cpp/include/raft/random/rng.cuh | 26 ++++ cpp/test/CMakeLists.txt | 2 + cpp/test/matrix/sample_rows.cu | 140 ++++++++++++++++++ cpp/test/random/excess_sampling.cu | 114 ++++++++++++++ 11 files changed, 786 insertions(+), 5 deletions(-) create mode 100644 cpp/bench/prims/random/subsample.cu create mode 100644 cpp/include/raft/matrix/detail/sample_rows.cuh create mode 100644 cpp/include/raft/matrix/sample_rows.cuh create mode 100644 cpp/test/matrix/sample_rows.cu create mode 100644 cpp/test/random/excess_sampling.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 903f4e4347..95361e19ca 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -128,7 +128,7 @@ if(BUILD_PRIMS_BENCH) ConfigureBench( NAME RANDOM_BENCH PATH bench/prims/random/make_blobs.cu bench/prims/random/permute.cu - bench/prims/random/rng.cu bench/prims/main.cpp + bench/prims/random/rng.cu bench/prims/random/subsample.cu bench/prims/main.cpp ) ConfigureBench(NAME SPARSE_BENCH PATH bench/prims/sparse/convert_csr.cu bench/prims/main.cpp) diff --git a/cpp/bench/prims/matrix/gather.cu b/cpp/bench/prims/matrix/gather.cu index e6f26ba925..078f9e6198 100644 --- a/cpp/bench/prims/matrix/gather.cu +++ b/cpp/bench/prims/matrix/gather.cu @@ -16,34 +16,48 @@ #include +#include +#include #include #include #include #include #include +#include namespace raft::bench::matrix { template struct GatherParams { IdxT rows, cols, map_length; + bool host; }; template inline auto operator<<(std::ostream& os, const GatherParams& p) -> std::ostream& { - os << p.rows << "#" << p.cols << "#" << p.map_length; + os << p.rows << "#" << p.cols << "#" << p.map_length << (p.host ? "#host" : "#device"); return os; } template struct Gather : public fixture { Gather(const GatherParams& p) - : params(p), matrix(this->handle), map(this->handle), out(this->handle), stencil(this->handle) + : params(p), + old_mr(rmm::mr::get_current_device_resource()), + pool_mr(rmm::mr::get_current_device_resource(), 2 * (1ULL << 30)), + matrix(this->handle), + map(this->handle), + out(this->handle), + stencil(this->handle), + matrix_h(this->handle) { + rmm::mr::set_current_device_resource(&pool_mr); } + ~Gather() { rmm::mr::set_current_device_resource(old_mr); } + void allocate_data(const ::benchmark::State& state) override { matrix = raft::make_device_matrix(handle, params.rows, params.cols); @@ -59,6 +73,11 @@ struct Gather : public fixture { if constexpr (Conditional) { raft::random::uniform(handle, rng, stencil.data_handle(), params.map_length, T(-1), T(1)); } + + if (params.host) { + matrix_h = raft::make_host_matrix(handle, params.rows, params.cols); + raft::copy(matrix_h.data_handle(), matrix.data_handle(), matrix.size(), stream); + } resource::sync_stream(handle, stream); } @@ -77,14 +96,22 @@ struct Gather : public fixture { raft::matrix::gather_if( handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op); } else { - raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + if (params.host) { + raft::matrix::detail::gather( + handle, make_const_mdspan(matrix_h.view()), map_const_view, out.view()); + } else { + raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + } } }); } private: GatherParams params; + rmm::mr::device_memory_resource* old_mr; + rmm::mr::pool_memory_resource pool_mr; raft::device_matrix matrix, out; + raft::host_matrix matrix_h; raft::device_vector stencil; raft::device_vector map; }; // struct Gather @@ -100,4 +127,9 @@ RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); + +auto inputs_host = raft::util::itertools::product>( + {10000000}, {100}, {1000, 1000000, 10000000}, {true}); +RAFT_BENCH_REGISTER((Gather), "Host", inputs_host); + } // namespace raft::bench::matrix diff --git a/cpp/bench/prims/random/subsample.cu b/cpp/bench/prims/random/subsample.cu new file mode 100644 index 0000000000..4c8ca2bf31 --- /dev/null +++ b/cpp/bench/prims/random/subsample.cu @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace raft::bench::random { + +struct sample_inputs { + int n_samples; + int n_train; + int method; +}; // struct sample_inputs + +inline auto operator<<(std::ostream& os, const sample_inputs& p) -> std::ostream& +{ + os << p.n_samples << "#" << p.n_train << "#" << p.method; + return os; +} + +// Sample with replacement. We use this as a baseline. +template +auto bernoulli_subsample(raft::resources const& res, IdxT n_samples, IdxT n_subsamples, int seed) + -> raft::device_vector +{ + RAFT_EXPECTS(n_subsamples <= n_samples, "Cannot have more training samples than dataset vectors"); + + auto indices = raft::make_device_vector(res, n_subsamples); + raft::random::RngState state(123456ULL); + raft::random::uniformInt( + res, state, indices.data_handle(), n_subsamples, IdxT(0), IdxT(n_samples)); + return indices; +} + +template +struct sample : public fixture { + sample(const sample_inputs& p) + : params(p), + old_mr(rmm::mr::get_current_device_resource()), + pool_mr(rmm::mr::get_current_device_resource(), 2 * GiB), + in(make_device_vector(res, p.n_samples)), + out(make_device_vector(res, p.n_train)) + { + rmm::mr::set_current_device_resource(&pool_mr); + raft::random::RngState r(123456ULL); + } + + ~sample() { rmm::mr::set_current_device_resource(old_mr); } + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + raft::random::RngState r(123456ULL); + loop_on_state(state, [this, &r]() { + if (params.method == 1) { + this->out = + bernoulli_subsample(this->res, this->params.n_samples, this->params.n_train, 137); + } else if (params.method == 2) { + this->out = raft::random::excess_subsample( + this->res, r, this->params.n_samples, this->params.n_train); + } + }); + } + + private: + float GiB = 1073741824.0f; + raft::device_resources res; + rmm::mr::device_memory_resource* old_mr; + rmm::mr::pool_memory_resource pool_mr; + sample_inputs params; + raft::device_vector out, in; +}; // struct sample + +const std::vector input_vecs = {{100000000, 10000000, 1}, + {100000000, 50000000, 1}, + {100000000, 100000000, 1}, + {100000000, 10000000, 2}, + {100000000, 50000000, 2}, + {100000000, 100000000, 2}}; + +RAFT_BENCH_REGISTER(sample, "", input_vecs); + +} // namespace raft::bench::random diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 651fec81c3..05cc9204bf 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -16,9 +16,19 @@ #pragma once +#include +#include +#include +#include +#include #include +#include +#include +#include #include +#include + #include namespace raft { @@ -336,6 +346,83 @@ void gather_if(const InputIteratorT in, gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream); } +/** + * Helper function to gather a set of vectors from a (host) dataset. + */ +template +void gather_buff(host_matrix_view dataset, + host_vector_view indices, + MatIdxT offset, + pinned_matrix_view buff) +{ + raft::common::nvtx::range fun_scope("gather_host_buff"); + IdxT batch_size = std::min(buff.extent(0), indices.extent(0) - offset); + +#pragma omp for + for (IdxT i = 0; i < batch_size; i++) { + IdxT in_idx = indices(offset + i); + for (IdxT k = 0; k < buff.extent(1); k++) { + buff(i, k) = dataset(in_idx, k); + } + } +} + +template +void gather(raft::resources const& res, + host_matrix_view dataset, + device_vector_view indices, + raft::device_matrix_view output) +{ + raft::common::nvtx::range fun_scope("gather"); + IdxT n_dim = output.extent(1); + IdxT n_train = output.extent(0); + auto indices_host = raft::make_host_vector(n_train); + raft::copy( + indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res)); + resource::sync_stream(res); + + const size_t buffer_size = 32768 * 1024; // bytes + const size_t max_batch_size = + std::min(round_up_safe(buffer_size / n_dim, 32), n_train); + RAFT_LOG_DEBUG("Gathering data with batch size %zu", max_batch_size); + + // Gather the vector on the host in tmp buffers. We use two buffers to overlap H2D sync + // and gathering the data. + auto out_tmp1 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + auto out_tmp2 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + + // Usually a limited number of threads provide sufficient bandwidth for gathering data. + int n_threads = std::min(omp_get_max_threads(), 32); + + // The gather_buff function has a parallel for loop. We start the the omp parallel + // region here, to avoid repeated overhead within the device_offset loop. +#pragma omp parallel num_threads(n_threads) + { + auto view1 = out_tmp1.view(); + auto view2 = out_tmp2.view(); + gather_buff(dataset, make_const_mdspan(indices_host.view()), (MatIdxT)0, view1); + for (MatIdxT device_offset = 0; device_offset < n_train; device_offset += max_batch_size) { + MatIdxT batch_size = std::min(max_batch_size, n_train - device_offset); + +#pragma omp master + raft::copy(output.data_handle() + device_offset * n_dim, + view1.data_handle(), + batch_size * n_dim, + resource::get_cuda_stream(res)); + // Start gathering the next batch on the host. + MatIdxT host_offset = device_offset + batch_size; + batch_size = std::min(max_batch_size, n_train - host_offset); + if (batch_size > 0) { + gather_buff(dataset, make_const_mdspan(indices_host.view()), host_offset, view2); + } +#pragma omp master + resource::sync_stream(res); +#pragma omp barrier + std::swap(view1, view2); + } + } +} + } // namespace detail } // namespace matrix } // namespace raft diff --git a/cpp/include/raft/matrix/detail/sample_rows.cuh b/cpp/include/raft/matrix/detail/sample_rows.cuh new file mode 100644 index 0000000000..e28ad648da --- /dev/null +++ b/cpp/include/raft/matrix/detail/sample_rows.cuh @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::matrix::detail { + +/** Select rows randomly from input and copy to output. */ +template +void sample_rows(raft::resources const& res, + random::RngState random_state, + const T* input, + IdxT n_rows_input, + raft::device_matrix_view output) +{ + IdxT n_dim = output.extent(1); + IdxT n_samples = output.extent(0); + + raft::device_vector train_indices = + raft::random::excess_subsample(res, random_state, n_rows_input, n_samples); + + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, input)); + T* ptr = reinterpret_cast(attr.devicePointer); + if (ptr != nullptr) { + raft::matrix::gather(res, + raft::make_device_matrix_view(ptr, n_rows_input, n_dim), + raft::make_const_mdspan(train_indices.view()), + output); + } else { + auto dataset = raft::make_host_matrix_view(input, n_rows_input, n_dim); + raft::matrix::detail::gather(res, dataset, make_const_mdspan(train_indices.view()), output); + } +} +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/sample_rows.cuh b/cpp/include/raft/matrix/sample_rows.cuh new file mode 100644 index 0000000000..7925d344e4 --- /dev/null +++ b/cpp/include/raft/matrix/sample_rows.cuh @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::matrix { + +/** @brief Select rows randomly from input and copy to output. + * + * The rows are selected randomly. The random sampling method does not guarantee completely unique + * selection of rows, but it is close to being unique. + * + * @param res RAFT resource handle + * @param random_state + * @param dataset input dataset + * @param output subsampled dataset + */ +template +void sample_rows(raft::resources const& res, + random::RngState random_state, + mdspan, row_major, accessor> dataset, + raft::device_matrix_view output) +{ + RAFT_EXPECTS(dataset.extent(1) == output.extent(1), + "dataset dims must match, but received %ld vs %ld", + static_cast(dataset.extent(1)), + static_cast(output.extent(1))); + detail::sample_rows(res, random_state, dataset.data_handle(), dataset.extent(0), output); +} + +/** @brief Select rows randomly from input and copy to output. + * + * The rows are selected randomly. The random sampling method does not guarantee completely unique + * selection of rows, but it is close to being unique. + * + * @param res RAFT resource handle + * @param random_state + * @param dataset input dataset + * @param n_samples number of rows in the returned matrix + * + * @return subsampled dataset + * */ +template +raft::device_matrix sample_rows( + raft::resources const& res, + random::RngState random_state, + mdspan, row_major, accessor> dataset, + IdxT n_samples) +{ + auto output = raft::make_device_matrix(res, n_samples, dataset.extent(1)); + sample_rows(res, random_state, dataset, output.view()); + return output; +} + +} // namespace raft::matrix diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index 57f4c8d33d..61a944e9b6 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,12 +17,20 @@ #pragma once #include +#include +#include +#include +#include #include #include #include #include #include +#include + +#include + namespace raft { namespace random { namespace detail { @@ -278,6 +286,7 @@ std::enable_if_t> discrete(RngState& rng_state, len); } +/** Note the memory space requirements are O(4*len) */ template void sampleWithoutReplacement(RngState& rng_state, DataT* out, @@ -328,6 +337,133 @@ void affine_transform_params(RngState const& rng_state, IdxT n, IdxT& a, IdxT& b b = mt_rng() % n; } +/** @brief Sample without replacement from range 0..N-1. + * + * Elements are sampled uniformly. + * The algorithm will allocate a workspace of size O(4*n_samples) internally. + * + * We use max N random numbers. Depending on how large n_samples is w.r.t to N, we + * either use rejection sampling, or sort the [0..N-1] values using random keys. + * + * @tparam IdxT type of indices that we sample + * @tparam MatIdxT extent type of the returned mdarray + * + * @param res RAFT resource handle + * @param state random number generator state + * @param N number of elements to sample from. We will sample values in range 0..N-1 + * @param n_samples number of samples to return + * + * @return device mdarray with the random samples + */ +template +auto excess_subsample(raft::resources const& res, RngState& state, IdxT N, IdxT n_samples) + -> raft::device_vector +{ + RAFT_EXPECTS(n_samples <= N, "Cannot have more training samples than dataset vectors"); + + // Number of samples we'll need to sample (with replacement), to expect 'k' + // unique samples from 'n' is given by the following equation: log(1 - k/n)/log(1 - 1/n) ref: + // https://stats.stackexchange.com/questions/296005/the-expected-number-of-unique-elements-drawn-with-replacement + IdxT n_excess_samples = + n_samples < N + ? std::ceil(raft::log(1 - double(n_samples) / double(N)) / (raft::log(1 - 1 / double(N)))) + : N; + + // There is a variance of n_excess_samples, we take 10% more elements. + n_excess_samples += std::max(0.1 * n_samples, 100); + + // n_excess_sampless will be larger than N around k = 0.64*N. When we reach N, then instead of + // doing rejection sampling, we simply shuffle the range [0..N-1] using N random numbers. + n_excess_samples = std::min(n_excess_samples, N); + auto rnd_idx = raft::make_device_vector(res, n_excess_samples); + + auto linear_idx = raft::make_device_vector(res, rnd_idx.size()); + raft::linalg::map_offset(res, linear_idx.view(), identity_op()); + + uniformInt(res, state, rnd_idx.data_handle(), rnd_idx.size(), IdxT(0), IdxT(N)); + + // Sort indices according to rnd keys + size_t workspace_size = 0; + auto stream = resource::get_cuda_stream(res); + cub::DeviceMergeSort::SortPairs(nullptr, + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + rnd_idx.size(), + raft::less_op{}, + stream); + auto workspace = raft::make_device_vector(res, workspace_size); + cub::DeviceMergeSort::SortPairs(workspace.data_handle(), + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + rnd_idx.size(), + raft::less_op{}, + stream); + + if (rnd_idx.size() == static_cast(N)) { + // We shuffled the linear_idx array by sorting it according to rnd_idx. + // We return the first n_samples elements. + if (n_samples == N) { return linear_idx; } + rnd_idx = raft::make_device_vector(res, n_samples); + raft::copy(rnd_idx.data_handle(), linear_idx.data_handle(), n_samples, stream); + return rnd_idx; + } + // Else we do a rejection sampling (or excess sampling): we generated more random indices than + // needed and reject the duplicates. + auto keys_out = raft::make_device_vector(res, rnd_idx.size()); + auto values_out = raft::make_device_vector(res, rnd_idx.size()); + rmm::device_scalar num_selected(stream); + size_t worksize2 = 0; + cub::DeviceSelect::UniqueByKey(nullptr, + worksize2, + rnd_idx.data_handle(), + linear_idx.data_handle(), + keys_out.data_handle(), + values_out.data_handle(), + num_selected.data(), + rnd_idx.size(), + stream); + + if (worksize2 > workspace.size()) { + workspace = raft::make_device_vector(res, worksize2); + workspace_size = workspace.size(); + } + + cub::DeviceSelect::UniqueByKey(workspace.data_handle(), + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + keys_out.data_handle(), + values_out.data_handle(), + num_selected.data(), + rnd_idx.size(), + stream); + + IdxT selected = num_selected.value(stream); + + if (selected < n_samples) { + RAFT_LOG_DEBUG("Subsampling returned with less unique indices (%zu) than requested (%zu)", + (size_t)selected, + (size_t)n_samples); + + // We continue to select n_samples elements, this will now contains a few duplicates. + } + + // After duplicates are removed, we need to shuffle back to random order + cub::DeviceMergeSort::SortPairs(workspace.data_handle(), + workspace_size, + values_out.data_handle(), + keys_out.data_handle(), + n_samples, + raft::less_op{}, + stream); + + values_out = raft::make_device_vector(res, n_samples); + raft::copy(values_out.data_handle(), keys_out.data_handle(), n_samples, stream); + return values_out; +} + }; // end namespace detail }; // end namespace random }; // end namespace raft diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 4e63669f98..7fd461980f 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -813,6 +813,32 @@ void sampleWithoutReplacement(raft::resources const& handle, rng_state, out, outIdx, in, wts, sampledLen, len, resource::get_cuda_stream(handle)); } +/** @brief Sample from range 0..N-1. + * + * Elements are sampled uniformly. The method aims to sample without replacement, + * but there is a small probability of a few having duplicate elements. + * + * The algorithm will allocate a workspace of size 4*n_samples*sizeof(IdxT) internally. + * + * We use max N random numbers. Depending on how large n_samples is w.r.t to N, we + * either use rejection sampling, or sort the [0..N-1] values using random keys. + * + * @tparam IdxT type of indices that we sample + * @tparam MatIdxT extent type of the returned mdarray + * + * @param res RAFT resource handle + * @param state random number generator state + * @param N number of elements to sample from. We will sample values in range 0..N-1. + * @param n_samples number of samples to return + * + * @return device mdarray with the random samples + */ +template +auto excess_subsample(raft::resources const& res, RngState& state, IdxT N, IdxT n_samples) +{ + return detail::excess_subsample(res, state, N, n_samples); +} + /** * @brief Generates the 'a' and 'b' parameters for a modulo affine * transformation equation: `(ax + b) % n` diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index bf44cf9c60..ecb871fccc 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -267,6 +267,7 @@ if(BUILD_TESTS) test/matrix/matrix.cu test/matrix/norm.cu test/matrix/reverse.cu + test/matrix/sample_rows.cu test/matrix/slice.cu test/matrix/triangular.cu test/sparse/spectral_matrix.cu @@ -294,6 +295,7 @@ if(BUILD_TESTS) test/random/rng_int.cu test/random/rmat_rectangular_generator.cu test/random/sample_without_replacement.cu + test/random/excess_sampling.cu ) ConfigureTest( diff --git a/cpp/test/matrix/sample_rows.cu b/cpp/test/matrix/sample_rows.cu new file mode 100644 index 0000000000..e332a918fe --- /dev/null +++ b/cpp/test/matrix/sample_rows.cu @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace raft { +namespace matrix { + +struct inputs { + int N; + int dim; + int n_samples; + bool host; +}; + +::std::ostream& operator<<(::std::ostream& os, const inputs p) +{ + os << p.N << "#" << p.dim << "#" << p.n_samples << (p.host ? "#host" : "#device"); + return os; +} + +template +class SampleRowsTest : public ::testing::TestWithParam { + public: + SampleRowsTest() + : params(::testing::TestWithParam::GetParam()), + stream(resource::get_cuda_stream(res)), + state{137ULL}, + in(make_device_matrix(res, params.N, params.dim)), + out(make_device_matrix(res, 0, 0)), + in_h(make_host_matrix(res, params.N, params.dim)), + out_h(make_host_matrix(res, params.n_samples, params.dim)) + { + raft::random::uniform(res, state, in.data_handle(), in.size(), T(-1.0), T(1.0)); + for (int64_t i = 0; i < params.N; i++) { + for (int64_t k = 0; k < params.dim; k++) + in_h(i, k) = i * 1000 + k; + } + raft::copy(in.data_handle(), in_h.data_handle(), in_h.size(), stream); + } + + void check() + { + if (params.host) { + out = raft::matrix::sample_rows( + res, state, make_const_mdspan(in_h.view()), (int64_t)params.n_samples); + } else { + out = raft::matrix::sample_rows( + res, state, make_const_mdspan(in.view()), (int64_t)params.n_samples); + } + + raft::copy(out_h.data_handle(), out.data_handle(), out.size(), stream); + resource::sync_stream(res, stream); + + ASSERT_TRUE(out.extent(0) == params.n_samples); + ASSERT_TRUE(out.extent(1) == params.dim); + + std::unordered_set occurrence; + + for (int64_t i = 0; i < params.n_samples; ++i) { + T val = out_h(i, 0) / 1000; + ASSERT_TRUE(0 <= val && val < params.N) + << "out-of-range index @i=" << i << " val=" << val << " params=" << params; + EXPECT_TRUE(occurrence.find(val) == occurrence.end()) + << "repeated index @i=" << i << " idx=" << val << " params=" << params; + occurrence.insert(val); + for (int64_t k = 0; k < params.dim; k++) { + ASSERT_TRUE(raft::match(out_h(i, k), val * 1000 + k, raft::CompareApprox(1e-6))); + } + } + } + + protected: + inputs params; + raft::resources res; + cudaStream_t stream; + random::RngState state; + device_matrix in, out; + host_matrix in_h, out_h; +}; + +inline std::vector generate_inputs() +{ + std::vector input1 = + raft::util::itertools::product({10}, {1, 17, 96}, {1, 6, 9, 10}, {false}); + + std::vector input2 = + raft::util::itertools::product({137}, {1, 17, 128}, {1, 10, 100, 137}, {false}); + input1.insert(input1.end(), input2.begin(), input2.end()); + + input2 = raft::util::itertools::product( + {100000}, {1, 42}, {1, 137, 1000, 10000, 50000, 62000, 100000}, {false}); + + input1.insert(input1.end(), input2.begin(), input2.end()); + + int n = input1.size(); + // Add same tests for host data + for (int i = 0; i < n; i++) { + inputs x = input1[i]; + x.host = true; + input1.push_back(x); + } + return input1; +} + +const std::vector inputs1 = generate_inputs(); + +using SampleRowsTestInt64 = SampleRowsTest; +TEST_P(SampleRowsTestInt64, SamplingTest) { check(); } +INSTANTIATE_TEST_SUITE_P(SampleRowsTests, SampleRowsTestInt64, ::testing::ValuesIn(inputs1)); + +} // namespace matrix +} // namespace raft diff --git a/cpp/test/random/excess_sampling.cu b/cpp/test/random/excess_sampling.cu new file mode 100644 index 0000000000..e86436fb7d --- /dev/null +++ b/cpp/test/random/excess_sampling.cu @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace raft { +namespace random { + +using namespace raft::random; + +struct inputs { + int64_t N; + int64_t n_samples; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const inputs p) +{ + os << p.N << "/" << p.n_samples; + return os; +} + +template +class ExcessSamplingTest : public ::testing::TestWithParam { + public: + ExcessSamplingTest() + : params(::testing::TestWithParam::GetParam()), + stream(resource::get_cuda_stream(res)), + state{137ULL} + { + } + + void check() + { + device_vector out = + raft::random::excess_subsample(res, state, params.N, params.n_samples); + ASSERT_TRUE(out.extent(0) == params.n_samples); + + auto h_out = make_host_vector(res, params.n_samples); + raft::copy(h_out.data_handle(), out.data_handle(), out.size(), stream); + resource::sync_stream(res, stream); + + std::unordered_set occurrence; + int64_t sum = 0; + for (int64_t i = 0; i < params.n_samples; ++i) { + T val = h_out(i); + sum += val; + ASSERT_TRUE(0 <= val && val < params.N) + << "out-of-range index @i=" << i << " val=" << val << " n_samples=" << params.n_samples; + ASSERT_TRUE(occurrence.find(val) == occurrence.end()) + << "repeated index @i=" << i << " idx=" << val; + occurrence.insert(val); + } + float avg = sum / (float)params.n_samples; + if (params.n_samples >= 100 && params.N / params.n_samples < 100) { + ASSERT_TRUE(raft::match(avg, (params.N - 1) / 2.0f, raft::CompareApprox(0.2))) + << "non-uniform sample"; + } + } + + protected: + inputs params; + raft::resources res; + cudaStream_t stream; + RngState state; +}; + +const std::vector input1 = {{1, 0}, + {1, 1}, + {10, 0}, + {10, 1}, + {10, 2}, + {10, 10}, + {137, 42}, + {200, 0}, + {200, 1}, + {200, 100}, + {200, 130}, + {200, 200}, + {10000, 893}, + {10000000000, 1023}}; + +using ExcessSamplingTestInt64 = ExcessSamplingTest; +TEST_P(ExcessSamplingTestInt64, SamplingTest) { check(); } +INSTANTIATE_TEST_SUITE_P(ExcessSamplingTests, ExcessSamplingTestInt64, ::testing::ValuesIn(input1)); + +} // namespace random +} // namespace raft From 335236c705c0c53da8a4bf6a22835fdbe669f1df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Malte=20F=C3=B6rster?= <97973773+mfoerste4@users.noreply.github.com> Date: Wed, 20 Mar 2024 15:08:19 +0100 Subject: [PATCH 2/2] Performance optimization of IVF-flat / select_k (#2221) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR is a followup to #2169. To enable IVF-flat with k>256 we need an additional select_k invocation which was unexpectedly slow. There are two reasons for that: First problem is the data handed to select_k: The valid data length per row is much smaller than the conservative maximum that could be achieved by probing the N largest probes. Therefore each query row contains roughly ~50% dummy values. This is also the case for IVF-PQ, but did not show up as prominent due to the second reason. The second problem, and also a difference to the IVF-PQ algorithm - is that a 64bit payload data type is used for selectK. The performance of selectK with 64bit index type is significantly slower than with 32bit, especially when many elements are in the same range: ``` Benchmark Time CPU Iterations ----------------------------------------------------------------------------------------------------- SelectK/float/uint32_t/kRadix11bitsExtraPass/1/manual_time 1.68 ms 1.74 ms 413 1357#200000#512 SelectK/float/uint32_t/kRadix11bitsExtraPass/3/manual_time 2.31 ms 2.37 ms 302 1357#200000#512#same-leading-bits SelectK/float/int64_t/kRadix11bitsExtraPass/1/manual_time 5.92 ms 5.98 ms 116 1357#200000#512 SelectK/float/int64_t/kRadix11bitsExtraPass/3/manual_time 83.7 ms 83.8 ms 8 1357#200000#512#same-leading-bits ----------------------------------------------------------------------------------------------------- ``` The data distribution within a IVF-flat benchmark resulted in a select_k time of ~24ms. ### scope: * additional parameter added to select_k to optionally pass individual row lengths for every batch entry. This parameter is utilized by both IVF-Flat and IVF-PQ and results in a ~2x speedup (50 nodes out of 5000) of the final `select_k`. * refactor ivf-flat search to work with 32bit indices by storing positions instead of actual indices. This allows to utilize 32bit index type select_k for ~10x speedup in the final `select_k`. FYI @tfeher @achirkin ### not in scope: * General optimization of select_k: In the current implementation there is no difference in the type of the payload and the actual index type. Especially the type of the histogram has a large effect on performance (due to the atomics). Authors: - Malte Förster (https://github.com/mfoerste4) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2221 --- .../raft/matrix/detail/select_k-ext.cuh | 8 +- .../raft/matrix/detail/select_k-inl.cuh | 16 ++-- .../raft/matrix/detail/select_radix.cuh | 35 ++++++- .../raft/neighbors/detail/ivf_common.cuh | 20 ++-- .../detail/ivf_flat_interleaved_scan-ext.cuh | 4 +- .../detail/ivf_flat_interleaved_scan-inl.cuh | 25 ++--- .../neighbors/detail/ivf_flat_search-inl.cuh | 91 +++++++++++-------- .../raft/neighbors/detail/ivf_pq_search.cuh | 5 +- .../raft/neighbors/detail/refine_device.cuh | 36 +++++++- .../matrix/detail/select_k_double_int64_t.cu | 3 +- .../matrix/detail/select_k_double_uint32_t.cu | 3 +- cpp/src/matrix/detail/select_k_float_int32.cu | 3 +- .../matrix/detail/select_k_float_int64_t.cu | 3 +- .../matrix/detail/select_k_float_uint32_t.cu | 3 +- .../matrix/detail/select_k_half_int64_t.cu | 3 +- .../matrix/detail/select_k_half_uint32_t.cu | 3 +- ...at_interleaved_scan_float_float_int64_t.cu | 2 +- ...flat_interleaved_scan_half_half_int64_t.cu | 2 +- ...interleaved_scan_int8_t_int32_t_int64_t.cu | 2 +- ...terleaved_scan_uint8_t_uint32_t_int64_t.cu | 2 +- cpp/test/neighbors/ann_cagra.cuh | 8 +- cpp/test/neighbors/ann_utils.cuh | 43 ++++++++- 22 files changed, 221 insertions(+), 99 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index 6a7847d8a0..506cbffcb9 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -41,8 +41,9 @@ void select_k(raft::resources const& handle, T* out_val, IdxT* out_idx, bool select_min, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto, + const IdxT* len_i = nullptr) RAFT_EXPLICIT; } // namespace raft::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -58,7 +59,8 @@ void select_k(raft::resources const& handle, IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, uint32_t); instantiate_raft_matrix_detail_select_k(__half, int64_t); instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index 8f40e6ae00..93d233152b 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -229,6 +229,9 @@ void segmented_sort_by_key(raft::resources const& handle, * whether to make sure selected pairs are sorted by value * @param[in] algo * the selection algorithm to use + * @param[in] len_i + * array of size (batch_size) providing lengths for each individual row + * only radix select-k supported */ template void select_k(raft::resources const& handle, @@ -240,8 +243,9 @@ void select_k(raft::resources const& handle, T* out_val, IdxT* out_idx, bool select_min, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto, + const IdxT* len_i = nullptr) { common::nvtx::range fun_scope( "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); @@ -262,9 +266,8 @@ void select_k(raft::resources const& handle, out_val, out_idx, select_min, - true // fused_last_filter - ); - + true, // fused_last_filter + len_i); } else { bool fused_last_filter = algo == SelectAlgo::kRadix11bits; detail::select::radix::select_k(handle, @@ -276,7 +279,8 @@ void select_k(raft::resources const& handle, out_val, out_idx, select_min, - fused_last_filter); + fused_last_filter, + len_i); } if (sorted) { auto offsets = make_device_mdarray( diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 82983b7cd2..36a346fda3 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -557,6 +557,7 @@ RAFT_KERNEL radix_kernel(const T* in, Counter* counters, IdxT* histograms, const IdxT len, + const IdxT* len_i, const IdxT k, const bool select_min, const int pass) @@ -598,6 +599,14 @@ RAFT_KERNEL radix_kernel(const T* in, in_buf += batch_id * buf_len; in_idx_buf += batch_id * buf_len; } + + // in case we have individual len for each query defined we want to make sure + // that we only iterate valid elements. + if (len_i != nullptr) { + const IdxT max_len = max(len_i[batch_id], k); + if (max_len < previous_len) previous_len = max_len; + } + // "current_len > buf_len" means current pass will skip writing buffer if (pass == 0 || current_len > buf_len) { out_buf = nullptr; @@ -829,6 +838,7 @@ void radix_topk(const T* in, IdxT* out_idx, bool select_min, bool fused_last_filter, + const IdxT* len_i, unsigned grid_dim, int sm_cnt, rmm::cuda_stream_view stream, @@ -868,6 +878,7 @@ void radix_topk(const T* in, const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; T* chunk_out = out + offset * k; IdxT* chunk_out_idx = out_idx + offset * k; + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; const T* in_buf = nullptr; const IdxT* in_idx_buf = nullptr; @@ -905,6 +916,7 @@ void radix_topk(const T* in, counters.data(), histograms.data(), len, + chunk_len_i, k, select_min, pass); @@ -1007,6 +1019,7 @@ template RAFT_KERNEL radix_topk_one_block_kernel(const T* in, const IdxT* in_idx, const IdxT len, + const IdxT* len_i, const IdxT k, T* out, IdxT* out_idx, @@ -1057,6 +1070,13 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, out_idx_buf = nullptr; } + // in case we have individual len for each query defined we want to make sure + // that we only iterate valid elements. + if (len_i != nullptr) { + const IdxT max_len = max(len_i[batch_id], k); + if (max_len < previous_len) previous_len = max_len; + } + filter_and_histogram_for_one_block(in_buf, in_idx_buf, out_buf, @@ -1106,6 +1126,7 @@ void radix_topk_one_block(const T* in, T* out, IdxT* out_idx, bool select_min, + const IdxT* len_i, int sm_cnt, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -1121,10 +1142,12 @@ void radix_topk_one_block(const T* in, max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { - int chunk_size = std::min(max_chunk_size, batch_size - offset); + int chunk_size = std::min(max_chunk_size, batch_size - offset); + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; kernel<<>>(in + offset * len, in_idx ? (in_idx + offset * len) : nullptr, len, + chunk_len_i, k, out + offset * k, out_idx + offset * k, @@ -1188,6 +1211,8 @@ void radix_topk_one_block(const T* in, * blocks is called. The later case is preferable when leading bits of input data are almost the * same. That is, when the value range of input data is narrow. In such case, there could be a * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. + * @param len_i + * optional array of size (batch_size) providing lengths for each individual row */ template void select_k(raft::resources const& res, @@ -1199,7 +1224,8 @@ void select_k(raft::resources const& res, T* out, IdxT* out_idx, bool select_min, - bool fused_last_filter) + bool fused_last_filter, + const IdxT* len_i) { auto stream = resource::get_cuda_stream(res); auto mr = resource::get_workspace_resource(res); @@ -1223,13 +1249,13 @@ void select_k(raft::resources const& res, if (len <= BlockSize * items_per_thread) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { unsigned grid_dim = impl::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { impl::radix_topk(in, in_idx, @@ -1240,6 +1266,7 @@ void select_k(raft::resources const& res, out_idx, select_min, fused_last_filter, + len_i, grid_dim, sm_cnt, stream, diff --git a/cpp/include/raft/neighbors/detail/ivf_common.cuh b/cpp/include/raft/neighbors/detail/ivf_common.cuh index ef7ae7c804..df0319e181 100644 --- a/cpp/include/raft/neighbors/detail/ivf_common.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_common.cuh @@ -147,11 +147,11 @@ __device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT return ix_min; } -template +template __launch_bounds__(BlockDim) RAFT_KERNEL - postprocess_neighbors_kernel(IdxT1* neighbors_out, // [n_queries, topk] - const IdxT2* neighbors_in, // [n_queries, topk] - const IdxT1* const* db_indices, // [n_clusters][..] + postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk] + const uint32_t* neighbors_in, // [n_queries, topk] + const IdxT* const* db_indices, // [n_clusters][..] const uint32_t* clusters_to_probe, // [n_queries, n_probes] const uint32_t* chunk_indices, // [n_queries, n_probes] uint32_t n_queries, @@ -170,7 +170,7 @@ __launch_bounds__(BlockDim) RAFT_KERNEL const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices); const bool valid = chunk_ix < n_probes; neighbors_out[k] = - valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord; + valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord; } /** @@ -180,10 +180,10 @@ __launch_bounds__(BlockDim) RAFT_KERNEL * probed clusters / defined by the `chunk_indices`. * We assume the searched sample sizes (for a single query) fit into `uint32_t`. */ -template -void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, topk] - const IdxT2* neighbors_in, // [n_queries, topk] - const IdxT1* const* db_indices, // [n_clusters][..] +template +void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk] + const uint32_t* neighbors_in, // [n_queries, topk] + const IdxT* const* db_indices, // [n_clusters][..] const uint32_t* clusters_to_probe, // [n_queries, n_probes] const uint32_t* chunk_indices, // [n_queries, n_probes] uint32_t n_queries, @@ -193,7 +193,7 @@ void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, to { constexpr int kPNThreads = 256; const int pn_blocks = raft::div_rounding_up_unsafe(n_queries * topk, kPNThreads); - postprocess_neighbors_kernel + postprocess_neighbors_kernel <<>>(neighbors_out, neighbors_in, db_indices, diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh index 7c2d1d2157..140a9f17c8 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -45,7 +45,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i const uint32_t* chunk_indices, const bool select_min, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) RAFT_EXPLICIT; @@ -70,7 +70,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 6fc528e26b..9cd8b70148 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -690,7 +690,6 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const uint32_t query_smem_elems, const T* query, const uint32_t* coarse_index, - const IdxT* const* list_indices_ptrs, const T* const* list_data_ptrs, const uint32_t* list_sizes, const uint32_t queries_offset, @@ -700,7 +699,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const uint32_t* chunk_indices, const uint32_t dim, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances) { extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; @@ -719,8 +718,8 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) distances += query_id * k * gridDim.x + blockIdx.x * k; } else { distances += query_id * uint64_t(max_samples); - chunk_indices += (n_probes * query_id); } + chunk_indices += (n_probes * query_id); coarse_index += query_id * n_probes; } @@ -728,7 +727,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); __syncthreads(); - using local_topk_t = block_sort_t; + using local_topk_t = block_sort_t; local_topk_t queue(k); { using align_warp = Pow2; @@ -752,11 +751,9 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 uint32_t sample_offset = 0; - if constexpr (!kManageLocalTopK) { - if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; } - assert(list_length == chunk_indices[probe_id] - sample_offset); - assert(sample_offset + list_length <= max_samples); - } + if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; } + assert(list_length == chunk_indices[probe_id] - sample_offset); + assert(sample_offset + list_length <= max_samples); constexpr int kUnroll = WarpSize / Veclen; constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize; @@ -806,8 +803,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) // Enqueue one element per thread const float val = valid ? static_cast(dist) : local_topk_t::queue_t::kDummy; if constexpr (kManageLocalTopK) { - const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; - queue.add(val, idx); + queue.add(val, sample_offset + vec_id); } else { if (vec_id < list_length) distances[sample_offset + vec_id] = val; } @@ -873,7 +869,7 @@ void launch_kernel(Lambda lambda, const uint32_t max_samples, const uint32_t* chunk_indices, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) @@ -927,7 +923,6 @@ void launch_kernel(Lambda lambda, query_smem_elems, queries, coarse_index, - index.inds_ptrs().data_handle(), index.data_ptrs().data_handle(), index.list_sizes().data_handle(), queries_offset + query_offset, @@ -945,8 +940,8 @@ void launch_kernel(Lambda lambda, distances += grid_dim_y * grid_dim_x * k; } else { distances += grid_dim_y * max_samples; - chunk_indices += grid_dim_y * n_probes; } + chunk_indices += grid_dim_y * n_probes; coarse_index += grid_dim_y * n_probes; } } @@ -1161,7 +1156,7 @@ void ivfflat_interleaved_scan(const index& index, const uint32_t* chunk_indices, const bool select_min, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 98bdeda42f..441fb76b2f 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -67,13 +67,16 @@ void search_impl(raft::resources const& handle, // Optional structures if postprocessing is required // The topk distance value of candidate vectors from each cluster(list) rmm::device_uvector distances_tmp_dev(0, stream, search_mr); - // The topk index of candidate vectors from each cluster(list) - rmm::device_uvector indices_tmp_dev(0, stream, search_mr); // Number of samples for each query rmm::device_uvector num_samples(0, stream, search_mr); // Offsets per probe for each query rmm::device_uvector chunk_index(0, stream, search_mr); + // The topk index of candidate vectors from each cluster(list), local index offset + // also we might need additional storage for select_k + rmm::device_uvector indices_tmp_dev(0, stream, search_mr); + rmm::device_uvector neighbors_uint32_buf(0, stream, search_mr); + size_t float_query_size; if constexpr (std::is_integral_v) { float_query_size = n_queries * index.dim(); @@ -175,23 +178,29 @@ void search_impl(raft::resources const& handle, grid_dim_x = 1; } + num_samples.resize(n_queries, stream); + chunk_index.resize(n_queries_probes, stream); + + ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)(index.list_sizes().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + num_samples.data(), + stream); + auto distances_dev_ptr = distances; - auto indices_dev_ptr = neighbors; + + uint32_t* neighbors_uint32 = nullptr; + if constexpr (sizeof(IdxT) == sizeof(uint32_t)) { + neighbors_uint32 = reinterpret_cast(neighbors); + } else { + neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k), stream); + neighbors_uint32 = neighbors_uint32_buf.data(); + } + + uint32_t* indices_dev_ptr = nullptr; bool manage_local_topk = is_local_topk_feasible(k); if (!manage_local_topk || grid_dim_x > 1) { - if (!manage_local_topk) { - num_samples.resize(n_queries, stream); - chunk_index.resize(n_queries_probes, stream); - - ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)( - index.list_sizes().data_handle(), - coarse_indices_dev.data(), - chunk_index.data(), - num_samples.data(), - stream); - } - auto target_size = std::size_t(n_queries) * (manage_local_topk ? grid_dim_x * k : max_samples); distances_tmp_dev.resize(target_size, stream); @@ -199,6 +208,8 @@ void search_impl(raft::resources const& handle, distances_dev_ptr = distances_tmp_dev.data(); indices_dev_ptr = indices_tmp_dev.data(); + } else { + indices_dev_ptr = neighbors_uint32; } ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( @@ -224,31 +235,33 @@ void search_impl(raft::resources const& handle, // Merge topk values from different blocks if (!manage_local_topk || grid_dim_x > 1) { - matrix::detail::select_k(handle, - distances_tmp_dev.data(), - indices_tmp_dev.data(), - n_queries, - manage_local_topk ? (k * grid_dim_x) : max_samples, - k, - distances, - neighbors, - select_min); - - if (!manage_local_topk) { - // post process distances && neighbor IDs - ivf::detail::postprocess_distances( - distances, distances, index.metric(), n_queries, k, 1.0, false, stream); - ivf::detail::postprocess_neighbors(neighbors, - neighbors, - index.inds_ptrs().data_handle(), - coarse_indices_dev.data(), - chunk_index.data(), - n_queries, - n_probes, - k, - stream); - } + matrix::detail::select_k(handle, + distances_tmp_dev.data(), + indices_tmp_dev.data(), + n_queries, + manage_local_topk ? (k * grid_dim_x) : max_samples, + k, + distances, + neighbors_uint32, + select_min, + false, + matrix::SelectAlgo::kAuto, + manage_local_topk ? nullptr : num_samples.data()); + } + if (!manage_local_topk) { + // post process distances && neighbor IDs + ivf::detail::postprocess_distances( + distances, distances, index.metric(), n_queries, k, 1.0, false, stream); } + ivf::detail::postprocess_neighbors(neighbors, + neighbors_uint32, + index.inds_ptrs().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + n_queries, + n_probes, + k, + stream); } /** See raft::neighbors::ivf_flat::search docs */ diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index d445f909e5..4c5da38092 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -447,7 +447,10 @@ void ivfpq_search_worker(raft::resources const& handle, topK, topk_dists.data(), neighbors_uint32, - true); + true, + false, + matrix::SelectAlgo::kAuto, + manage_local_topk ? nullptr : num_samples.data()); // Postprocessing ivf::detail::postprocess_distances( diff --git a/cpp/include/raft/neighbors/detail/refine_device.cuh b/cpp/include/raft/neighbors/detail/refine_device.cuh index e76e52657b..bdc29ca121 100644 --- a/cpp/include/raft/neighbors/detail/refine_device.cuh +++ b/cpp/include/raft/neighbors/detail/refine_device.cuh @@ -88,6 +88,27 @@ void refine_device(raft::resources const& handle, n_queries, n_candidates); uint32_t grid_dim_x = 1; + + // the neighbor ids will be computed in uint32_t as offset + rmm::device_uvector neighbors_uint32_buf(0, resource::get_cuda_stream(handle)); + // Offsets per probe for each query [n_queries] as n_probes = 1 + rmm::device_uvector chunk_index(n_queries, resource::get_cuda_stream(handle)); + + // we know that each cluster has exactly n_candidates entries + thrust::fill(resource::get_thrust_policy(handle), + chunk_index.data(), + chunk_index.data() + n_queries, + uint32_t(n_candidates)); + + uint32_t* neighbors_uint32 = nullptr; + if constexpr (sizeof(idx_t) == sizeof(uint32_t)) { + neighbors_uint32 = reinterpret_cast(indices.data_handle()); + } else { + neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k), + resource::get_cuda_stream(handle)); + neighbors_uint32 = neighbors_uint32_buf.data(); + } + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, typename raft::spatial::knn::detail::utils::config::value_t, @@ -100,13 +121,24 @@ void refine_device(raft::resources const& handle, 1, k, 0, - nullptr, + chunk_index.data(), raft::distance::is_min_close(metric), raft::neighbors::filtering::none_ivf_sample_filter(), - indices.data_handle(), + neighbors_uint32, distances.data_handle(), grid_dim_x, resource::get_cuda_stream(handle)); + + // postprocessing -- neighbors from position to actual id + ivf::detail::postprocess_neighbors(indices.data_handle(), + neighbors_uint32, + refinement_index.inds_ptrs().data_handle(), + fake_coarse_idx.data(), + chunk_index.data(), + n_queries, + 1, + k, + resource::get_cuda_stream(handle)); } } // namespace raft::neighbors::detail diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu index e32b4ef6f0..bf234aacbf 100644 --- a/cpp/src/matrix/detail/select_k_double_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(double, int64_t); diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu index 21c954ca46..7f0511a76a 100644 --- a/cpp/src/matrix/detail/select_k_double_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -29,7 +29,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(double, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_float_int32.cu b/cpp/src/matrix/detail/select_k_float_int32.cu index 7f163a0b0d..e68b1e32df 100644 --- a/cpp/src/matrix/detail/select_k_float_int32.cu +++ b/cpp/src/matrix/detail/select_k_float_int32.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, int); diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu index 87b6525356..5aa40d8c9d 100644 --- a/cpp/src/matrix/detail/select_k_float_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu index e698f811d8..9aba147edf 100644 --- a/cpp/src/matrix/detail/select_k_float_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu index 0eee20b1fa..bc513e4aeb 100644 --- a/cpp/src/matrix/detail/select_k_half_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, int64_t); diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu index f4e6bae21f..e46c7d46bb 100644 --- a/cpp/src/matrix/detail/select_k_half_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, uint32_t); diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index def33e493e..5ac820e0dd 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu index e96600ee02..4d847cdeb1 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu @@ -35,7 +35,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index 13c9d2e283..8a0e8f0118 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index 51f02343fc..7cad992e2b 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index a111de0762..7278f71a24 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -549,6 +549,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { EXPECT_FALSE(unacceptable_node); double min_recall = ps.min_recall; + // TODO(mfoerster): re-enable uniquenes test EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, @@ -556,7 +557,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { ps.n_queries, ps.k, 0.003, - min_recall)); + min_recall, + false)); EXPECT_TRUE(eval_distances(handle_, database.data(), search_queries.data(), @@ -668,6 +670,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { } double min_recall = ps.min_recall; + // TODO(mfoerster): re-enable uniquenes test EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, @@ -675,7 +678,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { ps.n_queries, ps.k, 0.003, - min_recall)); + min_recall, + false)); EXPECT_TRUE(eval_distances(handle_, database.data(), search_queries.data(), diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index afd083d512..6be2ac7fc7 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -35,6 +35,7 @@ #include #include +#include namespace raft::neighbors { @@ -153,13 +154,40 @@ auto calc_recall(const std::vector& expected_idx, static_cast(match_count) / static_cast(total_count), match_count, total_count); } +/** check uniqueness of indices + */ +template +auto check_unique_indices(const std::vector& actual_idx, size_t rows, size_t cols) +{ + size_t max_count; + std::set unique_indices; + for (size_t i = 0; i < rows; ++i) { + unique_indices.clear(); + max_count = 0; + for (size_t k = 0; k < cols; ++k) { + size_t idx_k = i * cols + k; // row major assumption! + auto act_idx = actual_idx[idx_k]; + if (act_idx == std::numeric_limits::max()) { + max_count++; + } else if (unique_indices.find(act_idx) == unique_indices.end()) { + unique_indices.insert(act_idx); + } else { + return testing::AssertionFailure() + << "Duplicated index " << act_idx << " at k " << k << " for query " << i << "! "; + } + } + } + return testing::AssertionSuccess(); +} + template auto eval_recall(const std::vector& expected_idx, const std::vector& actual_idx, size_t rows, size_t cols, double eps, - double min_recall) -> testing::AssertionResult + double min_recall, + bool test_unique = true) -> testing::AssertionResult { auto [actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, rows, cols); @@ -176,7 +204,10 @@ auto eval_recall(const std::vector& expected_idx, << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" << min_recall << "); eps = " << eps << ". "; } - return testing::AssertionSuccess(); + if (test_unique) + return check_unique_indices(actual_idx, rows, cols); + else + return testing::AssertionSuccess(); } /** Overload of calc_recall to account for distances @@ -224,7 +255,8 @@ auto eval_neighbours(const std::vector& expected_idx, size_t rows, size_t cols, double eps, - double min_recall) -> testing::AssertionResult + double min_recall, + bool test_unique = true) -> testing::AssertionResult { auto [actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); @@ -241,7 +273,10 @@ auto eval_neighbours(const std::vector& expected_idx, << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" << min_recall << "); eps = " << eps << ". "; } - return testing::AssertionSuccess(); + if (test_unique) + return check_unique_indices(actual_idx, rows, cols); + else + return testing::AssertionSuccess(); } template