Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CAGRA graph pruning: fix 32/64-bit int arithmetics #2197

Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 30 additions & 31 deletions cpp/include/raft/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,8 @@

#include <raft/util/bitonic_sort.cuh>
#include <raft/util/cuda_rt_essentials.hpp>
#include <raft/util/device_atomics.cuh>
#include <raft/util/warp_primitives.cuh>

#include "utils.hpp"

Expand Down Expand Up @@ -95,11 +97,11 @@ RAFT_KERNEL kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, data
dataset[d + static_cast<uint64_t>(dataset_dim) * dstNode]);
dist += diff * diff;
}
dist += __shfl_xor_sync(0xffffffff, dist, 1);
dist += __shfl_xor_sync(0xffffffff, dist, 2);
dist += __shfl_xor_sync(0xffffffff, dist, 4);
dist += __shfl_xor_sync(0xffffffff, dist, 8);
dist += __shfl_xor_sync(0xffffffff, dist, 16);
dist += raft::shfl_xor(dist, 1);
dist += raft::shfl_xor(dist, 2);
dist += raft::shfl_xor(dist, 4);
dist += raft::shfl_xor(dist, 8);
dist += raft::shfl_xor(dist, 16);
if (lane_id == (k % raft::WarpSize)) {
my_keys[k / raft::WarpSize] = dist;
my_vals[k / raft::WarpSize] = dstNode;
Expand All @@ -126,7 +128,7 @@ RAFT_KERNEL kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, data

template <int MAX_DEGREE, class IdxT>
RAFT_KERNEL kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree]
const uint32_t graph_size,
const int64_t graph_size,
const uint32_t graph_degree,
const uint32_t degree,
const uint32_t batch_size,
Expand All @@ -139,25 +141,23 @@ RAFT_KERNEL kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph
uint64_t* const num_retain = stats;
uint64_t* const num_full = stats + 1;

const uint64_t nid = blockIdx.x + (batch_size * batch_id);
if (nid >= graph_size) { return; }
const int64_t iA = static_cast<int64_t>(batch_size) * static_cast<int64_t>(batch_id) +
static_cast<int64_t>(blockIdx.x);
if (iA >= graph_size) { return; }
for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) {
smem_num_detour[k] = 0;
}
__syncthreads();

const uint64_t iA = nid;
if (iA >= graph_size) { return; }

// count number of detours (A->D->B)
for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) {
const uint64_t iD = knn_graph[kAD + (graph_degree * iA)];
const int64_t iD = knn_graph[iA * graph_degree + kAD];
for (uint32_t kDB = threadIdx.x; kDB < graph_degree; kDB += blockDim.x) {
const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)graph_degree * iD)];
const int64_t iB_candidate = knn_graph[iD * graph_degree + kDB];
for (uint32_t kAB = kAD + 1; kAB < graph_degree; kAB++) {
// if ( kDB < kAB )
{
const uint64_t iB = knn_graph[kAB + (graph_degree * iA)];
const int64_t iB = knn_graph[iA * graph_degree + kAB];
if (iB == iB_candidate) {
atomicAdd(smem_num_detour + kAB, 1);
break;
Expand All @@ -170,20 +170,20 @@ RAFT_KERNEL kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph

uint32_t num_edges_no_detour = 0;
for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) {
detour_count[k + (graph_degree * iA)] = min(smem_num_detour[k], (uint32_t)255);
detour_count[iA * graph_degree + k] = std::min<uint32_t>(smem_num_detour[k], 255);
if (smem_num_detour[k] == 0) { num_edges_no_detour++; }
}
num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1);
num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 2);
num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 4);
num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 8);
num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 16);
num_edges_no_detour += raft::shfl_xor(num_edges_no_detour, 1);
num_edges_no_detour += raft::shfl_xor(num_edges_no_detour, 2);
num_edges_no_detour += raft::shfl_xor(num_edges_no_detour, 4);
num_edges_no_detour += raft::shfl_xor(num_edges_no_detour, 8);
num_edges_no_detour += raft::shfl_xor(num_edges_no_detour, 16);
num_edges_no_detour = min(num_edges_no_detour, degree);

if (threadIdx.x == 0) {
num_no_detour_edges[iA] = num_edges_no_detour;
atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour);
if (num_edges_no_detour >= degree) { atomicAdd((unsigned long long int*)num_full, 1); }
atomicAdd<uint64_t>(num_retain, num_edges_no_detour);
if (num_edges_no_detour >= degree) { atomicAdd<uint64_t>(num_full, 1); }
}
}

Expand Down Expand Up @@ -332,7 +332,7 @@ void optimize(raft::resources const& res,
const uint32_t output_graph_degree = new_graph.extent(1);
auto input_graph_ptr = knn_graph.data_handle();
auto output_graph_ptr = new_graph.data_handle();
const IdxT graph_size = new_graph.extent(0);
const int64_t graph_size = new_graph.extent(0);

{
//
Expand Down Expand Up @@ -384,9 +384,8 @@ void optimize(raft::resources const& res,
input_graph_degree,
1024);
}
const uint32_t batch_size =
std::min(static_cast<uint32_t>(graph_size), static_cast<uint32_t>(256 * 1024));
const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size;
const uint32_t batch_size = std::min(graph_size, static_cast<int64_t>(256 * 1024));
const uint32_t num_batch = raft::div_rounding_up_safe<int64_t>(graph_size, batch_size);
const dim3 threads_prune(32, 1, 1);
const dim3 blocks_prune(batch_size, 1, 1);

Expand Down Expand Up @@ -423,7 +422,7 @@ void optimize(raft::resources const& res,
// Create pruned kNN graph
uint32_t max_detour = 0;
#pragma omp parallel for reduction(max : max_detour)
for (uint64_t i = 0; i < graph_size; i++) {
for (int64_t i = 0; i < graph_size; i++) {
uint64_t pk = 0;
for (uint32_t num_detour = 0; num_detour < output_graph_degree; num_detour++) {
if (max_detour < num_detour) { max_detour = num_detour; /* stats */ }
Expand Down Expand Up @@ -477,7 +476,7 @@ void optimize(raft::resources const& res,

for (uint64_t k = 0; k < output_graph_degree; k++) {
#pragma omp parallel for
for (uint64_t i = 0; i < graph_size; i++) {
for (int64_t i = 0; i < graph_size; i++) {
dest_nodes.data_handle()[i] = output_graph_ptr[k + (output_graph_degree * i)];
}
resource::sync_stream(res);
Expand Down Expand Up @@ -527,7 +526,7 @@ void optimize(raft::resources const& res,

constexpr int _omp_chunk = 1024;
#pragma omp parallel for schedule(dynamic, _omp_chunk)
for (uint64_t j = 0; j < graph_size; j++) {
for (int64_t j = 0; j < graph_size; j++) {
uint64_t k = std::min(rev_graph_count.data_handle()[j], output_graph_degree);
while (k) {
k--;
Expand Down Expand Up @@ -556,7 +555,7 @@ void optimize(raft::resources const& res,
/* stats */
uint64_t num_replaced_edges = 0;
#pragma omp parallel for reduction(+ : num_replaced_edges)
for (uint64_t i = 0; i < graph_size; i++) {
for (int64_t i = 0; i < graph_size; i++) {
for (uint64_t k = 0; k < output_graph_degree; k++) {
const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)];
const uint64_t pos =
Expand Down
Loading