Skip to content

Commit

Permalink
Updating faiss cpu to override search params
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Sep 13, 2023
1 parent 74e6a5d commit fcd029f
Showing 1 changed file with 69 additions and 21 deletions.
90 changes: 69 additions & 21 deletions cpp/bench/ann/src/faiss/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class FaissGpu : public ANN<T> {

void build(const T* dataset, size_t nrow, cudaStream_t stream = 0) final;

void set_search_param(const AnnSearchParam& param) override;
virtual void set_search_param(const AnnSearchParam& param) {}

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
Expand Down Expand Up @@ -147,6 +147,7 @@ class FaissGpu : public ANN<T> {
cudaEvent_t sync_{nullptr};
cudaStream_t faiss_default_stream_{nullptr};
double training_sample_fraction_;
std::unique_ptr<faiss::SearchParameters> search_params_;
};

template <typename T>
Expand Down Expand Up @@ -181,20 +182,6 @@ void FaissGpu<T>::build(const T* dataset, size_t nrow, cudaStream_t stream)
stream_wait(stream);
}

template <typename T>
void FaissGpu<T>::set_search_param(const AnnSearchParam& param)
{
auto search_param = dynamic_cast<const SearchParam&>(param);
int nprobe = search_param.nprobe;
assert(nprobe <= nlist_);
dynamic_cast<faiss::gpu::GpuIndexIVF*>(index_.get())->setNumProbes(nprobe);

if (search_param.refine_ratio > 1.0) {
this->index_refine_ = std::make_unique<faiss::IndexRefineFlat>(this->index_.get());
this->index_refine_.get()->k_factor = search_param.refine_ratio;
}
}

template <typename T>
void FaissGpu<T>::search(const T* queries,
int batch_size,
Expand All @@ -203,10 +190,14 @@ void FaissGpu<T>::search(const T* queries,
float* distances,
cudaStream_t stream) const
{
static_assert(sizeof(size_t) == sizeof(faiss::Index::idx_t),
"sizes of size_t and faiss::Index::idx_t are different");
index_->search(
batch_size, queries, k, distances, reinterpret_cast<faiss::Index::idx_t*>(neighbors));
static_assert(sizeof(size_t) == sizeof(faiss::idx_t),
"sizes of size_t and faiss::idx_t are different");
index_->search(batch_size,
queries,
k,
distances,
reinterpret_cast<faiss::idx_t*>(neighbors),
search_params_.get());
stream_wait(stream);
}

Expand Down Expand Up @@ -245,6 +236,22 @@ class FaissGpuIVFFlat : public FaissGpu<T> {
&(this->gpu_resource_), dim, param.nlist, this->metric_type_, config);
}

void set_search_param(const typename FaissGpu<T>::AnnSearchParam& param) override
{
auto search_param = dynamic_cast<const typename FaissGpu<T>::SearchParam&>(param);
int nprobe = search_param.nprobe;
assert(nprobe <= nlist_);

faiss::IVFSearchParameters faiss_search_params;
faiss_search_params.nprobe = nprobe;
this->search_params_ = std::make_unique<faiss::IVFSearchParameters>(faiss_search_params);

if (search_param.refine_ratio > 1.0) {
this->index_refine_ = std::make_unique<faiss::IndexRefineFlat>(this->index_.get());
this->index_refine_.get()->k_factor = search_param.refine_ratio;
}
}

void save(const std::string& file) const override
{
this->template save_<faiss::gpu::GpuIndexIVFFlat, faiss::IndexIVFFlat>(file);
Expand Down Expand Up @@ -280,6 +287,23 @@ class FaissGpuIVFPQ : public FaissGpu<T> {
config);
}

void set_search_param(const typename FaissGpu<T>::AnnSearchParam& param) override
{
auto search_param = dynamic_cast<const typename FaissGpu<T>::SearchParam&>(param);
int nprobe = search_param.nprobe;
assert(nprobe <= nlist_);

faiss::IVFPQSearchParameters faiss_search_params;
faiss_search_params.nprobe = nprobe;

this->search_params_ = std::make_unique<faiss::IVFPQSearchParameters>(faiss_search_params);

if (search_param.refine_ratio > 1.0) {
this->index_refine_ = std::make_unique<faiss::IndexRefineFlat>(this->index_.get());
this->index_refine_.get()->k_factor = search_param.refine_ratio;
}
}

void save(const std::string& file) const override
{
this->template save_<faiss::gpu::GpuIndexIVFPQ, faiss::IndexIVFPQ>(file);
Expand All @@ -293,6 +317,8 @@ class FaissGpuIVFPQ : public FaissGpu<T> {
template <typename T>
class FaissGpuIVFSQ : public FaissGpu<T> {
public:
using typename FaissGpu<T>::AnnSearchParam;
using typename FaissGpu<T>::SearchParam;
struct BuildParam : public FaissGpu<T>::BuildParam {
std::string quantizer_type;
};
Expand All @@ -315,6 +341,23 @@ class FaissGpuIVFSQ : public FaissGpu<T> {
&(this->gpu_resource_), dim, param.nlist, qtype, this->metric_type_, true, config);
}

virtual void set_search_param(const typename FaissGpu<T>::AnnSearchParam& param) override
{
auto search_param = dynamic_cast<const SearchParam&>(param);
int nprobe = search_param.nprobe;
assert(nprobe <= nlist_);

faiss::IVFSearchParameters faiss_search_params;
faiss_search_params.nprobe = nprobe;

this->search_params_ = std::make_unique<faiss::IVFSearchParameters>(faiss_search_params);

if (search_param.refine_ratio > 1.0) {
this->index_refine_ = std::make_unique<faiss::IndexRefineFlat>(this->index_.get());
this->index_refine_.get()->k_factor = search_param.refine_ratio;
}
}

void save(const std::string& file) const override
{
this->template save_<faiss::gpu::GpuIndexIVFScalarQuantizer, faiss::IndexIVFScalarQuantizer>(
Expand All @@ -339,9 +382,14 @@ class FaissGpuFlat : public FaissGpu<T> {
&(this->gpu_resource_), dim, this->metric_type_, config);
}

// class FaissGpu is more like a IVF class, so need special treating here
void set_search_param(const typename ANN<T>::AnnSearchParam&) override{};
void set_search_param(const typename FaissGpu<T>::AnnSearchParam& param) override
{
auto search_param = dynamic_cast<const typename FaissGpu<T>::SearchParam&>(param);
int nprobe = search_param.nprobe;
assert(nprobe <= nlist_);

this->search_params_ = std::make_unique<faiss::SearchParameters>();
}
void save(const std::string& file) const override
{
this->template save_<faiss::gpu::GpuIndexFlat, faiss::IndexFlat>(file);
Expand Down

0 comments on commit fcd029f

Please sign in to comment.