Skip to content

Commit

Permalink
Turning off faiss refinement for the time being.
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Nov 1, 2023
1 parent c6ec247 commit b71ef73
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 49 deletions.
4 changes: 2 additions & 2 deletions cpp/bench/ann/src/common/ann_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ class ANN : public AnnBase {
// The advantage of this way is that index has smaller size
// and many indices can share one dataset.
//
// AlgoProperty::need_dataset_when_search of such algorithm should be true,
// SearchParam::needs_dataset() of such algorithm should be true,
// and set_search_dataset() should save the passed-in pointer somewhere.
// The client code should call set_search_dataset() before searching,
// and should not release dataset before searching is finished.
virtual void set_search_dataset(const T* /*dataset*/, size_t /*nrow*/) { printf("Setting \n"); };
virtual void set_search_dataset(const T* /*dataset*/, size_t /*nrow*/){};
};

} // namespace raft::bench::ann
Expand Down
1 change: 1 addition & 0 deletions cpp/bench/ann/src/common/benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ template <typename T>
std::unique_ptr<typename raft::bench::ann::ANN<T>::AnnSearchParam> create_search_param(
const std::string& algo, const nlohmann::json& conf)
{
printf("INside create_search_param\n");
static auto fname = get_fun_name(reinterpret_cast<void*>(&create_search_param<T>));
auto handle = load_lib(algo);
auto fun_addr = dlsym(handle, fname.c_str());
Expand Down
20 changes: 8 additions & 12 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ void bench_search(::benchmark::State& state,
std::ptrdiff_t batch_offset = 0;
std::size_t queries_processed = 0;

printf("Starting benchmark search\n");
double total_time = 0;

const auto& sp_json = index.search_params[search_param_ix];
Expand Down Expand Up @@ -202,7 +201,6 @@ void bench_search(::benchmark::State& state,
index_file = index.file;
}

printf("Loading index from file\n");
std::unique_ptr<typename ANN<T>::AnnSearchParam> search_param;
ANN<T>* algo;
try {
Expand All @@ -217,25 +215,15 @@ void bench_search(::benchmark::State& state,
search_param->metric_objective = metric_objective;
} catch (const std::exception& e) {
state.SkipWithError("Failed to create an algo: " + std::string(e.what()));
}

printf("Set search params\n");
try {
algo->set_search_param(*search_param);
} catch (const std::exception& ex) {
state.SkipWithError("An error occurred setting search parameters: " + std::string(ex.what()));
return;
}

printf("Setting search dataset\n");
auto algo_property = parse_algo_property(algo->get_preference(), sp_json);
current_algo_props = std::make_shared<AlgoProperty>(algo_property.dataset_memory_type,
algo_property.query_memory_type);

printf("AFTER!\n");
if (search_param->needs_dataset()) {
try {
printf("About to set search datast\n");
algo->set_search_dataset(dataset->base_set(current_algo_props->dataset_memory_type),
dataset->base_set_size());
} catch (const std::exception& ex) {
Expand All @@ -245,6 +233,14 @@ void bench_search(::benchmark::State& state,
return;
}
}

try {
algo->set_search_param(*search_param);

} catch (const std::exception& ex) {
state.SkipWithError("An error occurred setting search parameters: " + std::string(ex.what()));
return;
}
}

const auto algo_property = *current_algo_props;
Expand Down
60 changes: 31 additions & 29 deletions cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
#include <faiss/index_io.h>
#include <omp.h>

#include <raft/core/device_resources.hpp>
#include <raft/core/resource/stream_view.hpp>

#include <cassert>
#include <memory>
#include <stdexcept>
Expand Down Expand Up @@ -102,6 +105,7 @@ class FaissGpu : public ANN<T> {
RAFT_CUDA_TRY(cudaGetDevice(&device_));
RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming));
faiss_default_stream_ = gpu_resource_.getDefaultStream(device_);
raft::resource::set_cuda_stream(handle_, faiss_default_stream_);
}

virtual ~FaissGpu() noexcept { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_)); }
Expand All @@ -110,7 +114,7 @@ class FaissGpu : public ANN<T> {

virtual void set_search_param(const FaissGpu<T>::AnnSearchParam& param) {}

virtual void set_search_dataset(const T* dataset, size_t nrow) {}
void set_search_dataset(const T* dataset, size_t nrow) override { dataset_ = dataset; }

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
Expand All @@ -126,7 +130,7 @@ class FaissGpu : public ANN<T> {
AlgoProperty property;
// to enable building big dataset which is larger than GPU memory
property.dataset_memory_type = MemoryType::Host;
property.query_memory_type = MemoryType::Device;
property.query_memory_type = MemoryType::Host;
return property;
}

Expand All @@ -145,14 +149,17 @@ class FaissGpu : public ANN<T> {

mutable faiss::gpu::StandardGpuResources gpu_resource_;
std::unique_ptr<faiss::gpu::GpuIndex> index_;
std::unique_ptr<faiss::IndexRefineFlat> index_refine_;
std::unique_ptr<faiss::IndexRefineFlat> index_refine_{nullptr};
faiss::MetricType metric_type_;
int nlist_;
int device_;
cudaEvent_t sync_{nullptr};
cudaStream_t faiss_default_stream_{nullptr};
double training_sample_fraction_;
std::unique_ptr<faiss::SearchParameters> search_params_;
const T* dataset_;
raft::device_resources handle_;
float refine_ratio_ = 1.0;
};

template <typename T>
Expand Down Expand Up @@ -198,12 +205,23 @@ void FaissGpu<T>::search(const T* queries,
static_assert(sizeof(size_t) == sizeof(faiss::idx_t),
"sizes of size_t and faiss::idx_t are different");

if (index_refine_->k_factor > 1) {
printf("Using refine!\n");
index_refine_->search(
batch_size, queries, k, distances, reinterpret_cast<faiss::idx_t*>(neighbors));
if (this->refine_ratio_ > 1.0) {
// TODO: FAISS changed their search APIs to accept the search parameters as a struct object
// but their refine API doesn't allow the struct to be passed in. Once this is fixed, we
// need to re-enable refinement below
// index_refine_->search(batch_size, queries, k, distances,
// reinterpret_cast<faiss::idx_t*>(neighbors), this->search_params_.get()); Related FAISS issue:
// https://github.com/facebookresearch/faiss/issues/3118
throw std::runtime_error(
"FAISS doesn't support refinement in their new APIs so this feature is disabled in the "
"benchmarks for the time being.");
} else {
index_->search(batch_size, queries, k, distances, reinterpret_cast<faiss::idx_t*>(neighbors));
index_->search(batch_size,
queries,
k,
distances,
reinterpret_cast<faiss::idx_t*>(neighbors),
this->search_params_.get());
}
stream_wait(stream);
}
Expand Down Expand Up @@ -258,6 +276,7 @@ class FaissGpuIVFFlat : public FaissGpu<T> {
faiss::IVFSearchParameters faiss_search_params;
faiss_search_params.nprobe = nprobe;
this->search_params_ = std::make_unique<faiss::IVFSearchParameters>(faiss_search_params);
this->refine_ratio_ = search_param.refine_ratio;
}

void save(const std::string& file) const override
Expand Down Expand Up @@ -296,19 +315,12 @@ class FaissGpuIVFPQ : public FaissGpu<T> {
config);
}

void set_search_dataset(const T* dataset, size_t nrow) override
{
printf("Setting search ataset for refine\n");
dataset_ = dataset;
}

void set_search_param(const typename FaissGpu<T>::AnnSearchParam& param) override
{
printf("Setting ivfpq search params\n");
auto search_param = dynamic_cast<const typename FaissGpu<T>::SearchParam&>(param);
int nprobe = search_param.nprobe;
assert(nprobe <= nlist_);

this->refine_ratio_ = search_param.refine_ratio;
faiss::IVFPQSearchParameters faiss_search_params;
faiss_search_params.nprobe = nprobe;

Expand All @@ -329,8 +341,6 @@ class FaissGpuIVFPQ : public FaissGpu<T> {
{
this->template load_<faiss::gpu::GpuIndexIVFPQ, faiss::IndexIVFPQ>(file);
}

const T* dataset_;
};

// TODO: Enable this in cmake
Expand All @@ -342,12 +352,6 @@ class FaissGpuIVFSQ : public FaissGpu<T> {
std::string quantizer_type;
};

struct SearchParam : public FaissGpu<T>::SearchParam {
int nprobe;
float refine_ratio = 1.0;
auto needs_dataset() const -> bool override { return true; }
};

FaissGpuIVFSQ(Metric metric, int dim, const BuildParam& param) : FaissGpu<T>(metric, dim, param)
{
faiss::ScalarQuantizer::QuantizerType qtype;
Expand All @@ -366,7 +370,6 @@ class FaissGpuIVFSQ : public FaissGpu<T> {
&(this->gpu_resource_), dim, param.nlist, qtype, this->metric_type_, true, config);
}

void set_search_dataset(const T* dataset, size_t nrow) override { this->dataset_ = dataset; }
void set_search_param(const typename FaissGpu<T>::AnnSearchParam& param) override
{
auto search_param = dynamic_cast<const typename FaissGpu<T>::SearchParam&>(param);
Expand All @@ -377,9 +380,10 @@ class FaissGpuIVFSQ : public FaissGpu<T> {
faiss_search_params.nprobe = nprobe;

this->search_params_ = std::make_unique<faiss::IVFSearchParameters>(faiss_search_params);

this->refine_ratio_ = search_param.refine_ratio;
if (search_param.refine_ratio > 1.0) {
this->index_refine_ = std::make_unique<faiss::IndexRefineFlat>(this->index_.get(), dataset_);
this->index_refine_ =
std::make_unique<faiss::IndexRefineFlat>(this->index_.get(), this->dataset_);
this->index_refine_.get()->k_factor = search_param.refine_ratio;
}
}
Expand All @@ -394,8 +398,6 @@ class FaissGpuIVFSQ : public FaissGpu<T> {
this->template load_<faiss::gpu::GpuIndexIVFScalarQuantizer, faiss::IndexIVFScalarQuantizer>(
file);
}

const T* dataset_;
};

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: faiss_gpu_ivf_flat
groups:
base:
build:
nlists: [1024, 2048, 4096, 8192, 16000, 32000]
ratio: [1, 10, 25]
useFloat16: [True, False]
nlist: [2048]
ratio: [1, 4, 10]
useFloat16: [False]
search:
numProbes: [1, 5, 10, 50, 100, 200, 500, 1000, 2000]
refine_ratio: [1, 2, 4, 10]
nprobe: [2048]
refine_ratio: [1]
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ groups:
useFloat16: [False]
search:
nprobe: [1, 5, 10, 50, 100, 200]
refine_ratio: [1, 2, 4]
refine_ratio: [1]
test:
build:
nlist: [1024]
Expand Down

0 comments on commit b71ef73

Please sign in to comment.