Skip to content

Commit

Permalink
[FEA] Add bitset for ANN pre-filtering and deletion (#1803)
Browse files Browse the repository at this point in the history
Related to #1600

Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1803
  • Loading branch information
lowener authored Sep 26, 2023
1 parent ed42bb5 commit b9b5f44
Show file tree
Hide file tree
Showing 11 changed files with 619 additions and 3 deletions.
2 changes: 1 addition & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ BUILD_REPORT_METRICS=""
BUILD_REPORT_INCL_CACHE_STATS=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"
BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"

CACHE_ARGS=""
NVTX=ON
Expand Down
2 changes: 2 additions & 0 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ if(BUILD_PRIMS_BENCH)
NAME CLUSTER_BENCH PATH bench/prims/cluster/kmeans_balanced.cu bench/prims/cluster/kmeans.cu
bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY
)
ConfigureBench(NAME CORE_BENCH PATH bench/prims/core/bitset.cu bench/prims/main.cpp)

ConfigureBench(
NAME TUNE_DISTANCE PATH bench/prims/distance/tune_pairwise/kernel.cu
Expand Down Expand Up @@ -155,4 +156,5 @@ if(BUILD_PRIMS_BENCH)
LIB
EXPLICIT_INSTANTIATE_ONLY
)

endif()
74 changes: 74 additions & 0 deletions cpp/bench/prims/core/bitset.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <common/benchmark.hpp>
#include <raft/core/bitset.cuh>
#include <raft/core/device_mdspan.hpp>
#include <rmm/device_uvector.hpp>

namespace raft::bench::core {

struct bitset_inputs {
uint32_t bitset_len;
uint32_t mask_len;
uint32_t query_len;
}; // struct bitset_inputs

template <typename bitset_t, typename index_t>
struct bitset_bench : public fixture {
bitset_bench(const bitset_inputs& p)
: params(p),
mask{raft::make_device_vector<index_t, index_t>(res, p.mask_len)},
queries{raft::make_device_vector<index_t, index_t>(res, p.query_len)},
outputs{raft::make_device_vector<bool, index_t>(res, p.query_len)}
{
raft::random::RngState state{42};
raft::random::uniformInt(res, state, mask.view(), index_t{0}, index_t{p.bitset_len});
}

void run_benchmark(::benchmark::State& state) override
{
loop_on_state(state, [this]() {
auto my_bitset = raft::core::bitset<bitset_t, index_t>(
this->res, raft::make_const_mdspan(mask.view()), params.bitset_len);
my_bitset.test(res, raft::make_const_mdspan(queries.view()), outputs.view());
});
}

private:
raft::resources res;
bitset_inputs params;
raft::device_vector<index_t, index_t> mask, queries;
raft::device_vector<bool, index_t> outputs;
}; // struct bitset

const std::vector<bitset_inputs> bitset_input_vecs{
{256 * 1024 * 1024, 64 * 1024 * 1024, 256 * 1024 * 1024}, // Standard Bench
{256 * 1024 * 1024, 64 * 1024 * 1024, 1024 * 1024 * 1024}, // Extra queries
{128 * 1024 * 1024, 1024 * 1024 * 1024, 256 * 1024 * 1024}, // Extra mask to test atomics impact
};

using Uint8_32 = bitset_bench<uint8_t, uint32_t>;
using Uint16_64 = bitset_bench<uint16_t, uint32_t>;
using Uint32_32 = bitset_bench<uint32_t, uint32_t>;
using Uint32_64 = bitset_bench<uint32_t, uint64_t>;

RAFT_BENCH_REGISTER(Uint8_32, "", bitset_input_vecs);
RAFT_BENCH_REGISTER(Uint16_64, "", bitset_input_vecs);
RAFT_BENCH_REGISTER(Uint32_32, "", bitset_input_vecs);
RAFT_BENCH_REGISTER(Uint32_64, "", bitset_input_vecs);

} // namespace raft::bench::core
Loading

0 comments on commit b9b5f44

Please sign in to comment.