Skip to content

Commit

Permalink
Various fixes to reproducible benchmarks (#1800)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Artem M. Chirkin (https://github.com/achirkin)
  - Divye Gala (https://github.com/divyegala)
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #1800
  • Loading branch information
cjnolet authored Sep 11, 2023
1 parent 12480cf commit c59c9d1
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 23 deletions.
3 changes: 1 addition & 2 deletions cpp/bench/ann/src/ggnn/ggnn_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ template <typename T>
void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::Ggnn<T>::BuildParam& param)
{
param.dataset_size = conf.at("dataset_size");
param.k = conf.at("k");
param.k = conf.at("k");

if (conf.contains("k_build")) { param.k_build = conf.at("k_build"); }
if (conf.contains("segment_size")) { param.segment_size = conf.at("segment_size"); }
Expand Down
22 changes: 4 additions & 18 deletions cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ class Ggnn : public ANN<T> {
int num_layers{4}; // L
float tau{0.5};
int refine_iterations{2};

size_t dataset_size;
int k; // GGNN requires to know k during building
};

Expand Down Expand Up @@ -182,24 +180,17 @@ GgnnImpl<T, measure, D, KBuild, KQuery, S>::GgnnImpl(Metric metric,
}

if (dim != D) { throw std::runtime_error("mis-matched dim"); }

int device;
RAFT_CUDA_TRY(cudaGetDevice(&device));

ggnn_ = std::make_unique<GGNNGPUInstance>(
device, build_param_.dataset_size, build_param_.num_layers, true, build_param_.tau);
}

template <typename T, DistanceMeasure measure, int D, int KBuild, int KQuery, int S>
void GgnnImpl<T, measure, D, KBuild, KQuery, S>::build(const T* dataset,
size_t nrow,
cudaStream_t stream)
{
if (nrow != build_param_.dataset_size) {
throw std::runtime_error(
"build_param_.dataset_size = " + std::to_string(build_param_.dataset_size) +
" , but nrow = " + std::to_string(nrow));
}
int device;
RAFT_CUDA_TRY(cudaGetDevice(&device));
ggnn_ = std::make_unique<GGNNGPUInstance>(
device, nrow, build_param_.num_layers, true, build_param_.tau);

ggnn_->set_base_data(dataset);
ggnn_->set_stream(stream);
Expand All @@ -212,11 +203,6 @@ void GgnnImpl<T, measure, D, KBuild, KQuery, S>::build(const T* dataset,
template <typename T, DistanceMeasure measure, int D, int KBuild, int KQuery, int S>
void GgnnImpl<T, measure, D, KBuild, KQuery, S>::set_search_dataset(const T* dataset, size_t nrow)
{
if (nrow != build_param_.dataset_size) {
throw std::runtime_error(
"build_param_.dataset_size = " + std::to_string(build_param_.dataset_size) +
" , but nrow = " + std::to_string(nrow));
}
ggnn_->set_base_data(dataset);
}

Expand Down
6 changes: 4 additions & 2 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include <utility>
#include <vector>

#include <omp.h>

#include "../common/ann_types.hpp"
#include <hnswlib.h>

Expand Down Expand Up @@ -164,13 +166,13 @@ class HnswLib : public ANN<T> {
struct BuildParam {
int M;
int ef_construction;
int num_threads{1};
int num_threads = omp_get_num_procs();
};

using typename ANN<T>::AnnSearchParam;
struct SearchParam : public AnnSearchParam {
int ef;
int num_threads{1};
int num_threads = omp_get_num_procs();
};

HnswLib(Metric metric, int dim, const BuildParam& param);
Expand Down
2 changes: 1 addition & 1 deletion docs/source/ann_benchmarks_param_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,4 @@ IVF-pq is an inverted-file index, which partitions the vectors into a series of
| `ef` | `search_param` | Y | Positive Integer >0 | | Size of the dynamic list for the nearest neighbors used for search. Higher value leads to more accurate but slower search. Cannot be lower than `k`. |
| `numThreads` | `search_params` | N | Positive Integer >0 | 1 | Number of threads to use for queries. |

Please refer to [HNSW algorithm parameters guide](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md) from `hnswlib` to learn more about these arguments.
Please refer to [HNSW algorithm parameters guide] from `hnswlib` to learn more about these arguments.

0 comments on commit c59c9d1

Please sign in to comment.