diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index 0337f6891f..71f9f7c492 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -157,7 +157,7 @@ template -RAFT_KERNEL compute_final_dists_registers(const value_t* X_index, +RAFT_KERNEL compute_final_dists_registers(const value_t* X_reordered, const value_t* X, const value_int n_cols, bitset_type* bitset, @@ -238,7 +238,7 @@ RAFT_KERNEL compute_final_dists_registers(const value_t* X_index, // the closest k neighbors, compute it and add to k-select value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -267,7 +267,7 @@ RAFT_KERNEL compute_final_dists_registers(const value_t* X_index, // the closest k neighbors, compute it and add to k-select value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -313,7 +313,7 @@ template -RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, +RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_reordered, const value_t* X, value_int n_cols, // n_cols should be 2 or 3 dims const value_idx* R_knn_inds, @@ -408,7 +408,7 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -433,7 +433,7 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -1013,7 +1013,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, if (k <= 32) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -1033,7 +1033,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 64) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -1052,7 +1052,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 128) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -1072,7 +1072,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 256) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -1092,7 +1092,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 512) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -1112,7 +1112,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 1024) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -1182,7 +1182,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1208,7 +1208,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1234,7 +1234,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1260,7 +1260,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1285,7 +1285,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 8, 64, dims><<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1310,7 +1310,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 8, 64, dims><<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(),