Skip to content

Commit

Permalink
Merge branch 'branch-23.12' into bf_batch_query
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet authored Nov 4, 2023
2 parents d4394c1 + aa3e229 commit d8ec324
Show file tree
Hide file tree
Showing 12 changed files with 156 additions and 71 deletions.
2 changes: 1 addition & 1 deletion ci/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ fi

if [[ ${package_name} == "raft-dask" ]]; then
sed -r -i "s/pylibraft==(.*)\"/pylibraft${PACKAGE_CUDA_SUFFIX}==\1${alpha_spec}\"/g" ${pyproject_file}
sed -i "s/ucx-py/ucx-py${PACKAGE_CUDA_SUFFIX}/g" python/raft-dask/pyproject.toml
sed -r -i "s/ucx-py==(.*)\"/ucx-py${PACKAGE_CUDA_SUFFIX}==\1${alpha_spec}\"/g" ${pyproject_file}
else
sed -r -i "s/rmm(.*)\"/rmm${PACKAGE_CUDA_SUFFIX}\1${alpha_spec}\"/g" ${pyproject_file}
fi
Expand Down
111 changes: 75 additions & 36 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,21 @@
#include <algorithm>
#include <chrono>
#include <cmath>
#include <condition_variable>
#include <cstdint>
#include <fstream>
#include <limits>
#include <memory>
#include <mutex>
#include <numeric>
#include <string>
#include <unistd.h>
#include <vector>

namespace raft::bench::ann {

std::mutex init_mutex;
std::condition_variable cond_var;

static inline std::unique_ptr<AnnBase> current_algo{nullptr};
static inline std::shared_ptr<AlgoProperty> current_algo_props{nullptr};

Expand Down Expand Up @@ -172,8 +176,6 @@ void bench_search(::benchmark::State& state,
std::ptrdiff_t batch_offset = 0;
std::size_t queries_processed = 0;

double total_time = 0;

const auto& sp_json = index.search_params[search_param_ix];

if (state.thread_index() == 0) { dump_parameters(state, sp_json); }
Expand All @@ -185,6 +187,8 @@ void bench_search(::benchmark::State& state,
// Round down the query data to a multiple of the batch size to loop over full batches of data
const std::size_t query_set_size = (dataset->query_set_size() / n_queries) * n_queries;

const T* query_set = nullptr;

if (!file_exists(index.file)) {
state.SkipWithError("Index file is missing. Run the benchmark in the build mode first.");
return;
Expand All @@ -194,6 +198,7 @@ void bench_search(::benchmark::State& state,
* Make sure the first thread loads the algo and dataset
*/
if (state.thread_index() == 0) {
std::unique_lock lk(init_mutex);
// algo is static to cache it between close search runs to save time on index loading
static std::string index_file = "";
if (index.file != index_file) {
Expand Down Expand Up @@ -233,18 +238,28 @@ void bench_search(::benchmark::State& state,
return;
}
}

try {
algo->set_search_param(*search_param);

} catch (const std::exception& ex) {
state.SkipWithError("An error occurred setting search parameters: " + std::string(ex.what()));
return;
}
}

query_set = dataset->query_set(current_algo_props->query_memory_type);
cond_var.notify_all();
} else {
std::unique_lock lk(init_mutex);
// All other threads will wait for the first thread to initialize the algo.

cond_var.wait(
lk, [] { return current_algo_props.get() != nullptr && current_algo.get() != nullptr; });
// gbench ensures that all threads are synchronized at the start of the benchmark loop.
// We are accessing shared variables (like current_algo, current_algo_probs) before the
// benchmark loop, therefore the synchronization here is necessary.
}
const auto algo_property = *current_algo_props;
const T* query_set = dataset->query_set(algo_property.query_memory_type);
query_set = dataset->query_set(algo_property.query_memory_type);

/**
* Each thread will manage its own outputs
Expand All @@ -265,7 +280,6 @@ void bench_search(::benchmark::State& state,
[[maybe_unused]] auto ntx_lap = nvtx.lap();
[[maybe_unused]] auto gpu_lap = gpu_timer.lap();

auto start = std::chrono::high_resolution_clock::now();
// run the search
try {
algo->search(query_set + batch_offset * dataset->dim(),
Expand All @@ -278,24 +292,22 @@ void bench_search(::benchmark::State& state,
state.SkipWithError(std::string(e.what()));
}

auto end = std::chrono::high_resolution_clock::now();

auto elapsed_seconds = std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
// advance to the next batch
batch_offset = (batch_offset + n_queries) % query_set_size;

queries_processed += n_queries;
state.SetIterationTime(elapsed_seconds.count());
total_time += elapsed_seconds.count();
}
}
auto end = std::chrono::high_resolution_clock::now();
if (state.thread_index() == 0) {
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
state.counters.insert({{"end_to_end", duration}});
}
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
if (state.thread_index() == 0) { state.counters.insert({{"end_to_end", duration}}); }
state.counters.insert(
{"Latency", {duration / double(state.iterations()), benchmark::Counter::kAvgThreads}});

state.SetItemsProcessed(queries_processed);
if (cudart.found()) {
state.counters.insert({{"GPU", gpu_timer.total_time() / double(state.iterations())}});
double gpu_time_per_iteration = gpu_timer.total_time() / (double)state.iterations();
state.counters.insert({"GPU", {gpu_time_per_iteration, benchmark::Counter::kAvgThreads}});
}

// This will be the total number of queries across all threads
Expand Down Expand Up @@ -341,6 +353,7 @@ inline void printf_usage()
" [--index_prefix=<prefix>]\n"
" [--override_kv=<key:value1:value2:...:valueN>]\n"
" [--mode=<latency|throughput>\n"
" [--threads=min[:max]]\n"
" <conf>.json\n"
"\n"
"Note the non-standard benchmark parameters:\n"
Expand All @@ -359,8 +372,12 @@ inline void printf_usage()
" you can use this parameter multiple times to get the Cartesian product of benchmark"
" configs.\n"
" --mode=<latency|throughput>"
" run the benchmarks in latency (accumulate times spent in each batch) or "
" throughput (pipeline batches and measure end-to-end) mode\n");
" run the benchmarks in latency (accumulate times spent in each batch) or "
" throughput (pipeline batches and measure end-to-end) mode\n"
" --threads=min[:max] specify the number threads to use for throughput benchmark."
" Power of 2 values between 'min' and 'max' will be used. If only 'min' is specified,"
" then a single test is run with 'min' threads. By default min=1, max=<num hyper"
" threads>.\n");
}

template <typename T>
Expand All @@ -385,33 +402,28 @@ void register_build(std::shared_ptr<const Dataset<T>> dataset,
template <typename T>
void register_search(std::shared_ptr<const Dataset<T>> dataset,
std::vector<Configuration::Index> indices,
Objective metric_objective)
Objective metric_objective,
const std::vector<int>& threads)
{
for (auto index : indices) {
for (std::size_t i = 0; i < index.search_params.size(); i++) {
auto suf = static_cast<std::string>(index.search_params[i]["override_suffix"]);
index.search_params[i].erase("override_suffix");

int max_threads =
metric_objective == Objective::THROUGHPUT ? std::thread::hardware_concurrency() : 1;

auto* b = ::benchmark::RegisterBenchmark(
index.name + suf, bench_search<T>, index, i, dataset, metric_objective)
->Unit(benchmark::kMillisecond)
->ThreadRange(1, max_threads)

/**
* The following are important for getting accuracy QPS measurements on both CPU
* and GPU These make sure that
* - `end_to_end` ~ (`Time` * `Iterations`)
* - `items_per_second` ~ (`total_queries` / `end_to_end`)
* - `Time` = `end_to_end` / `Iterations`
*
* - Latency = `Time`
* - Throughput = `items_per_second`
*/
->MeasureProcessCPUTime()
->UseRealTime();

if (metric_objective == Objective::THROUGHPUT) { b->ThreadRange(threads[0], threads[1]); }
}
}
}
Expand All @@ -424,7 +436,8 @@ void dispatch_benchmark(const Configuration& conf,
std::string data_prefix,
std::string index_prefix,
kv_series override_kv,
Objective metric_objective)
Objective metric_objective,
const std::vector<int>& threads)
{
if (cudart.found()) {
for (auto [key, value] : cuda_info()) {
Expand Down Expand Up @@ -493,7 +506,7 @@ void dispatch_benchmark(const Configuration& conf,
index.search_params = apply_overrides(index.search_params, override_kv);
index.file = combine_path(index_prefix, index.file);
}
register_search<T>(dataset, indices, metric_objective);
register_search<T>(dataset, indices, metric_objective, threads);
}
}

Expand Down Expand Up @@ -525,6 +538,8 @@ inline auto run_main(int argc, char** argv) -> int
std::string index_prefix = "index";
std::string new_override_kv = "";
std::string mode = "latency";
std::string threads_arg_txt = "";
std::vector<int> threads = {1, -1}; // min_thread, max_thread
kv_series override_kv{};

char arg0_default[] = "benchmark"; // NOLINT
Expand All @@ -548,7 +563,18 @@ inline auto run_main(int argc, char** argv) -> int
parse_string_flag(argv[i], "--data_prefix", data_prefix) ||
parse_string_flag(argv[i], "--index_prefix", index_prefix) ||
parse_string_flag(argv[i], "--mode", mode) ||
parse_string_flag(argv[i], "--override_kv", new_override_kv)) {
parse_string_flag(argv[i], "--override_kv", new_override_kv) ||
parse_string_flag(argv[i], "--threads", threads_arg_txt)) {
if (!threads_arg_txt.empty()) {
auto threads_arg = split(threads_arg_txt, ':');
threads[0] = std::stoi(threads_arg[0]);
if (threads_arg.size() > 1) {
threads[1] = std::stoi(threads_arg[1]);
} else {
threads[1] = threads[0];
}
threads_arg_txt = "";
}
if (!new_override_kv.empty()) {
auto kvv = split(new_override_kv, ':');
auto key = kvv[0];
Expand All @@ -570,6 +596,17 @@ inline auto run_main(int argc, char** argv) -> int
Objective metric_objective = Objective::LATENCY;
if (mode == "throughput") { metric_objective = Objective::THROUGHPUT; }

int max_threads =
(metric_objective == Objective::THROUGHPUT) ? std::thread::hardware_concurrency() : 1;
if (threads[1] == -1) threads[1] = max_threads;

if (metric_objective == Objective::LATENCY) {
if (threads[0] != 1 || threads[1] != 1) {
log_warn("Latency mode enabled. Overriding threads arg, running with single thread.");
threads = {1, 1};
}
}

if (build_mode == search_mode) {
log_error("One and only one of --build and --search should be specified");
printf_usage();
Expand All @@ -596,7 +633,8 @@ inline auto run_main(int argc, char** argv) -> int
data_prefix,
index_prefix,
override_kv,
metric_objective);
metric_objective,
threads);
} else if (dtype == "uint8") {
dispatch_benchmark<std::uint8_t>(conf,
force_overwrite,
Expand All @@ -605,7 +643,8 @@ inline auto run_main(int argc, char** argv) -> int
data_prefix,
index_prefix,
override_kv,
metric_objective);
metric_objective,
threads);
} else if (dtype == "int8") {
dispatch_benchmark<std::int8_t>(conf,
force_overwrite,
Expand All @@ -614,7 +653,8 @@ inline auto run_main(int argc, char** argv) -> int
data_prefix,
index_prefix,
override_kv,
metric_objective);
metric_objective,
threads);
} else {
log_error("datatype '%s' is not supported", dtype.c_str());
return -1;
Expand All @@ -629,5 +669,4 @@ inline auto run_main(int argc, char** argv) -> int
current_algo.reset();
return 0;
}

}; // namespace raft::bench::ann
3 changes: 3 additions & 0 deletions cpp/bench/ann/src/common/cuda_stub.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ namespace stub {
{
return cudaSuccess;
}
[[gnu::weak, gnu::noinline]] cudaError_t cudaDeviceSynchronize() { return cudaSuccess; }

[[gnu::weak, gnu::noinline]] cudaError_t cudaStreamSynchronize(cudaStream_t pStream)
{
return cudaSuccess;
Expand Down Expand Up @@ -214,6 +216,7 @@ RAFT_DECLARE_CUDART(cudaFree);
RAFT_DECLARE_CUDART(cudaStreamCreate);
RAFT_DECLARE_CUDART(cudaStreamCreateWithFlags);
RAFT_DECLARE_CUDART(cudaStreamDestroy);
RAFT_DECLARE_CUDART(cudaDeviceSynchronize);
RAFT_DECLARE_CUDART(cudaStreamSynchronize);
RAFT_DECLARE_CUDART(cudaEventCreate);
RAFT_DECLARE_CUDART(cudaEventRecord);
Expand Down
13 changes: 5 additions & 8 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ void HnswLib<T>::build(const T* dataset, size_t nrow, cudaStream_t)
char buf[20];
std::time_t now = std::time(nullptr);
std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", std::localtime(&now));

printf("%s building %zu / %zu\n", buf, i, items_per_thread);
fflush(stdout);
}
Expand All @@ -163,13 +162,11 @@ void HnswLib<T>::set_search_param(const AnnSearchParam& param_)
auto param = dynamic_cast<const SearchParam&>(param_);
appr_alg_->ef_ = param.ef;
metric_objective_ = param.metric_objective;
num_threads_ = param.num_threads;

bool use_pool = (metric_objective_ == Objective::LATENCY && param.num_threads > 1) &&
(!thread_pool_ || num_threads_ != param.num_threads);
if (use_pool) {
num_threads_ = param.num_threads;
thread_pool_ = std::make_unique<FixedThreadPool>(num_threads_);
}
// Create a pool if multiple query threads have been set and the pool hasn't been created already
bool create_pool = (metric_objective_ == Objective::LATENCY && num_threads_ > 1 && !thread_pool_);
if (create_pool) { thread_pool_ = std::make_unique<FixedThreadPool>(num_threads_); }
}

template <typename T>
Expand All @@ -180,7 +177,7 @@ void HnswLib<T>::search(
// hnsw can only handle a single vector at a time.
get_search_knn_results_(query + i * dim_, k, indices + i * k, distances + i * k);
};
if (metric_objective_ == Objective::LATENCY) {
if (metric_objective_ == Objective::LATENCY && num_threads_ > 1) {
thread_pool_->submit(f, batch_size);
} else {
for (int i = 0; i < batch_size; i++) {
Expand Down
13 changes: 12 additions & 1 deletion cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <algorithm>
#include <cmath>
#include <memory>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <stdexcept>
#include <string>
#include <type_traits>
Expand Down Expand Up @@ -272,5 +273,15 @@ REGISTER_ALGO_INSTANCE(std::uint8_t);

#ifdef ANN_BENCH_BUILD_MAIN
#include "../common/benchmark.hpp"
int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); }
int main(int argc, char** argv)
{
rmm::mr::cuda_memory_resource cuda_mr;
// Construct a resource that uses a coalescing best-fit pool allocator
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_mr{&cuda_mr};
rmm::mr::set_current_device_resource(
&pool_mr); // Updates the current device resource pointer to `pool_mr`
rmm::mr::device_memory_resource* mr =
rmm::mr::get_current_device_resource(); // Points to `pool_mr`
return raft::bench::ann::run_main(argc, argv);
}
#endif
1 change: 0 additions & 1 deletion cpp/include/raft/core/detail/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#pragma once
#include <cstdio>
#include <execution>
#include <raft/core/cuda_support.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/error.hpp>
Expand Down
Loading

0 comments on commit d8ec324

Please sign in to comment.