Skip to content

Commit

Permalink
Adding thread pool to overlap faiss queries
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Oct 4, 2023
1 parent 385b4f4 commit 95c12db
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 114 deletions.
131 changes: 131 additions & 0 deletions cpp/bench/ann/src/common/thread_pool.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <atomic>
#include <future>
#include <memory>
#include <mutex>
#include <omp.h>
#include <stdexcept>
#include <thread>
#include <utility>

class FixedThreadPool {
public:
FixedThreadPool(int num_threads)
{
if (num_threads < 1) {
throw std::runtime_error("num_threads must >= 1");
} else if (num_threads == 1) {
return;
}

tasks_ = new Task_[num_threads];

threads_.reserve(num_threads);
for (int i = 0; i < num_threads; ++i) {
threads_.emplace_back([&, i] {
auto& task = tasks_[i];
while (true) {
std::unique_lock<std::mutex> lock(task.mtx);
task.cv.wait(lock,
[&] { return task.has_task || finished_.load(std::memory_order_relaxed); });
if (finished_.load(std::memory_order_relaxed)) { break; }

task.task();
task.has_task = false;
}
});
}
}

~FixedThreadPool()
{
if (threads_.empty()) { return; }

finished_.store(true, std::memory_order_relaxed);
for (unsigned i = 0; i < threads_.size(); ++i) {
auto& task = tasks_[i];
std::lock_guard<std::mutex>(task.mtx);

task.cv.notify_one();
threads_[i].join();
}

delete[] tasks_;
}

template <typename Func, typename IdxT>
void submit(Func f, IdxT len)
{
if (threads_.empty()) {
for (IdxT i = 0; i < len; ++i) {
f(i);
}
return;
}

const int num_threads = threads_.size();
// one extra part for competition among threads
const IdxT items_per_thread = len / (num_threads + 1);
std::atomic<IdxT> cnt(items_per_thread * num_threads);

auto wrapped_f = [&](IdxT start, IdxT end) {
for (IdxT i = start; i < end; ++i) {
f(i);
}

while (true) {
IdxT i = cnt.fetch_add(1, std::memory_order_relaxed);
if (i >= len) { break; }
f(i);
}
};

std::vector<std::future<void>> futures;
futures.reserve(num_threads);
for (int i = 0; i < num_threads; ++i) {
IdxT start = i * items_per_thread;
auto& task = tasks_[i];
{
std::lock_guard lock(task.mtx);
(void)lock; // stop nvcc warning
task.task = std::packaged_task<void()>([=] { wrapped_f(start, start + items_per_thread); });
futures.push_back(task.task.get_future());
task.has_task = true;
}
task.cv.notify_one();
}

for (auto& fut : futures) {
fut.wait();
}
return;
}

private:
struct alignas(64) Task_ {
std::mutex mtx;
std::condition_variable cv;
bool has_task = false;
std::packaged_task<void()> task;
};

Task_* tasks_;
std::vector<std::thread> threads_;
std::atomic<bool> finished_{false};
};
1 change: 1 addition & 0 deletions cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ void parse_search_param(const nlohmann::json& conf,
{
param.nprobe = conf.at("nprobe");
if (conf.contains("refine_ratio")) { param.refine_ratio = conf.at("refine_ratio"); }
if (conf.contains("numThreads")) { param.num_threads = conf.at("numThreads"); }
}

template <typename T, template <typename> class Algo>
Expand Down
28 changes: 26 additions & 2 deletions cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#pragma once

#include "../common/ann_types.hpp"
#include "../common/thread_pool.hpp"

#include <raft/core/logger.hpp>

#include <faiss/IndexFlat.h>
Expand Down Expand Up @@ -54,6 +56,7 @@ class FaissCpu : public ANN<T> {
struct SearchParam : public AnnSearchParam {
int nprobe;
float refine_ratio = 1.0;
int num_threads = omp_get_num_procs();
};

struct BuildParam {
Expand Down Expand Up @@ -116,6 +119,9 @@ class FaissCpu : public ANN<T> {
faiss::MetricType metric_type_;
int nlist_;
double training_sample_fraction_;

int num_threads_;
std::unique_ptr<FixedThreadPool> thread_pool_;
};

template <typename T>
Expand Down Expand Up @@ -160,6 +166,11 @@ void FaissCpu<T>::set_search_param(const AnnSearchParam& param)
this->index_refine_ = std::make_unique<faiss::IndexRefineFlat>(this->index_.get());
this->index_refine_.get()->k_factor = search_param.refine_ratio;
}

if (!thread_pool_ || num_threads_ != search_param.num_threads) {
num_threads_ = search_param.num_threads;
thread_pool_ = std::make_unique<FixedThreadPool>(num_threads_);
}
}

template <typename T>
Expand All @@ -172,7 +183,13 @@ void FaissCpu<T>::search(const T* queries,
{
static_assert(sizeof(size_t) == sizeof(faiss::idx_t),
"sizes of size_t and faiss::idx_t are different");
index_->search(batch_size, queries, k, distances, reinterpret_cast<faiss::idx_t*>(neighbors));

thread_pool_->submit(
[&](int i) {
// Use thread pool for batch size = 1. FAISS multi-threads internally for batch size > 1.
index_->search(batch_size, queries, k, distances, reinterpret_cast<faiss::idx_t*>(neighbors));
},
1);
}

template <typename T>
Expand Down Expand Up @@ -275,7 +292,14 @@ class FaissCpuFlat : public FaissCpu<T> {
}

// class FaissCpu is more like a IVF class, so need special treating here
void set_search_param(const typename ANN<T>::AnnSearchParam&) override{};
void set_search_param(const typename ANN<T>::AnnSearchParam& param) override
{
auto search_param = dynamic_cast<const typename FaissCpu<T>::SearchParam&>(param);
if (!this->thread_pool_ || this->num_threads_ != search_param.num_threads) {
this->num_threads_ = search_param.num_threads;
this->thread_pool_ = std::make_unique<FixedThreadPool>(this->num_threads_);
}
};

void save(const std::string& file) const override
{
Expand Down
110 changes: 2 additions & 108 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@
#include <utility>
#include <vector>

#include <omp.h>

#include "../common/ann_types.hpp"
#include "../common/thread_pool.hpp"
#include <hnswlib.h>

namespace raft::bench::ann {
Expand All @@ -53,112 +52,6 @@ struct hnsw_dist_t<uint8_t> {
using type = int;
};

class FixedThreadPool {
public:
FixedThreadPool(int num_threads)
{
if (num_threads < 1) {
throw std::runtime_error("num_threads must >= 1");
} else if (num_threads == 1) {
return;
}

tasks_ = new Task_[num_threads];

threads_.reserve(num_threads);
for (int i = 0; i < num_threads; ++i) {
threads_.emplace_back([&, i] {
auto& task = tasks_[i];
while (true) {
std::unique_lock<std::mutex> lock(task.mtx);
task.cv.wait(lock,
[&] { return task.has_task || finished_.load(std::memory_order_relaxed); });
if (finished_.load(std::memory_order_relaxed)) { break; }

task.task();
task.has_task = false;
}
});
}
}

~FixedThreadPool()
{
if (threads_.empty()) { return; }

finished_.store(true, std::memory_order_relaxed);
for (unsigned i = 0; i < threads_.size(); ++i) {
auto& task = tasks_[i];
std::lock_guard<std::mutex>(task.mtx);

task.cv.notify_one();
threads_[i].join();
}

delete[] tasks_;
}

template <typename Func, typename IdxT>
void submit(Func f, IdxT len)
{
if (threads_.empty()) {
for (IdxT i = 0; i < len; ++i) {
f(i);
}
return;
}

const int num_threads = threads_.size();
// one extra part for competition among threads
const IdxT items_per_thread = len / (num_threads + 1);
std::atomic<IdxT> cnt(items_per_thread * num_threads);

auto wrapped_f = [&](IdxT start, IdxT end) {
for (IdxT i = start; i < end; ++i) {
f(i);
}

while (true) {
IdxT i = cnt.fetch_add(1, std::memory_order_relaxed);
if (i >= len) { break; }
f(i);
}
};

std::vector<std::future<void>> futures;
futures.reserve(num_threads);
for (int i = 0; i < num_threads; ++i) {
IdxT start = i * items_per_thread;
auto& task = tasks_[i];
{
std::lock_guard lock(task.mtx);
(void)lock; // stop nvcc warning
task.task = std::packaged_task<void()>([=] { wrapped_f(start, start + items_per_thread); });
futures.push_back(task.task.get_future());
task.has_task = true;
}
task.cv.notify_one();
}

for (auto& fut : futures) {
fut.wait();
}
return;
}

private:
struct alignas(64) Task_ {
std::mutex mtx;
std::condition_variable cv;
bool has_task = false;
std::packaged_task<void()> task;
};

Task_* tasks_;
std::vector<std::thread> threads_;
std::atomic<bool> finished_{false};
};

template <typename T>
class HnswLib : public ANN<T> {
public:
Expand Down Expand Up @@ -281,6 +174,7 @@ void HnswLib<T>::search(
{
thread_pool_->submit(
[&](int i) {
// hnsw can only handle a single vector at a time.
get_search_knn_results_(query + i * dim_, k, indices + i * k, distances + i * k);
},
batch_size);
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class RaftCagra : public ANN<T> {

using BuildParam = raft::neighbors::cagra::index_params;

RaftCagra(Metric metric, int dim, const BuildParam& param)
RaftCagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1)
: ANN<T>(metric, dim),
index_params_(param),
dimension_(dim),
Expand Down
13 changes: 10 additions & 3 deletions docs/source/ann_benchmarks_param_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,16 @@ IVF-pq is an inverted-file index, which partitions the vectors into a series of
| `numProbes` | `search_params` | Y | Positive Integer >0 | | The closest number of clusters to search for each query vector. Larger values will improve recall but will search more points in the index. |
| `refine_ratio` | `search_params` | N| Positive Number >=0 | 0 | `refine_ratio * k` nearest neighbors are queried from the index initially and an additional refinement step improves recall by selecting only the best `k` neighbors. |

### `faiss_flat`
### `faiss_cpu_flat`

Use FAISS flat index on the CPU, which performs an exact search using brute-force and doesn't have any further build or search parameters.

### `faiss_ivf_flat`

| Parameter | Type | Required | Data Type | Default | Description |
|-----------|----------------|----------|---------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `numThreads` | `search_params` | N | Positive Integer >0 | 1 | Number of threads to use for queries. |

### `faiss_cpu_ivf_flat`

Use FAISS IVF-Flat index on CPU

Expand All @@ -106,8 +111,9 @@ Use FAISS IVF-Flat index on CPU
| `nlists` | `build_param` | Y | Positive Integer >0 | | Number of clusters to partition the vectors into. Larger values will put less points into each cluster but this will impact index build time as more clusters need to be trained. |
| `ratio` | `build_param` | N | Positive Integer >0 | 2 | `1/ratio` is the number of training points which should be used to train the clusters. |
| `nprobe` | `search_params` | Y | Positive Integer >0 | | The closest number of clusters to search for each query vector. Larger values will improve recall but will search more points in the index. |
| `numThreads` | `search_params` | N | Positive Integer >0 | 1 | Number of threads to use for queries. |

### `faiss_ivf_pq`
### `faiss_cpu_ivf_pq`

Use FAISS IVF-PQ index on CPU

Expand All @@ -120,6 +126,7 @@ Use FAISS IVF-PQ index on CPU
| `bitsPerCode` | `build_param` | N | Positive Integer [4-8] | 8 | Number of bits to use for each code. |
| `numProbes` | `search_params` | Y | Positive Integer >0 | | The closest number of clusters to search for each query vector. Larger values will improve recall but will search more points in the index. |
| `refine_ratio` | `search_params` | N| Positive Number >=0 | 0 | `refine_ratio * k` nearest neighbors are queried from the index initially and an additional refinement step improves recall by selecting only the best `k` neighbors. |
| `numThreads` | `search_params` | N | Positive Integer >0 | 1 | Number of threads to use for queries. |


## HNSW
Expand Down

0 comments on commit 95c12db

Please sign in to comment.