Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/cagra_hnsw_serialize' into cagra…
Browse files Browse the repository at this point in the history
…_hnsw_serialize
  • Loading branch information
divyegala committed Nov 9, 2023
2 parents aa22782 + d668e05 commit b370f51
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 25 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
- checks
- conda-cpp-build
- conda-cpp-tests
- conda-cpp-checks
- conda-python-build
- conda-python-tests
- docs-build
Expand Down Expand Up @@ -43,6 +44,14 @@ jobs:
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
with:
build_type: pull-request
conda-cpp-checks:
needs: conda-cpp-build
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
with:
build_type: pull-request
enable_check_symbols: true
symbol_exclusions: (void (thrust::|cub::)|_ZN\d+raft_cutlass)
conda-python-build:
needs: conda-cpp-build
secrets: inherit
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ on:
type: string

jobs:
conda-cpp-checks:
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
with:
build_type: nightly
branch: ${{ inputs.branch }}
date: ${{ inputs.date }}
sha: ${{ inputs.sha }}
enable_check_symbols: true
symbol_exclusions: (void (thrust::|cub::)|_ZN\d+raft_cutlass)
conda-cpp-tests:
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
Expand Down
74 changes: 49 additions & 25 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <memory>
#include <mutex>
#include <numeric>
#include <sstream>
#include <string>
#include <unistd.h>
#include <vector>
Expand Down Expand Up @@ -175,7 +176,6 @@ void bench_search(::benchmark::State& state,
std::shared_ptr<const Dataset<T>> dataset,
Objective metric_objective)
{
std::ptrdiff_t batch_offset = 0;
std::size_t queries_processed = 0;

const auto& sp_json = index.search_params[search_param_ix];
Expand All @@ -189,6 +189,20 @@ void bench_search(::benchmark::State& state,
// Round down the query data to a multiple of the batch size to loop over full batches of data
const std::size_t query_set_size = (dataset->query_set_size() / n_queries) * n_queries;

if (dataset->query_set_size() < n_queries) {
std::stringstream msg;
msg << "Not enough queries in benchmark set. Expected " << n_queries << ", actual "
<< dataset->query_set_size();
return state.SkipWithError(msg.str());
}

// Each thread start from a different offset, so that the queries that they process do not
// overlap.
std::ptrdiff_t batch_offset = (state.thread_index() * n_queries) % query_set_size;
std::ptrdiff_t queries_stride = state.threads() * n_queries;
// Output is saved into a contiguous buffer (separate buffers for each thread).
std::ptrdiff_t out_offset = 0;

const T* query_set = nullptr;

if (!file_exists(index.file)) {
Expand Down Expand Up @@ -278,7 +292,6 @@ void bench_search(::benchmark::State& state,
{
nvtx_case nvtx{state.name()};

// TODO: Have the odd threads load the queries backwards just to rule out caching.
ANN<T>* algo = dynamic_cast<ANN<T>*>(current_algo.get());
for (auto _ : state) {
[[maybe_unused]] auto ntx_lap = nvtx.lap();
Expand All @@ -289,15 +302,16 @@ void bench_search(::benchmark::State& state,
algo->search(query_set + batch_offset * dataset->dim(),
n_queries,
k,
neighbors->data + batch_offset * k,
distances->data + batch_offset * k,
neighbors->data + out_offset * k,
distances->data + out_offset * k,
gpu_timer.stream());
} catch (const std::exception& e) {
state.SkipWithError(std::string(e.what()));
}

// advance to the next batch
batch_offset = (batch_offset + n_queries) % query_set_size;
batch_offset = (batch_offset + queries_stride) % query_set_size;
out_offset = (out_offset + n_queries) % query_set_size;

queries_processed += n_queries;
}
Expand All @@ -323,31 +337,41 @@ void bench_search(::benchmark::State& state,
// last thread to finish processing notifies all
if (processed_threads-- == 0) { cond_var.notify_all(); }

// Use the last thread as a sanity check that all the threads are working.
if (state.thread_index() == state.threads() - 1) {
// evaluate recall
if (dataset->max_k() >= k) {
const std::int32_t* gt = dataset->gt_set();
const std::uint32_t max_k = dataset->max_k();
buf<std::size_t> neighbors_host = neighbors->move(MemoryType::Host);
std::size_t rows = std::min(queries_processed, query_set_size);
std::size_t match_count = 0;
std::size_t total_count = rows * static_cast<size_t>(k);
for (std::size_t i = 0; i < rows; i++) {
for (std::uint32_t j = 0; j < k; j++) {
auto act_idx = std::int32_t(neighbors_host.data[i * k + j]);
for (std::uint32_t l = 0; l < k; l++) {
auto exp_idx = gt[i * max_k + l];
if (act_idx == exp_idx) {
match_count++;
break;
// Each thread calculates recall on their partition of queries.
// evaluate recall
if (dataset->max_k() >= k) {
const std::int32_t* gt = dataset->gt_set();
const std::uint32_t max_k = dataset->max_k();
buf<std::size_t> neighbors_host = neighbors->move(MemoryType::Host);
std::size_t rows = std::min(queries_processed, query_set_size);
std::size_t match_count = 0;
std::size_t total_count = rows * static_cast<size_t>(k);

// We go through the groundtruth with same stride as the benchmark loop.
size_t out_offset = 0;
size_t batch_offset = (state.thread_index() * n_queries) % query_set_size;
while (out_offset < rows) {
for (std::size_t i = 0; i < n_queries; i++) {
size_t i_orig_idx = batch_offset + i;
size_t i_out_idx = out_offset + i;
if (i_out_idx < rows) {
for (std::uint32_t j = 0; j < k; j++) {
auto act_idx = std::int32_t(neighbors_host.data[i_out_idx * k + j]);
for (std::uint32_t l = 0; l < k; l++) {
auto exp_idx = gt[i_orig_idx * max_k + l];
if (act_idx == exp_idx) {
match_count++;
break;
}
}
}
}
}
double actual_recall = static_cast<double>(match_count) / static_cast<double>(total_count);
state.counters.insert({{"Recall", actual_recall}});
out_offset += n_queries;
batch_offset = (batch_offset + queries_stride) % query_set_size;
}
double actual_recall = static_cast<double>(match_count) / static_cast<double>(total_count);
state.counters.insert({"Recall", {actual_recall, benchmark::Counter::kAvgThreads}});
}
}

Expand Down

0 comments on commit b370f51

Please sign in to comment.