From 4f0e425466e0092e86dfb070bde232551092e4b2 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 14 Sep 2023 13:12:54 -0700 Subject: [PATCH] use batch load iterator --- .../raft/neighbors/detail/nn_descent.cuh | 32 ++++++------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index b1a51381e0..da22d9ff2d 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -40,6 +40,7 @@ #include #include #include +#include #include // raft::util::arch::SM_* #include #include @@ -206,11 +207,6 @@ __device__ __forceinline__ void warp_bitonic_sort(T* element_ptr, const int lane return; } -enum class Metric_t { - METRIC_INNER_PRODUCT = 0, - METRIC_L2 = 1, -}; - struct BuildConfig { size_t max_dataset_size; size_t dataset_dim; @@ -219,7 +215,6 @@ struct BuildConfig { // If internal_node_degree == 0, the value of node_degree will be assigned to it size_t max_iterations{50}; float termination_threshold{0.0001}; - Metric_t metric_type{Metric_t::METRIC_INNER_PRODUCT}; }; template @@ -1212,8 +1207,6 @@ void GNND::local_join(cudaStream_t stream) template void GNND::build(Data_t* data, const Index_t nrow, Index_t* output_graph) { - using input_t = typename std::remove_const::type; - cudaStream_t stream = raft::resource::get_cuda_stream(res); nrow_ = nrow; graph_.h_graph = (InternalID_t*)output_graph; @@ -1221,27 +1214,21 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out cudaPointerAttributes data_ptr_attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); if (data_ptr_attr.type == cudaMemoryTypeUnregistered) { - int batch_size = 100000; - auto input_data = raft::make_device_matrix( - res, batch_size, build_config_.dataset_dim); - for (int step = 0; step < ceildiv(nrow_, batch_size); step++) { - int list_offset = step * batch_size; - int num_lists = step != ceildiv(nrow_, batch_size) - 1 ? batch_size : nrow_ - list_offset; - raft::copy(input_data.data_handle(), - data + list_offset * build_config_.dataset_dim, - num_lists * build_config_.dataset_dim, - raft::resource::get_cuda_stream(res)); - preprocess_data_kernel<<(nrow_), build_config_.dataset_dim, batch_size, stream}; + for (auto const& batch : vec_batches) { + preprocess_data_kernel<<(raft::warp_size())) * raft::warp_size(), - stream>>>(input_data.data_handle(), + stream>>>(batch.data(), d_data_.data_handle(), build_config_.dataset_dim, l2_norms_.data_handle(), - list_offset); + batch.offset()); } } else { preprocess_data_kernel<<< @@ -1421,8 +1408,7 @@ index build(raft::resources const& res, .node_degree = extended_graph_degree, .internal_node_degree = extended_intermediate_degree, .max_iterations = params.max_iterations, - .termination_threshold = params.termination_threshold, - .metric_type = Metric_t::METRIC_L2}; + .termination_threshold = params.termination_threshold}; GNND nnd(res, build_config); nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle());