Skip to content

Commit

Permalink
utilize reordered index for rbc knn
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Mar 5, 2024
1 parent b3e998a commit 0129748
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ template <typename value_idx,
int thread_q = 2,
int tpb = 128,
int col_q = 2>
RAFT_KERNEL compute_final_dists_registers(const value_t* X_index,
RAFT_KERNEL compute_final_dists_registers(const value_t* X_reordered,
const value_t* X,
const value_int n_cols,
bitset_type* bitset,
Expand Down Expand Up @@ -238,7 +238,7 @@ RAFT_KERNEL compute_final_dists_registers(const value_t* X_index,
// the closest k neighbors, compute it and add to k-select
value_t dist = std::numeric_limits<value_t>::max();
if (z <= heap.warpKTop) {
const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind);
const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i));
value_t local_y_ptr[col_q];
for (value_int j = 0; j < n_cols; ++j) {
local_y_ptr[j] = y_ptr[j];
Expand Down Expand Up @@ -267,7 +267,7 @@ RAFT_KERNEL compute_final_dists_registers(const value_t* X_index,
// the closest k neighbors, compute it and add to k-select
value_t dist = std::numeric_limits<value_t>::max();
if (z <= heap.warpKTop) {
const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind);
const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i));
value_t local_y_ptr[col_q];
for (value_int j = 0; j < n_cols; ++j) {
local_y_ptr[j] = y_ptr[j];
Expand Down Expand Up @@ -313,7 +313,7 @@ template <typename value_idx = std::int64_t,
int col_q = 2,
typename value_int = std::uint32_t,
typename distance_func>
RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index,
RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_reordered,
const value_t* X,
value_int n_cols, // n_cols should be 2 or 3 dims
const value_idx* R_knn_inds,
Expand Down Expand Up @@ -408,7 +408,7 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index,
value_t dist = std::numeric_limits<value_t>::max();

if (z <= heap.warpKTop) {
const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind);
const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i));
value_t local_y_ptr[col_q];
for (value_int j = 0; j < n_cols; ++j) {
local_y_ptr[j] = y_ptr[j];
Expand All @@ -433,7 +433,7 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index,
value_t dist = std::numeric_limits<value_t>::max();

if (z <= heap.warpKTop) {
const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind);
const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i));
value_t local_y_ptr[col_q];
for (value_int j = 0; j < n_cols; ++j) {
local_y_ptr[j] = y_ptr[j];
Expand Down Expand Up @@ -1013,7 +1013,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle,
if (k <= 32)
block_rbc_kernel_registers<value_idx, value_t, 32, 2, 128, dims, value_int>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
R_knn_inds,
Expand All @@ -1033,7 +1033,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle,
else if (k <= 64)
block_rbc_kernel_registers<value_idx, value_t, 64, 3, 128, 2, value_int>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
R_knn_inds,
Expand All @@ -1052,7 +1052,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle,
else if (k <= 128)
block_rbc_kernel_registers<value_idx, value_t, 128, 3, 128, dims, value_int>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
R_knn_inds,
Expand All @@ -1072,7 +1072,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle,
else if (k <= 256)
block_rbc_kernel_registers<value_idx, value_t, 256, 4, 128, dims, value_int>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
R_knn_inds,
Expand All @@ -1092,7 +1092,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle,
else if (k <= 512)
block_rbc_kernel_registers<value_idx, value_t, 512, 8, 64, dims, value_int>
<<<n_query_rows, 64, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
R_knn_inds,
Expand All @@ -1112,7 +1112,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle,
else if (k <= 1024)
block_rbc_kernel_registers<value_idx, value_t, 1024, 8, 64, dims, value_int>
<<<n_query_rows, 64, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
R_knn_inds,
Expand Down Expand Up @@ -1182,7 +1182,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle,
128,
dims>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
bitset.data(),
Expand All @@ -1208,7 +1208,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle,
128,
dims>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
bitset.data(),
Expand All @@ -1234,7 +1234,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle,
128,
dims>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
bitset.data(),
Expand All @@ -1260,7 +1260,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle,
128,
dims>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
bitset.data(),
Expand All @@ -1285,7 +1285,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle,
8,
64,
dims><<<n_query_rows, 64, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
bitset.data(),
Expand All @@ -1310,7 +1310,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle,
8,
64,
dims><<<n_query_rows, 64, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
bitset.data(),
Expand Down

0 comments on commit 0129748

Please sign in to comment.