Skip to content

Commit

Permalink
Fix medoid output and address other comments
Browse files Browse the repository at this point in the history
  • Loading branch information
bkarsin committed Sep 27, 2024
1 parent 91f8089 commit 8a1c8bf
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 31 deletions.
18 changes: 13 additions & 5 deletions cpp/include/cuvs/neighbors/vamana.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ namespace cuvs::neighbors::vamana {
*
*/
struct index_params : cuvs::neighbors::index_params {
/** Degree of output graph. */
/** Maximum degree of output graph corresponds to the R parameter in the original Vamana
* literature. */
uint32_t graph_degree = 32;
/** Maximum number of visited nodes per search **/
/** Maximum number of visited nodes per search corresponds to the L parameter in the Vamana
* literature **/
uint32_t visited_size = 64;
/** Number of Vamana iterations. */
/** Number of Vamana vector insertion iterations (each iteration inserts all vectors). */
uint32_t vamana_iters = 1;
/** Alpha for pruning parameter */
float alpha = 1.2;
Expand Down Expand Up @@ -119,6 +121,9 @@ struct index : cuvs::neighbors::index {
return graph_view_;
}

/** Return the id of the vector selected as the medoid. */
[[nodiscard]] inline auto medoid() const noexcept -> IdxT { return medoid_id_; }

// Don't allow copying the index for performance reasons (try avoiding copying data)
index(const index&) = delete;
index(index&&) = default;
Expand All @@ -144,11 +149,13 @@ struct index : cuvs::neighbors::index {
cuvs::distance::DistanceType metric,
raft::mdspan<const T, raft::matrix_extent<int64_t>, raft::row_major, data_accessor> dataset,
raft::mdspan<const IdxT, raft::matrix_extent<int64_t>, raft::row_major, graph_accessor>
vamana_graph)
vamana_graph,
IdxT medoid_id)
: cuvs::neighbors::index(),
metric_(metric),
graph_(raft::make_device_matrix<IdxT, int64_t>(res, 0, 0)),
dataset_(make_aligned_dataset(res, dataset, 16))
dataset_(make_aligned_dataset(res, dataset, 16)),
medoid_id_(medoid_id)
{
RAFT_EXPECTS(dataset.extent(0) == vamana_graph.extent(0),
"Dataset and vamana_graph must have equal number of rows");
Expand Down Expand Up @@ -197,6 +204,7 @@ struct index : cuvs::neighbors::index {
raft::device_matrix<IdxT, int64_t, raft::row_major> graph_;
raft::device_matrix_view<const IdxT, int64_t, raft::row_major> graph_view_;
std::unique_ptr<neighbors::dataset<int64_t>> dataset_;
IdxT medoid_id_;
};
/**
* @}
Expand Down
12 changes: 8 additions & 4 deletions cpp/src/neighbors/detail/vamana/vamana_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ void batched_insert_vamana(
const index_params& params,
raft::mdspan<const T, raft::matrix_extent<int64_t>, raft::row_major, Accessor> dataset,
raft::host_matrix_view<IdxT, int64_t> graph,
IdxT* medoid_id,
cuvs::distance::DistanceType metric,
int dim)
{
Expand Down Expand Up @@ -186,7 +187,9 @@ void batched_insert_vamana(
}

// Random medoid has minor impact on recall
int medoid_id = rand() % N;
// TODO: use heuristic for better medoid selection, issue:
// https://github.com/rapidsai/cuvs/issues/355
*medoid_id = rand() % N;

// size of current batch of inserts, increases logarithmically until max_batchsize
int step_size = 1;
Expand All @@ -213,7 +216,7 @@ void batched_insert_vamana(
dataset,
query_list_ptr.data_handle(),
step_size,
medoid_id,
*medoid_id,
degree,
dataset.extent(0),
visited_size,
Expand Down Expand Up @@ -381,12 +384,13 @@ index<T, IdxT> build(

cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded;

IdxT medoid_id;
batched_insert_vamana<T, float, IdxT, Accessor>(
res, params, dataset, vamana_graph.view(), metric, dim);
res, params, dataset, vamana_graph.view(), &medoid_id, metric, dim);

try {
return index<T, IdxT>(
res, params.metric, dataset, raft::make_const_mdspan(vamana_graph.view()));
res, params.metric, dataset, raft::make_const_mdspan(vamana_graph.view()), medoid_id);
} catch (std::bad_alloc& e) {
RAFT_LOG_DEBUG("Insufficient GPU memory to construct VAMANA index with dataset on GPU");
// We just add the graph. User is expected to update dataset separately (e.g allocating in
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/vamana/vamana_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void serialize(raft::resources const& res,
index_of.seekp(file_offset, index_of.beg);
uint32_t max_degree = 0;
size_t index_size = 24; // Starting metadata
uint32_t start = 0;
uint32_t start = static_cast<uint32_t>(index_.medoid());
size_t num_frozen_points = 0;
uint32_t max_observed_degree = 0;

Expand Down
22 changes: 1 addition & 21 deletions cpp/src/neighbors/detail/vamana/vamana_structs.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace cuvs::neighbors::vamana::detail {

#define FULL_BITMASK 0xFFFFFFFF

// Currently supported values for graph_degree.
static const int DEGREE_SIZES[4] = {32, 64, 128, 256};

// Object used to store id,distance combination graph construction operations
Expand Down Expand Up @@ -114,27 +115,6 @@ class Point {
id = other.id;
return *this;
}

/*
// Computes Cosine dist. Uses 2 registers to increase pipeline efficiency and ILP
// Assumes coordinates are normalized so each vector is of unit length. This lets us
// perform a dot-product instead of the full cosine distance computation.
// __device__ SUMTYPE cosine(Point<T,SUMTYPE,Dim>* other, bool test) {return NULL;}
__device__ SUMTYPE cosine(Point<T,SUMTYPE>* other) {
SUMTYPE total[2]={0,0};
for(int i=0; i<Dim; i+=2) {
total[0] += ((SUMTYPE)((SUMTYPE)coords[i] * (SUMTYPE)other->coords[i]));
total[1] += ((SUMTYPE)((SUMTYPE)coords[i+1] * (SUMTYPE)other->coords[i+1]));
}
return (SUMTYPE)1.0 - (total[0]+total[1]);
}
__forceinline__ __device__ SUMTYPE dist(Point<T,SUMTYPE>* other, int metric) {
if(metric == 0) return l2(other);
else return cosine(other);
}
*/
};

/* L2 fallback for low dimension when ILP is not possible */
Expand Down

0 comments on commit 8a1c8bf

Please sign in to comment.