Skip to content

Commit

Permalink
Merge branch 'branch-24.10' into rhdong/transpose-ut
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong authored Aug 27, 2024
2 parents 2b1774a + e17dcdf commit 2def2c7
Show file tree
Hide file tree
Showing 19 changed files with 994 additions and 82 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ repos:
hooks:
- id: check-json
- repo: https://github.com/rapidsai/pre-commit-hooks
rev: v0.3.1
rev: v0.4.0
hooks:
- id: verify-copyright
files: |
Expand Down
2 changes: 1 addition & 1 deletion conda/environments/all_cuda-118_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies:
- nccl>=2.9.9
- ninja
- numba>=0.57
- numpy>=1.23,<2.0a0
- numpy>=1.23,<3.0a0
- numpydoc
- nvcc_linux-aarch64=11.8
- pre-commit
Expand Down
2 changes: 1 addition & 1 deletion conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies:
- nccl>=2.9.9
- ninja
- numba>=0.57
- numpy>=1.23,<2.0a0
- numpy>=1.23,<3.0a0
- numpydoc
- nvcc_linux-64=11.8
- pre-commit
Expand Down
2 changes: 1 addition & 1 deletion conda/environments/all_cuda-125_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies:
- nccl>=2.9.9
- ninja
- numba>=0.57
- numpy>=1.23,<2.0a0
- numpy>=1.23,<3.0a0
- numpydoc
- pre-commit
- pydata-sphinx-theme
Expand Down
2 changes: 1 addition & 1 deletion conda/environments/all_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies:
- nccl>=2.9.9
- ninja
- numba>=0.57
- numpy>=1.23,<2.0a0
- numpy>=1.23,<3.0a0
- numpydoc
- pre-commit
- pydata-sphinx-theme
Expand Down
2 changes: 1 addition & 1 deletion conda/recipes/pylibraft/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ requirements:
{% endif %}
- libraft {{ version }}
- libraft-headers {{ version }}
- numpy >=1.23,<2.0a0
- numpy >=1.23,<3.0a0
- python x.x
- rmm ={{ minor_version }}

Expand Down
48 changes: 30 additions & 18 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/init.cuh>
#include <raft/matrix/slice.cuh>
#include <raft/neighbors/detail/cagra/device_common.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>
Expand Down Expand Up @@ -344,7 +345,9 @@ struct GnndGraph {
~GnndGraph();
};

template <typename Data_t = float, typename Index_t = int, typename epilogue_op = raft::identity_op>
template <typename Data_t = float,
typename Index_t = int,
typename epilogue_op = DistEpilogue<Index_t, Data_t>>
class GNND {
public:
GNND(raft::resources const& res, const BuildConfig& build_config);
Expand All @@ -356,17 +359,19 @@ class GNND {
Index_t* output_graph,
bool return_distances,
DistData_t* output_distances,
epilogue_op distance_epilogue = raft::identity_op());
epilogue_op distance_epilogue = DistEpilogue<Index_t, Data_t>());
~GNND() = default;
using ID_t = InternalID_t<Index_t>;
void reset(raft::resources const& res);

private:
void add_reverse_edges(Index_t* graph_ptr,
Index_t* h_rev_graph_ptr,
Index_t* d_rev_graph_ptr,
int2* list_sizes,
cudaStream_t stream = 0);
void local_join(cudaStream_t stream = 0, epilogue_op distance_epilogue = raft::identity_op());
void local_join(cudaStream_t stream = 0,
epilogue_op distance_epilogue = DistEpilogue<Index_t, Data_t>());

raft::resources const& res;

Expand Down Expand Up @@ -701,7 +706,7 @@ __device__ __forceinline__ void remove_duplicates(
// is 1024 and 1536 respectively, which means the bounds don't work anymore
template <typename Index_t,
typename ID_t = InternalID_t<Index_t>,
typename epilogue_op = raft::identity_op>
typename epilogue_op = DistEpilogue<Index_t, DistData_t>>
RAFT_KERNEL
#ifdef __CUDA_ARCH__
#if (__CUDA_ARCH__) == 750 || ((__CUDA_ARCH__) >= 860 && (__CUDA_ARCH__) <= 890)
Expand Down Expand Up @@ -1183,18 +1188,23 @@ GNND<Data_t, Index_t, epilogue_op>::GNND(raft::resources const& res,
d_list_sizes_old_{raft::make_device_vector<int2, size_t>(res, nrow_)}
{
static_assert(NUM_SAMPLES <= 32);

thrust::fill(thrust::device,
dists_buffer_.data_handle(),
dists_buffer_.data_handle() + dists_buffer_.size(),
std::numeric_limits<float>::max());
thrust::fill(thrust::device,
reinterpret_cast<Index_t*>(graph_buffer_.data_handle()),
reinterpret_cast<Index_t*>(graph_buffer_.data_handle()) + graph_buffer_.size(),
std::numeric_limits<Index_t>::max());
thrust::fill(thrust::device, d_locks_.data_handle(), d_locks_.data_handle() + d_locks_.size(), 0);
raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits<float>::max());
auto graph_buffer_view = raft::make_device_matrix_view<Index_t, int64_t>(
reinterpret_cast<Index_t*>(graph_buffer_.data_handle()), nrow_, DEGREE_ON_DEVICE);
raft::matrix::fill(res, graph_buffer_view, std::numeric_limits<Index_t>::max());
raft::matrix::fill(res, d_locks_.view(), 0);
};

template <typename Data_t, typename Index_t, typename epilogue_op>
void GNND<Data_t, Index_t, epilogue_op>::reset(raft::resources const& res)
{
raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits<float>::max());
auto graph_buffer_view = raft::make_device_matrix_view<Index_t, int64_t>(
reinterpret_cast<Index_t*>(graph_buffer_.data_handle()), nrow_, DEGREE_ON_DEVICE);
raft::matrix::fill(res, graph_buffer_view, std::numeric_limits<Index_t>::max());
raft::matrix::fill(res, d_locks_.view(), 0);
}

template <typename Data_t, typename Index_t, typename epilogue_op>
void GNND<Data_t, Index_t, epilogue_op>::add_reverse_edges(Index_t* graph_ptr,
Index_t* h_rev_graph_ptr,
Expand Down Expand Up @@ -1246,6 +1256,7 @@ void GNND<Data_t, Index_t, epilogue_op>::build(Data_t* data,

cudaStream_t stream = raft::resource::get_cuda_stream(res);
nrow_ = nrow;
graph_.nrow = nrow;
graph_.h_graph = (InternalID_t<Index_t>*)output_graph;

cudaPointerAttributes data_ptr_attr;
Expand Down Expand Up @@ -1384,6 +1395,7 @@ void GNND<Data_t, Index_t, epilogue_op>::build(Data_t* data,
static_cast<int64_t>(build_config_.output_graph_degree)};
raft::matrix::slice<DistData_t, int64_t, raft::row_major>(
res, raft::make_const_mdspan(graph_d_dists.view()), output_dist_view, coords);
raft::resource::sync_stream(res);
}
Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle();
Expand Down Expand Up @@ -1414,14 +1426,14 @@ void GNND<Data_t, Index_t, epilogue_op>::build(Data_t* data,
template <typename T,
typename IdxT = uint32_t,
typename epilogue_op = raft::identity_op,
typename epilogue_op = DistEpilogue<IdxT, T>,
typename Accessor =
host_device_accessor<std::experimental::default_accessor<T>, memory_type::host>>
void build(raft::resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset,
index<IdxT>& idx,
epilogue_op distance_epilogue = raft::identity_op())
epilogue_op distance_epilogue = DistEpilogue<IdxT, T>())
{
RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits<int>::max() - 1,
"The dataset size for GNND should be less than %d",
Expand Down Expand Up @@ -1491,13 +1503,13 @@ void build(raft::resources const& res,
template <typename T,
typename IdxT = uint32_t,
typename epilogue_op = raft::identity_op,
typename epilogue_op = DistEpilogue<IdxT, T>,
typename Accessor =
host_device_accessor<std::experimental::default_accessor<T>, memory_type::host>>
index<IdxT> build(raft::resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset,
epilogue_op distance_epilogue = raft::identity_op())
epilogue_op distance_epilogue = DistEpilogue<IdxT, T>())
{
size_t intermediate_degree = params.intermediate_graph_degree;
size_t graph_degree = params.graph_degree;
Expand Down
Loading

0 comments on commit 2def2c7

Please sign in to comment.