Skip to content

Commit

Permalink
review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Feb 22, 2024
1 parent a1f1f53 commit c42167c
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 17 deletions.
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ void search_main(raft::resources const& res,
distances.extent(0),
distances.extent(1),
kScale,
true,
resource::get_cuda_stream(res));
}
/** @} */ // end group cagra
Expand Down
11 changes: 6 additions & 5 deletions cpp/include/raft/neighbors/detail/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,15 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk]
uint32_t n_queries,
uint32_t topk,
float scaling_factor,
bool account_for_max_close,
rmm::cuda_stream_view stream)
{
constexpr bool needs_cast = !std::is_same<ScoreInT, ScoreOutT>::value;
size_t len = size_t(n_queries) * size_t(topk);
switch (metric) {
case distance::DistanceType::L2Unexpanded:
case distance::DistanceType::L2Expanded: {
if (scaling_factor != 0) {
if (scaling_factor != 1.0) {
linalg::unaryOp(
out,
in,
Expand All @@ -235,7 +236,7 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk]
} break;
case distance::DistanceType::L2SqrtUnexpanded:
case distance::DistanceType::L2SqrtExpanded: {
if (scaling_factor != 0) {
if (scaling_factor != 1.0) {
linalg::unaryOp(out,
in,
len,
Expand All @@ -251,13 +252,13 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk]
}
} break;
case distance::DistanceType::InnerProduct: {
if (scaling_factor != 0) {
float factor = (account_for_max_close ? -1.0 : 1.0) * scaling_factor * scaling_factor;
if (factor != 1.0) {
linalg::unaryOp(
out,
in,
len,
raft::compose_op(raft::mul_const_op<ScoreOutT>{-scaling_factor * scaling_factor},
raft::cast_op<ScoreOutT>{}),
raft::compose_op(raft::mul_const_op<ScoreOutT>{factor}, raft::cast_op<ScoreOutT>{}),
stream);
} else if (needs_cast) {
linalg::unaryOp(out, in, len, raft::cast_op<ScoreOutT>{}, stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -899,8 +899,6 @@ void launch_kernel(Lambda lambda,
raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide<float, IdxT>(
kThreadsPerBlock / kSubwarpSize, k);
smem_size += std::max<int>(smem_size, block_merge_mem);
} else {
smem_size += smem_size;
}

// power-of-two less than cuda limit (for better addr alignment)
Expand Down
18 changes: 9 additions & 9 deletions cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ void search_impl(raft::resources const& handle,

// Optional structures if postprocessing is required
// The topk distance value of candidate vectors from each cluster(list)
rmm::device_uvector<AccT> refined_distances_dev(0, stream, search_mr);
rmm::device_uvector<AccT> distances_tmp_dev(0, stream, search_mr);
// The topk index of candidate vectors from each cluster(list)
rmm::device_uvector<IdxT> refined_indices_dev(0, stream, search_mr);
rmm::device_uvector<IdxT> indices_tmp_dev(0, stream, search_mr);
// Number of samples for each query
rmm::device_uvector<uint32_t> num_samples(0, stream, search_mr);
// Offsets per probe for each query
Expand Down Expand Up @@ -193,11 +193,11 @@ void search_impl(raft::resources const& handle,

auto target_size = std::size_t(n_queries) * (manage_local_topk ? grid_dim_x * k : max_samples);

refined_distances_dev.resize(target_size, stream);
if (manage_local_topk) refined_indices_dev.resize(target_size, stream);
distances_tmp_dev.resize(target_size, stream);
if (manage_local_topk) indices_tmp_dev.resize(target_size, stream);

distances_dev_ptr = refined_distances_dev.data();
indices_dev_ptr = refined_indices_dev.data();
distances_dev_ptr = distances_tmp_dev.data();
indices_dev_ptr = indices_tmp_dev.data();
}

ivfflat_interleaved_scan<T, typename utils::config<T>::value_t, IdxT, IvfSampleFilterT>(
Expand All @@ -224,8 +224,8 @@ void search_impl(raft::resources const& handle,
// Merge topk values from different blocks
if (!manage_local_topk || grid_dim_x > 1) {
matrix::detail::select_k<AccT, IdxT>(handle,
refined_distances_dev.data(),
refined_indices_dev.data(),
distances_tmp_dev.data(),
indices_tmp_dev.data(),
n_queries,
manage_local_topk ? (k * grid_dim_x) : max_samples,
k,
Expand All @@ -236,7 +236,7 @@ void search_impl(raft::resources const& handle,
if (!manage_local_topk) {
// post process distances && neighbor IDs
ivf::detail::postprocess_distances(
distances, distances, index.metric(), n_queries, k, 0, stream);
distances, distances, index.metric(), n_queries, k, 1.0, false, stream);
ivf::detail::postprocess_neighbors(neighbors,
neighbors,
index.inds_ptrs().data_handle(),
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ void ivfpq_search_worker(raft::resources const& handle,

// Postprocessing
ivf::detail::postprocess_distances(
distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, stream);
distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, true, stream);
ivf::detail::postprocess_neighbors(neighbors,
neighbors_uint32,
index.inds_ptrs().data_handle(),
Expand Down

0 comments on commit c42167c

Please sign in to comment.