Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional Distance Metrics for CAGRA #2187

Closed
Closed
6 changes: 5 additions & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <rmm/cuda_stream_view.hpp>

#include "factory.cuh"
#include "raft/distance/distance_types.hpp"
#include "raft/util/cudart_utils.hpp"
#include "search_plan.cuh"
#include "search_single_cta.cuh"

Expand Down Expand Up @@ -129,7 +131,7 @@ void search_main(raft::resources const& res,
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<T, internal_IdxT, DistanceT, CagraSampleFilterT_s>> plan =
factory<T, internal_IdxT, DistanceT, CagraSampleFilterT_s>::create(
res, params, index.dim(), index.graph_degree(), topk);
res, params, index.dim(), index.graph_degree(), topk, index.metric());

plan->check(topk);

Expand Down Expand Up @@ -171,10 +173,12 @@ void search_main(raft::resources const& res,
_num_executed_iterations,
topk,
set_offset(sample_filter, qid));
raft::print_device_vector("topk_distances_ptr", _topk_distances_ptr, 10, std::cout);
}

static_assert(std::is_same_v<DistanceT, float>,
"only float distances are supported at the moment");
if (index.metric() != distance::InnerProduct) {
float* dist_out = distances.data_handle();
const DistanceT* dist_in = distances.data_handle();
// We're converting the data from T to DistanceT during distance computation
Expand Down
33 changes: 24 additions & 9 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#pragma once

#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>

#include "device_common.hpp"
Expand Down Expand Up @@ -61,7 +62,8 @@ struct distance_op<LOAD_T, DATA_T, DISTANCE_T, DATASET_BLOCK_DIM, TEAM_SIZE, fal

__device__ DISTANCE_T operator()(const DATA_T* const dataset_ptr,
const std::uint32_t dataset_dim,
const bool valid)
const bool valid,
raft::distance::DistanceType metric)
{
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
constexpr unsigned vlen = get_vlen<LOAD_T, DATA_T>();
Expand All @@ -87,8 +89,13 @@ struct distance_op<LOAD_T, DATA_T, DISTANCE_T, DATASET_BLOCK_DIM, TEAM_SIZE, fal
const uint32_t kv = k + v;
// if (kv >= dataset_dim) break;
DISTANCE_T diff = query_buffer[device::swizzling(kv)];
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
if (metric == raft::distance::L2Expanded) {
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
} else {
diff *= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 -= diff;
}
}
}
}
Expand Down Expand Up @@ -130,7 +137,8 @@ struct distance_op<LOAD_T, DATA_T, DISTANCE_T, DATASET_BLOCK_DIM, TEAM_SIZE, tru

__device__ DISTANCE_T operator()(const DATA_T* const dataset_ptr,
const std::uint32_t dataset_dim,
const bool valid)
const bool valid,
raft::distance::DistanceType metric)
{
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
constexpr unsigned vlen = get_vlen<LOAD_T, DATA_T>();
Expand All @@ -155,8 +163,13 @@ struct distance_op<LOAD_T, DATA_T, DISTANCE_T, DATASET_BLOCK_DIM, TEAM_SIZE, tru
DISTANCE_T diff;
const unsigned ev = (vlen * e) + v;
diff = query_frags[ev];
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
if (metric == raft::distance::L2Expanded) {
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
} else {
diff *= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 -= diff;
}
}
}
}
Expand Down Expand Up @@ -188,6 +201,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
const uint32_t num_seeds,
INDEX_T* const visited_hash_ptr,
const uint32_t hash_bitlen,
raft::distance::DistanceType metric,
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
Expand Down Expand Up @@ -215,7 +229,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
}
}

const auto norm2 = dist_op(dataset_ptr + dataset_ld * seed_index, dataset_dim, valid_i);
const auto norm2 = dist_op(dataset_ptr + dataset_ld * seed_index, dataset_dim, valid_i, metric);

if (valid_i && (norm2 < best_norm2_team_local)) {
best_norm2_team_local = norm2;
Expand Down Expand Up @@ -259,7 +273,8 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in
const std::uint32_t hash_bitlen,
const INDEX_T* const parent_indices,
const INDEX_T* const internal_topk_list,
const std::uint32_t search_width)
const std::uint32_t search_width,
raft::distance::DistanceType metric)
{
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;
const INDEX_T invalid_index = utils::get_max_value<INDEX_T>();
Expand Down Expand Up @@ -302,7 +317,7 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in
if (valid_i) { child_id = result_child_indices_ptr[i]; }

DISTANCE_T norm2 =
dist_op(dataset_ptr + child_id * dataset_ld, dataset_dim, child_id != invalid_index);
dist_op(dataset_ptr + child_id * dataset_ld, dataset_dim, child_id != invalid_index, metric);

// Store the distance
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
Expand Down
11 changes: 6 additions & 5 deletions cpp/include/raft/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ class factory {
search_params const& params,
int64_t dim,
int64_t graph_degree,
uint32_t topk)
uint32_t topk,
distance::DistanceType metric)
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
search_plan_impl_base plan(params, dim, graph_degree, topk, metric);
switch (plan.dataset_block_dim) {
case 128:
switch (plan.team_size) {
Expand Down Expand Up @@ -74,17 +75,17 @@ class factory {
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>>(
new single_cta_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, T, IdxT, DistanceT, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>>(
new multi_cta_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, T, IdxT, DistanceT, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
} else {
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>>(
new multi_kernel_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, T, IdxT, DistanceT, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "compute_distance.hpp"
#include "device_common.hpp"
#include "hashmap.hpp"
#include "raft/distance/distance_types.hpp"
#include "search_multi_cta_kernel.cuh"
#include "search_plan.cuh"
#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible
Expand Down Expand Up @@ -95,9 +96,10 @@ struct search : public search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILT
search_params params,
int64_t dim,
int64_t graph_degree,
uint32_t topk)
uint32_t topk,
distance::DistanceType metric)
: search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T>(
res, params, dim, graph_degree, topk),
res, params, dim, graph_degree, topk, metric),
intermediate_indices(0, resource::get_cuda_stream(res)),
intermediate_distances(0, resource::get_cuda_stream(res)),
topk_workspace(0, resource::get_cuda_stream(res))
Expand Down Expand Up @@ -230,6 +232,7 @@ struct search : public search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILT
min_iterations,
max_iterations,
sample_filter,
this->metric,
stream);
RAFT_CUDA_TRY(cudaPeekAtLastError());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ void select_and_run(raft::device_matrix_view<const DATA_T, int64_t, layout_strid
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
distance::DistanceType metric,
cudaStream_t stream) RAFT_EXPLICIT;
#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

Expand Down Expand Up @@ -84,6 +85,7 @@ void select_and_run(raft::device_matrix_view<const DATA_T, int64_t, layout_strid
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
distance::DistanceType metric, \
cudaStream_t stream);

instantiate_kernel_selection(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "compute_distance.hpp"
#include "device_common.hpp"
#include "hashmap.hpp"
#include "raft/distance/distance_types.hpp"
#include "search_plan.cuh"
#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible
#include "utils.hpp"
Expand Down Expand Up @@ -153,7 +154,8 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel(
const uint32_t min_iteration,
const uint32_t max_iteration,
uint32_t* const num_executed_iterations, /* stats */
SAMPLE_FILTER_T sample_filter)
SAMPLE_FILTER_T sample_filter,
raft::distance::DistanceType metric)
{
const auto num_queries = gridDim.y;
const auto query_id = blockIdx.y;
Expand Down Expand Up @@ -240,6 +242,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel(
num_seeds,
local_visited_hashmap_ptr,
hash_bitlen,
metric,
block_id,
num_blocks);
__syncthreads();
Expand Down Expand Up @@ -286,7 +289,8 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel(
hash_bitlen,
parent_indices_buffer,
result_indices_buffer,
search_width);
search_width,
metric);
_CLK_REC(clk_compute_distance);
__syncthreads();

Expand Down Expand Up @@ -338,7 +342,13 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel(

for (uint32_t i = threadIdx.x; i < itopk_size; i += blockDim.x) {
uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id)));
if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; }
const INDEX_T invalid_index = utils::get_max_value<INDEX_T>();

if (result_distances_ptr != nullptr) {
if (metric == distance::InnerProduct && result_indices_buffer[i] != invalid_index) {
result_distances_ptr[j] = -result_distances_buffer[i];
} else {
result_distances_ptr[j] = result_distances_buffer[i]; }}
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;

result_indices_ptr[j] =
Expand Down Expand Up @@ -483,6 +493,7 @@ void select_and_run( // raft::resources const& res,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
distance::DistanceType metric,
cudaStream_t stream)
{
auto kernel =
Expand Down Expand Up @@ -527,7 +538,8 @@ void select_and_run( // raft::resources const& res,
min_iterations,
max_iterations,
num_executed_iterations,
sample_filter);
sample_filter,
metric);
}

} // namespace multi_cta_search
Expand Down
Loading
Loading