Skip to content

Commit

Permalink
Merge branch 'branch-24.04' into rhdong/bitmap
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong authored Mar 11, 2024
2 parents 7d960bd + dddb05a commit 8a3b759
Show file tree
Hide file tree
Showing 8 changed files with 631 additions and 223 deletions.
4 changes: 2 additions & 2 deletions cpp/include/raft/neighbors/ball_cover-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ void eps_nn(raft::resources const& handle,
query.extent(0),
adj.data_handle(),
vd.data_handle(),
spatial::knn::detail::EuclideanFunc<value_t, int_t>());
spatial::knn::detail::EuclideanSqFunc<value_t, int_t>());
}

/**
Expand Down Expand Up @@ -392,7 +392,7 @@ void eps_nn(raft::resources const& handle,
adj_ia.data_handle(),
adj_ja.data_handle(),
vd.data_handle(),
spatial::knn::detail::EuclideanFunc<value_t, int_t>());
spatial::knn::detail::EuclideanSqFunc<value_t, int_t>());
}

/**
Expand Down
12 changes: 12 additions & 0 deletions cpp/include/raft/neighbors/ball_cover_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class BallCoverIndex {
R_1nn_dists(raft::make_device_vector<value_t, matrix_idx>(handle, m_)),
R_closest_landmark_dists(raft::make_device_vector<value_t, matrix_idx>(handle, m_)),
R(raft::make_device_matrix<value_t, matrix_idx>(handle, sqrt(m_), n_)),
X_reordered(raft::make_device_matrix<value_t, matrix_idx>(handle, m_, n_)),
R_radius(raft::make_device_vector<value_t, matrix_idx>(handle, sqrt(m_))),
index_trained(false)
{
Expand All @@ -93,6 +94,8 @@ class BallCoverIndex {
R_1nn_dists(raft::make_device_vector<value_t, matrix_idx>(handle, X_.extent(0))),
R_closest_landmark_dists(raft::make_device_vector<value_t, matrix_idx>(handle, X_.extent(0))),
R(raft::make_device_matrix<value_t, matrix_idx>(handle, sqrt(X_.extent(0)), X_.extent(1))),
X_reordered(
raft::make_device_matrix<value_t, matrix_idx>(handle, X_.extent(0), X_.extent(1))),
R_radius(raft::make_device_vector<value_t, matrix_idx>(handle, sqrt(X_.extent(0)))),
index_trained(false)
{
Expand Down Expand Up @@ -122,6 +125,10 @@ class BallCoverIndex {
{
return R_closest_landmark_dists.view();
}
auto get_X_reordered() const -> raft::device_matrix_view<const value_t, matrix_idx, row_major>
{
return X_reordered.view();
}

raft::device_vector_view<value_idx, matrix_idx> get_R_indptr() { return R_indptr.view(); }
raft::device_vector_view<value_idx, matrix_idx> get_R_1nn_cols() { return R_1nn_cols.view(); }
Expand All @@ -132,6 +139,10 @@ class BallCoverIndex {
{
return R_closest_landmark_dists.view();
}
raft::device_matrix_view<value_t, matrix_idx, row_major> get_X_reordered()
{
return X_reordered.view();
}
raft::device_matrix_view<const value_t, matrix_idx, row_major> get_X() const { return X; }

raft::distance::DistanceType get_metric() const { return metric; }
Expand Down Expand Up @@ -162,6 +173,7 @@ class BallCoverIndex {
raft::device_vector<value_t, matrix_idx> R_radius;

raft::device_matrix<value_t, matrix_idx, row_major> R;
raft::device_matrix<value_t, matrix_idx, row_major> X_reordered;

protected:
bool index_trained;
Expand Down
34 changes: 10 additions & 24 deletions cpp/include/raft/spatial/knn/detail/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ void construct_landmark_1nn(raft::resources const& handle,
index.get_R_indptr().data_handle(),
index.n_landmarks + 1,
resource::get_cuda_stream(handle));

// reorder X to allow aligned access
raft::matrix::copy_rows<value_t, value_idx>(
handle, index.get_X(), index.get_X_reordered(), index.get_R_1nn_cols());
}

/**
Expand Down Expand Up @@ -337,12 +341,6 @@ void perform_rbc_query(raft::resources const& handle,
/**
* Perform eps-select
*
* a. Map 1 row to each warp/block
* b. Add closest k R points to heap
* c. Iterate through batches of R, having each thread in the warp load a set
* of distances y from R (only if d(q, r) < 3 * distance to closest r) and
* marking the distance to be computed between x, y only
* if knn[k].distance >= d(x_i, R_k) + d(R_k, y)
*/
template <typename value_idx,
typename value_t,
Expand All @@ -355,7 +353,7 @@ void perform_rbc_eps_nn_query(
const value_t* query,
value_int n_query_pts,
value_t eps,
const value_t* landmark_dists,
const value_t* landmarks,
dist_func dfunc,
bool* adj,
value_idx* vd)
Expand All @@ -367,7 +365,7 @@ void perform_rbc_eps_nn_query(
resource::sync_stream(handle);

rbc_eps_pass<value_idx, value_t, value_int, matrix_idx>(
handle, index, query, n_query_pts, eps, landmark_dists, dfunc, adj, vd);
handle, index, query, n_query_pts, eps, landmarks, dfunc, adj, vd);

resource::sync_stream(handle);
}
Expand All @@ -384,14 +382,14 @@ void perform_rbc_eps_nn_query(
value_int n_query_pts,
value_t eps,
value_int* max_k,
const value_t* landmark_dists,
const value_t* landmarks,
dist_func dfunc,
value_idx* adj_ia,
value_idx* adj_ja,
value_idx* vd)
{
rbc_eps_pass<value_idx, value_t, value_int, matrix_idx>(
handle, index, query, n_query_pts, eps, max_k, landmark_dists, dfunc, adj_ia, adj_ja, vd);
handle, index, query, n_query_pts, eps, max_k, landmarks, dfunc, adj_ia, adj_ja, vd);

resource::sync_stream(handle);
}
Expand Down Expand Up @@ -664,15 +662,9 @@ void rbc_eps_nn_query(raft::resources const& handle,
{
ASSERT(index.is_index_trained(), "index must be previously trained");

auto R_dists =
raft::make_device_matrix<value_t, matrix_idx>(handle, index.n_landmarks, n_query_pts);

// find all landmarks that might have points in range
compute_landmark_dists(handle, index, query, n_query_pts, R_dists.data_handle());

// query all points and write to adj
perform_rbc_eps_nn_query(
handle, index, query, n_query_pts, eps, R_dists.data_handle(), dfunc, adj, vd);
handle, index, query, n_query_pts, eps, index.get_R().data_handle(), dfunc, adj, vd);
}

template <typename value_idx = std::int64_t,
Expand All @@ -693,20 +685,14 @@ void rbc_eps_nn_query(raft::resources const& handle,
{
ASSERT(index.is_index_trained(), "index must be previously trained");

auto R_dists =
raft::make_device_matrix<value_t, matrix_idx>(handle, index.n_landmarks, n_query_pts);

// find all landmarks that might have points in range
compute_landmark_dists(handle, index, query, n_query_pts, R_dists.data_handle());

// query all points and write to adj
perform_rbc_eps_nn_query(handle,
index,
query,
n_query_pts,
eps,
max_k,
R_dists.data_handle(),
index.get_R().data_handle(),
dfunc,
adj_ia,
adj_ja,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(
std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::DistFunc);

instantiate_raft_spatial_knn_detail_rbc_eps_pass(
std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanFunc);
std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanSqFunc);

#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two
#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one
Expand Down
Loading

0 comments on commit 8a3b759

Please sign in to comment.