Skip to content

Commit

Permalink
use batch load iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Sep 14, 2023
1 parent aa4f6cb commit 4f0e425
Showing 1 changed file with 9 additions and 23 deletions.
32 changes: 9 additions & 23 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/neighbors/detail/cagra/device_common.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/arch.cuh> // raft::util::arch::SM_*
#include <raft/util/cuda_dev_essentials.cuh>
#include <raft/util/cuda_rt_essentials.hpp>
Expand Down Expand Up @@ -206,11 +207,6 @@ __device__ __forceinline__ void warp_bitonic_sort(T* element_ptr, const int lane
return;
}

enum class Metric_t {
METRIC_INNER_PRODUCT = 0,
METRIC_L2 = 1,
};

struct BuildConfig {
size_t max_dataset_size;
size_t dataset_dim;
Expand All @@ -219,7 +215,6 @@ struct BuildConfig {
// If internal_node_degree == 0, the value of node_degree will be assigned to it
size_t max_iterations{50};
float termination_threshold{0.0001};
Metric_t metric_type{Metric_t::METRIC_INNER_PRODUCT};
};

template <typename Index_t>
Expand Down Expand Up @@ -1212,36 +1207,28 @@ void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
template <typename Data_t, typename Index_t>
void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* output_graph)
{
using input_t = typename std::remove_const<Data_t>::type;

cudaStream_t stream = raft::resource::get_cuda_stream(res);
nrow_ = nrow;
graph_.h_graph = (InternalID_t<Index_t>*)output_graph;

cudaPointerAttributes data_ptr_attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data));
if (data_ptr_attr.type == cudaMemoryTypeUnregistered) {
int batch_size = 100000;
auto input_data = raft::make_device_matrix<input_t, Index_t, raft::row_major>(
res, batch_size, build_config_.dataset_dim);
for (int step = 0; step < ceildiv(nrow_, batch_size); step++) {
int list_offset = step * batch_size;
int num_lists = step != ceildiv(nrow_, batch_size) - 1 ? batch_size : nrow_ - list_offset;
raft::copy(input_data.data_handle(),
data + list_offset * build_config_.dataset_dim,
num_lists * build_config_.dataset_dim,
raft::resource::get_cuda_stream(res));
preprocess_data_kernel<<<num_lists,
size_t batch_size = 100000;
raft::spatial::knn::detail::utils::batch_load_iterator vec_batches{
data, static_cast<size_t>(nrow_), build_config_.dataset_dim, batch_size, stream};
for (auto const& batch : vec_batches) {
preprocess_data_kernel<<<batch.size(),
raft::warp_size(),
sizeof(Data_t) *
ceildiv(build_config_.dataset_dim,
static_cast<size_t>(raft::warp_size())) *
raft::warp_size(),
stream>>>(input_data.data_handle(),
stream>>>(batch.data(),
d_data_.data_handle(),
build_config_.dataset_dim,
l2_norms_.data_handle(),
list_offset);
batch.offset());
}
} else {
preprocess_data_kernel<<<
Expand Down Expand Up @@ -1421,8 +1408,7 @@ index<IdxT> build(raft::resources const& res,
.node_degree = extended_graph_degree,
.internal_node_degree = extended_intermediate_degree,
.max_iterations = params.max_iterations,
.termination_threshold = params.termination_threshold,
.metric_type = Metric_t::METRIC_L2};
.termination_threshold = params.termination_threshold};
GNND<const T, int> nnd(res, build_config);
nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle());
Expand Down

0 comments on commit 4f0e425

Please sign in to comment.