diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index cdfb9d9931..ad16f3e11d 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -340,7 +341,7 @@ struct GnndGraph { ~GnndGraph(); }; -template +template class GNND { public: GNND(raft::resources const& res, const BuildConfig& build_config); @@ -351,7 +352,8 @@ class GNND { const Index_t nrow, Index_t* output_graph, bool return_distances, - DistData_t* output_distances); + DistData_t* output_distances, + epilogue_op distance_epilogue = raft::identity_op()); ~GNND() = default; using ID_t = InternalID_t; @@ -361,7 +363,7 @@ class GNND { Index_t* d_rev_graph_ptr, int2* list_sizes, cudaStream_t stream = 0); - void local_join(cudaStream_t stream = 0); + void local_join(cudaStream_t stream = 0, epilogue_op distance_epilogue = raft::identity_op()); raft::resources const& res; @@ -694,7 +696,9 @@ __device__ __forceinline__ void remove_duplicates( // MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048 // For architectures 750 and 860 (890), the values for MAX_RESIDENT_THREAD_PER_SM // is 1024 and 1536 respectively, which means the bounds don't work anymore -template > +template , + typename epilogue_op = raft::identity_op> RAFT_KERNEL #ifdef __CUDA_ARCH__ #if (__CUDA_ARCH__) == 750 || ((__CUDA_ARCH__) >= 860 && (__CUDA_ARCH__) <= 890) @@ -716,7 +720,8 @@ __launch_bounds__(BLOCK_SIZE, 4) DistData_t* dists, int graph_width, int* locks, - DistData_t* l2_norms) + DistData_t* l2_norms, + epilogue_op distance_epilogue) { #if (__CUDA_ARCH__ >= 700) using namespace nvcuda; @@ -826,14 +831,17 @@ __launch_bounds__(BLOCK_SIZE, 4) __syncthreads(); for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { - if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_new_size && - i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { + auto row_idx = i % SKEWED_MAX_NUM_BI_SAMPLES; + auto col_idx = i / SKEWED_MAX_NUM_BI_SAMPLES; + if (row_idx < list_new_size && col_idx < list_new_size) { + auto r = new_neighbors[row_idx]; + auto c = new_neighbors[col_idx]; if (l2_norms == nullptr) { - s_distances[i] = -s_distances[i]; + auto dist_val = -s_distances[i]; + s_distances[i] = distance_epilogue(dist_val, r, c); } else { - s_distances[i] = l2_norms[new_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + - l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - - 2.0 * s_distances[i]; + auto dist_val = l2_norms[r] + l2_norms[c] - 2.0 * s_distances[i]; + s_distances[i] = distance_epilogue(dist_val, r, c); } } else { s_distances[i] = std::numeric_limits::max(); @@ -905,14 +913,17 @@ __launch_bounds__(BLOCK_SIZE, 4) __syncthreads(); for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { - if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_old_size && - i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { + auto row_idx = i % SKEWED_MAX_NUM_BI_SAMPLES; + auto col_idx = i / SKEWED_MAX_NUM_BI_SAMPLES; + if (row_idx < list_old_size && col_idx < list_new_size) { + auto r = old_neighbors[row_idx]; + auto c = new_neighbors[col_idx]; if (l2_norms == nullptr) { - s_distances[i] = -s_distances[i]; + auto dist_val = -s_distances[i]; + s_distances[i] = distance_epilogue(dist_val, r, c); } else { - s_distances[i] = l2_norms[old_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + - l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - - 2.0 * s_distances[i]; + auto dist_val = l2_norms[r] + l2_norms[c] - 2.0 * s_distances[i]; + s_distances[i] = distance_epilogue(dist_val, r, c); } } else { s_distances[i] = std::numeric_limits::max(); @@ -1140,8 +1151,9 @@ GnndGraph::~GnndGraph() assert(h_graph == nullptr); } -template -GNND::GNND(raft::resources const& res, const BuildConfig& build_config) +template +GNND::GNND(raft::resources const& res, + const BuildConfig& build_config) : res(res), build_config_(build_config), graph_(build_config.max_dataset_size, @@ -1180,12 +1192,12 @@ GNND::GNND(raft::resources const& res, const BuildConfig& build thrust::fill(thrust::device, d_locks_.data_handle(), d_locks_.data_handle() + d_locks_.size(), 0); }; -template -void GNND::add_reverse_edges(Index_t* graph_ptr, - Index_t* h_rev_graph_ptr, - Index_t* d_rev_graph_ptr, - int2* list_sizes, - cudaStream_t stream) +template +void GNND::add_reverse_edges(Index_t* graph_ptr, + Index_t* h_rev_graph_ptr, + Index_t* d_rev_graph_ptr, + int2* list_sizes, + cudaStream_t stream) { add_rev_edges_kernel<<>>( graph_ptr, d_rev_graph_ptr, NUM_SAMPLES, list_sizes); @@ -1193,8 +1205,9 @@ void GNND::add_reverse_edges(Index_t* graph_ptr, h_rev_graph_ptr, d_rev_graph_ptr, nrow_ * NUM_SAMPLES, raft::resource::get_cuda_stream(res)); } -template -void GNND::local_join(cudaStream_t stream) +template +void GNND::local_join(cudaStream_t stream, + epilogue_op distance_epilogue) { thrust::fill(thrust::device.on(stream), dists_buffer_.data_handle(), @@ -1214,15 +1227,17 @@ void GNND::local_join(cudaStream_t stream) dists_buffer_.data_handle(), DEGREE_ON_DEVICE, d_locks_.data_handle(), - l2_norms_.data_handle()); + l2_norms_.data_handle(), + distance_epilogue); } -template -void GNND::build(Data_t* data, - const Index_t nrow, - Index_t* output_graph, - bool return_distances, - DistData_t* output_distances) +template +void GNND::build(Data_t* data, + const Index_t nrow, + Index_t* output_graph, + bool return_distances, + DistData_t* output_distances, + epilogue_op distance_epilogue) { using input_t = typename std::remove_const::type; @@ -1318,7 +1333,7 @@ void GNND::build(Data_t* data, raft::util::arch::SM_range(raft::util::arch::SM_70(), raft::util::arch::SM_future()); if (wmma_range.contains(runtime_arch)) { - local_join(stream); + local_join(stream, distance_epilogue); } else { THROW("NN_DESCENT cannot be run for __CUDA_ARCH__ < 700"); } @@ -1385,13 +1400,15 @@ void GNND::build(Data_t* data, } template , memory_type::host>> void build(raft::resources const& res, const index_params& params, mdspan, row_major, Accessor> dataset, - index& idx) + index& idx, + epilogue_op distance_epilogue = raft::identity_op()) { RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits::max() - 1, "The dataset size for GNND should be less than %d", @@ -1433,7 +1450,7 @@ void build(raft::resources const& res, .termination_threshold = params.termination_threshold, .output_graph_degree = params.graph_degree}; - GNND nnd(res, build_config); + GNND nnd(res, build_config); if (idx.distances().has_value() || !params.return_distances) { nnd.build(dataset.data_handle(), @@ -1442,7 +1459,8 @@ void build(raft::resources const& res, params.return_distances, idx.distances() .value_or(raft::make_device_matrix(res, 0, 0).view()) - .data_handle()); + .data_handle(), + distance_epilogue); } else { RAFT_EXPECTS(!params.return_distances, "Distance view not allocated. Using return_distances set to true requires " @@ -1459,12 +1477,14 @@ void build(raft::resources const& res, } template , memory_type::host>> index build(raft::resources const& res, const index_params& params, - mdspan, row_major, Accessor> dataset) + mdspan, row_major, Accessor> dataset, + epilogue_op distance_epilogue = raft::identity_op()) { size_t intermediate_degree = params.intermediate_graph_degree; size_t graph_degree = params.graph_degree; @@ -1481,7 +1501,7 @@ index build(raft::resources const& res, index idx{ res, dataset.extent(0), static_cast(graph_degree), params.return_distances}; - build(res, params, dataset, idx); + build(res, params, dataset, idx, distance_epilogue); return idx; } diff --git a/cpp/include/raft/neighbors/nn_descent.cuh b/cpp/include/raft/neighbors/nn_descent.cuh index ceb5ae5643..a46a2006d6 100644 --- a/cpp/include/raft/neighbors/nn_descent.cuh +++ b/cpp/include/raft/neighbors/nn_descent.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,19 +48,22 @@ namespace raft::neighbors::experimental::nn_descent { * * @tparam T data-type of the input dataset * @tparam IdxT data-type for the output index + * @tparam epilogue_op epilogue operation type for distances * @param[in] res raft::resources is an object mangaging resources * @param[in] params an instance of nn_descent::index_params that are parameters * to run the nn-descent algorithm * @param[in] dataset raft::device_matrix_view input dataset expected to be located * in device memory + * @param[in] distance_epilogue epilogue operation for distances * @return index index containing all-neighbors knn graph in host memory */ -template +template index build(raft::resources const& res, index_params const& params, - raft::device_matrix_view dataset) + raft::device_matrix_view dataset, + epilogue_op distance_epilogue = raft::identity_op()) { - return detail::build(res, params, dataset); + return detail::build(res, params, dataset, distance_epilogue); } /** @@ -85,6 +88,7 @@ index build(raft::resources const& res, * * @tparam T data-type of the input dataset * @tparam IdxT data-type for the output index + * @tparam epilogue_op epilogue operation type for distances * @param res raft::resources is an object mangaging resources * @param[in] params an instance of nn_descent::index_params that are parameters * to run the nn-descent algorithm @@ -92,14 +96,16 @@ index build(raft::resources const& res, * in device memory * @param[out] idx raft::neighbors::experimental::nn_descentindex containing all-neighbors knn graph * in host memory + * @param[in] distance_epilogue epilogue operation for distances */ -template +template void build(raft::resources const& res, index_params const& params, raft::device_matrix_view dataset, - index& idx) + index& idx, + epilogue_op distance_epilogue = raft::identity_op()) { - detail::build(res, params, dataset, idx); + detail::build(res, params, dataset, idx, distance_epilogue); } /** @@ -122,19 +128,22 @@ void build(raft::resources const& res, * * @tparam T data-type of the input dataset * @tparam IdxT data-type for the output index + * @tparam epilogue_op epilogue operation type for distances * @param res raft::resources is an object mangaging resources * @param[in] params an instance of nn_descent::index_params that are parameters * to run the nn-descent algorithm * @param[in] dataset raft::host_matrix_view input dataset expected to be located * in host memory + * @param[in] distance_epilogue epilogue operation for distances * @return index index containing all-neighbors knn graph in host memory */ -template +template index build(raft::resources const& res, index_params const& params, - raft::host_matrix_view dataset) + raft::host_matrix_view dataset, + epilogue_op distance_epilogue = raft::identity_op()) { - return detail::build(res, params, dataset); + return detail::build(res, params, dataset, distance_epilogue); } /** @@ -159,6 +168,7 @@ index build(raft::resources const& res, * * @tparam T data-type of the input dataset * @tparam IdxT data-type for the output index + * @tparam epilogue_op epilogue operation type for distances * @param[in] res raft::resources is an object mangaging resources * @param[in] params an instance of nn_descent::index_params that are parameters * to run the nn-descent algorithm @@ -166,14 +176,16 @@ index build(raft::resources const& res, * in host memory * @param[out] idx raft::neighbors::experimental::nn_descentindex containing all-neighbors knn graph * in host memory + * @param[in] distance_epilogue epilogue operation for distances */ -template +template void build(raft::resources const& res, index_params const& params, raft::host_matrix_view dataset, - index& idx) + index& idx, + epilogue_op distance_epilogue = raft::identity_op()) { - detail::build(res, params, dataset, idx); + detail::build(res, params, dataset, idx, distance_epilogue); } /** @} */ // end group nn-descent