Skip to content

Commit

Permalink
Make all cuda kernels have hidden visibility (#1898)
Browse files Browse the repository at this point in the history
Effect on binary size of libraft.a
23.12: 133361630
pr: 129748904

Effect on binary size of libraft.so
23.12: 83603224
pr: 83873088

Authors:
  - Robert Maynard (https://github.com/robertmaynard)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1898
  • Loading branch information
robertmaynard authored Oct 13, 2023
1 parent cf08558 commit 27dcf7b
Show file tree
Hide file tree
Showing 147 changed files with 1,322 additions and 1,316 deletions.
8 changes: 4 additions & 4 deletions cpp/bench/prims/distance/masked_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ struct Params {
AdjacencyPattern pattern;
}; // struct Params

__global__ void init_adj(AdjacencyPattern pattern,
int n,
raft::device_matrix_view<bool, int, raft::layout_c_contiguous> adj,
raft::device_vector_view<int, int, raft::layout_c_contiguous> group_idxs)
RAFT_KERNEL init_adj(AdjacencyPattern pattern,
int n,
raft::device_matrix_view<bool, int, raft::layout_c_contiguous> adj,
raft::device_vector_view<int, int, raft::layout_c_contiguous> group_idxs)
{
int m = adj.extent(0);
int num_groups = adj.extent(1);
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/sparse/convert_csr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct bench_param {
};

template <typename index_t>
__global__ void init_adj_kernel(bool* adj, index_t num_rows, index_t num_cols, index_t divisor)
RAFT_KERNEL init_adj_kernel(bool* adj, index_t num_rows, index_t num_cols, index_t divisor)
{
index_t r = blockDim.y * blockIdx.y + threadIdx.y;
index_t c = blockDim.x * blockIdx.x + threadIdx.x;
Expand Down
16 changes: 7 additions & 9 deletions cpp/include/raft/cluster/detail/agglomerative.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,7 @@ void build_dendrogram_host(raft::resources const& handle,
}

template <typename value_idx>
__global__ void write_levels_kernel(const value_idx* children,
value_idx* parents,
value_idx n_vertices)
RAFT_KERNEL write_levels_kernel(const value_idx* children, value_idx* parents, value_idx n_vertices)
{
value_idx tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < n_vertices) {
Expand All @@ -179,12 +177,12 @@ __global__ void write_levels_kernel(const value_idx* children,
* @param labels
*/
template <typename value_idx>
__global__ void inherit_labels(const value_idx* children,
const value_idx* levels,
std::size_t n_leaves,
value_idx* labels,
int cut_level,
value_idx n_vertices)
RAFT_KERNEL inherit_labels(const value_idx* children,
const value_idx* levels,
std::size_t n_leaves,
value_idx* labels,
int cut_level,
value_idx n_vertices)
{
value_idx tid = blockDim.x * blockIdx.x + threadIdx.x;

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/cluster/detail/connectivities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct distance_graph_impl<raft::cluster::LinkageDistance::KNN_GRAPH, value_idx,
};

template <typename value_idx>
__global__ void fill_indices2(value_idx* indices, size_t m, size_t nnz)
RAFT_KERNEL fill_indices2(value_idx* indices, size_t m, size_t nnz)
{
value_idx tid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (tid >= nnz) return;
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ template <uint32_t BlockDimY,
typename LabelT,
typename CounterT,
typename MappingOpT>
__global__ void __launch_bounds__((WarpSize * BlockDimY))
__launch_bounds__((WarpSize * BlockDimY)) RAFT_KERNEL
adjust_centers_kernel(MathT* centers, // [n_clusters, dim]
IdxT n_clusters,
IdxT dim,
Expand Down
46 changes: 23 additions & 23 deletions cpp/include/raft/cluster/detail/kmeans_deprecated.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ constexpr unsigned int BSIZE_DIV_WSIZE = (BLOCK_SIZE / WARP_SIZE);
* initialized to zero.
*/
template <typename index_type_t, typename value_type_t>
static __global__ void computeDistances(index_type_t n,
index_type_t d,
index_type_t k,
const value_type_t* __restrict__ obs,
const value_type_t* __restrict__ centroids,
value_type_t* __restrict__ dists)
RAFT_KERNEL computeDistances(index_type_t n,
index_type_t d,
index_type_t k,
const value_type_t* __restrict__ obs,
const value_type_t* __restrict__ centroids,
value_type_t* __restrict__ dists)
{
// Loop index
index_type_t i;
Expand Down Expand Up @@ -173,11 +173,11 @@ static __global__ void computeDistances(index_type_t n,
* cluster. Entries must be initialized to zero.
*/
template <typename index_type_t, typename value_type_t>
static __global__ void minDistances(index_type_t n,
index_type_t k,
value_type_t* __restrict__ dists,
index_type_t* __restrict__ codes,
index_type_t* __restrict__ clusterSizes)
RAFT_KERNEL minDistances(index_type_t n,
index_type_t k,
value_type_t* __restrict__ dists,
index_type_t* __restrict__ codes,
index_type_t* __restrict__ clusterSizes)
{
// Loop index
index_type_t i, j;
Expand Down Expand Up @@ -233,11 +233,11 @@ static __global__ void minDistances(index_type_t n,
* @param code_new Index associated with new centroid.
*/
template <typename index_type_t, typename value_type_t>
static __global__ void minDistances2(index_type_t n,
value_type_t* __restrict__ dists_old,
const value_type_t* __restrict__ dists_new,
index_type_t* __restrict__ codes_old,
index_type_t code_new)
RAFT_KERNEL minDistances2(index_type_t n,
value_type_t* __restrict__ dists_old,
const value_type_t* __restrict__ dists_new,
index_type_t* __restrict__ codes_old,
index_type_t code_new)
{
// Loop index
index_type_t i = threadIdx.x + blockIdx.x * blockDim.x;
Expand Down Expand Up @@ -275,9 +275,9 @@ static __global__ void minDistances2(index_type_t n,
* cluster. Entries must be initialized to zero.
*/
template <typename index_type_t>
static __global__ void computeClusterSizes(index_type_t n,
const index_type_t* __restrict__ codes,
index_type_t* __restrict__ clusterSizes)
RAFT_KERNEL computeClusterSizes(index_type_t n,
const index_type_t* __restrict__ codes,
index_type_t* __restrict__ clusterSizes)
{
index_type_t i = threadIdx.x + blockIdx.x * blockDim.x;
while (i < n) {
Expand Down Expand Up @@ -308,10 +308,10 @@ static __global__ void computeClusterSizes(index_type_t n,
* column is the mean position of a cluster).
*/
template <typename index_type_t, typename value_type_t>
static __global__ void divideCentroids(index_type_t d,
index_type_t k,
const index_type_t* __restrict__ clusterSizes,
value_type_t* __restrict__ centroids)
RAFT_KERNEL divideCentroids(index_type_t d,
index_type_t k,
const index_type_t* __restrict__ clusterSizes,
value_type_t* __restrict__ centroids)
{
// Global indices
index_type_t gidx, gidy;
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/common/detail/scatter.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,7 +22,7 @@
namespace raft::detail {

template <typename DataT, int VecLen, typename Lambda, typename IdxT>
__global__ void scatterKernel(DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op)
RAFT_KERNEL scatterKernel(DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op)
{
typedef TxN_t<DataT, VecLen> DataVec;
typedef TxN_t<IdxT, VecLen> IdxVec;
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/core/detail/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ __device__ auto increment_indices(IdxType* indices,
* parameters.
*/
template <typename DstType, typename SrcType>
__global__ mdspan_copyable_with_kernel_t<DstType, SrcType> mdspan_copy_kernel(DstType dst,
SrcType src)

RAFT_KERNEL mdspan_copy_kernel(DstType dst, SrcType src)
{
using config = mdspan_copyable<true, DstType, SrcType>;

Expand Down
32 changes: 32 additions & 0 deletions cpp/include/raft/core/detail/macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,38 @@
// as a weak symbol rather than a global."
#define RAFT_WEAK_FUNCTION __attribute__((weak))

// The RAFT_HIDDEN_FUNCTION specificies that the function will be hidden
// and therefore not callable by consumers of raft when compiled as
// a shared library.
//
// Hidden visibility also ensures that the linker doesn't de-duplicate the
// symbol across multiple `.so`. This allows multiple libraries to embed raft
// without issue
#define RAFT_HIDDEN_FUNCTION __attribute__((visibility("hidden")))

// The RAFT_KERNEL specificies that a kernel has hidden visibility
//
// Raft needs to ensure that the visibility of its __global__ function
// templates have hidden visibility ( default is weak visibility).
//
// When kernls have weak visibility it means that if two dynamic libraries
// both contain identical instantiations of a RAFT template, then the linker
// will discard one of the two instantiations and use only one of them.
//
// Do to unique requirements of how the CUDA works this de-deduplication
// can lead to the wrong kernels being called ( SM version being wrong ),
// silently no kernel being called at all, or cuda runtime errors being
// thrown.
//
// https://github.com/rapidsai/raft/issues/1722
#if defined(__CUDACC_RDC__)
#define RAFT_KERNEL RAFT_HIDDEN_FUNCTION __global__ void
#elif defined(_RAFT_HAS_CUDA)
#define RAFT_KERNEL static __global__ void
#else
#define RAFT_KERNEL static void
#endif

/**
* Some macro magic to remove optional parentheses of a macro argument.
* See https://stackoverflow.com/a/62984543
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/detail/compress_to_bits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace raft::distance::detail {
* Note: the division (`/`) is a ceilDiv.
*/
template <typename T = uint64_t, typename = std::enable_if_t<std::is_integral<T>::value>>
__global__ void compress_to_bits_kernel(
RAFT_KERNEL compress_to_bits_kernel(
raft::device_matrix_view<const bool, int, raft::layout_c_contiguous> in,
raft::device_matrix_view<T, int, raft::layout_c_contiguous> out)
{
Expand Down
30 changes: 15 additions & 15 deletions cpp/include/raft/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ struct MinReduceOpImpl {
};

template <typename DataT, typename OutT, typename IdxT, typename ReduceOpT>
__global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp)
RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp)
{
auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x;
if (tid < m) { redOp.init(min + tid, maxVal); }
Expand Down Expand Up @@ -139,20 +139,20 @@ template <typename DataT,
typename KVPReduceOpT,
typename OpT,
typename FinalLambda>
__global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
IdxT m,
IdxT n,
IdxT k,
DataT maxVal,
int* mutex,
ReduceOpT redOp,
KVPReduceOpT pairRedOp,
OpT distance_op,
FinalLambda fin_op)
__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedL2NNkernel(OutT* min,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
IdxT m,
IdxT n,
IdxT k,
DataT maxVal,
int* mutex,
ReduceOpT redOp,
KVPReduceOpT pairRedOp,
OpT distance_op,
FinalLambda fin_op)
{
// compile only if below non-ampere arch.
#if __CUDA_ARCH__ < 800
Expand Down
10 changes: 5 additions & 5 deletions cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace raft::distance::kernels::detail {
* @param offset
*/
template <typename math_t, typename exp_t>
__global__ void polynomial_kernel_nopad(
RAFT_KERNEL polynomial_kernel_nopad(
math_t* inout, size_t len, exp_t exponent, math_t gain, math_t offset)
{
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < len;
Expand All @@ -56,7 +56,7 @@ __global__ void polynomial_kernel_nopad(
* @param offset
*/
template <typename math_t, typename exp_t>
__global__ void polynomial_kernel(
RAFT_KERNEL polynomial_kernel(
math_t* inout, int ld, int rows, int cols, exp_t exponent, math_t gain, math_t offset)
{
for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols;
Expand All @@ -75,7 +75,7 @@ __global__ void polynomial_kernel(
* @param offset
*/
template <typename math_t>
__global__ void tanh_kernel_nopad(math_t* inout, size_t len, math_t gain, math_t offset)
RAFT_KERNEL tanh_kernel_nopad(math_t* inout, size_t len, math_t gain, math_t offset)
{
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < len;
tid += blockDim.x * gridDim.x) {
Expand All @@ -93,7 +93,7 @@ __global__ void tanh_kernel_nopad(math_t* inout, size_t len, math_t gain, math_t
* @param offset
*/
template <typename math_t>
__global__ void tanh_kernel(math_t* inout, int ld, int rows, int cols, math_t gain, math_t offset)
RAFT_KERNEL tanh_kernel(math_t* inout, int ld, int rows, int cols, math_t gain, math_t offset)
{
for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols;
tidy += blockDim.y * gridDim.y)
Expand Down Expand Up @@ -121,7 +121,7 @@ __global__ void tanh_kernel(math_t* inout, int ld, int rows, int cols, math_t ga
* @param gain
*/
template <typename math_t>
__global__ void rbf_kernel_expanded(
RAFT_KERNEL rbf_kernel_expanded(
math_t* inout, int ld, int rows, int cols, math_t* norm_x, math_t* norm_y, math_t gain)
{
for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols;
Expand Down
36 changes: 18 additions & 18 deletions cpp/include/raft/distance/detail/masked_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,24 @@ template <typename DataT,
typename KVPReduceOpT,
typename CoreLambda,
typename FinalLambda>
__global__ __launch_bounds__(P::Nthreads, 2) void masked_l2_nn_kernel(OutT* min,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
const uint64_t* adj,
const IdxT* group_idxs,
IdxT num_groups,
IdxT m,
IdxT n,
IdxT k,
bool sqrt,
DataT maxVal,
int* mutex,
ReduceOpT redOp,
KVPReduceOpT pairRedOp,
CoreLambda core_op,
FinalLambda fin_op)
__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL masked_l2_nn_kernel(OutT* min,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
const uint64_t* adj,
const IdxT* group_idxs,
IdxT num_groups,
IdxT m,
IdxT n,
IdxT k,
bool sqrt,
DataT maxVal,
int* mutex,
ReduceOpT redOp,
KVPReduceOpT pairRedOp,
CoreLambda core_op,
FinalLambda fin_op)
{
extern __shared__ char smem[];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ template <typename Policy,
typename DataT,
typename OutT,
typename FinOpT>
__global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(
OpT distance_op, pairwise_matrix_params<IdxT, DataT, OutT, FinOpT> params)
__launch_bounds__(Policy::Nthreads, 2) RAFT_KERNEL
pairwise_matrix_kernel(OpT distance_op, pairwise_matrix_params<IdxT, DataT, OutT, FinOpT> params)
{
// Early exit to minimize the size of the kernel when it is not supposed to be compiled.
constexpr SM_compat_t sm_compat_range{};
Expand Down
16 changes: 8 additions & 8 deletions cpp/include/raft/label/detail/classlabels.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -119,13 +119,13 @@ void getOvrlabels(
// +/-1, return array with the new class labels and corresponding indices.

template <typename Type, int TPB_X, typename Lambda>
__global__ void map_label_kernel(Type* map_ids,
size_t N_labels,
Type* in,
Type* out,
size_t N,
Lambda filter_op,
bool zero_based = false)
RAFT_KERNEL map_label_kernel(Type* map_ids,
size_t N_labels,
Type* in,
Type* out,
size_t N,
Lambda filter_op,
bool zero_based = false)
{
int tid = threadIdx.x + blockIdx.x * TPB_X;
if (tid < N) {
Expand Down
Loading

0 comments on commit 27dcf7b

Please sign in to comment.