From cf085586ee6713b855475773d76406f5c4affc43 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 13 Oct 2023 22:30:23 +0200 Subject: [PATCH 1/6] End-to-end vector search tutorial in docs (#1776) Closes https://github.com/rapidsai/raft/issues/1745 Authors: - Corey J. Nolet (https://github.com/cjnolet) - William Hicks (https://github.com/wphicks) - Divye Gala (https://github.com/divyegala) Approvers: - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/1776 --- docs/source/index.rst | 7 +- .../{using_comms.rst => using_raft_comms.rst} | 0 docs/source/vector_search_tutorial.md | 343 ++++++++++++++++++ 3 files changed, 347 insertions(+), 3 deletions(-) rename docs/source/{using_comms.rst => using_raft_comms.rst} (100%) create mode 100644 docs/source/vector_search_tutorial.md diff --git a/docs/source/index.rst b/docs/source/index.rst index b5d6abbbab..ee89aed5a6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -35,14 +35,14 @@ While not exhaustive, the following general categories help summarize the accele * - Category - Examples + * - Nearest Neighbors + - pairwise distances, vector search, epsilon neighborhoods, neighborhood graph construction * - Data Formats - sparse & dense, conversions, data generation * - Dense Operations - linear algebra, matrix and vector operations, slicing, norms, factorization, least squares, svd & eigenvalue problems * - Sparse Operations - linear algebra, eigenvalue problems, slicing, norms, reductions, factorization, symmetrization, components & labeling - * - Nearest Neighbors - - pairwise distances, vector search, epsilon neighborhoods, neighborhood graph construction * - Basic Clustering - spectral clustering, hierarchical clustering, k-means * - Solvers @@ -61,9 +61,10 @@ While not exhaustive, the following general categories help summarize the accele cpp_api.rst pylibraft_api.rst using_libraft.md + vector_search_tutorial.md raft_ann_benchmarks.md raft_dask_api.rst - using_comms.rst + using_raft_comms.rst developer_guide.md contributing.md diff --git a/docs/source/using_comms.rst b/docs/source/using_raft_comms.rst similarity index 100% rename from docs/source/using_comms.rst rename to docs/source/using_raft_comms.rst diff --git a/docs/source/vector_search_tutorial.md b/docs/source/vector_search_tutorial.md new file mode 100644 index 0000000000..126ac534c6 --- /dev/null +++ b/docs/source/vector_search_tutorial.md @@ -0,0 +1,343 @@ +# Vector Search in C++ Tutorial + +RAFT has several important algorithms for performing vector search on the GPU and this tutorial walks through the primary vector search APIs from start to finish to provide a reference for quick setup and C++ API usage. + +This tutorial assumes RAFT has been installed and/or added to your build so that you are able to compile and run RAFT code. If not done already, please follow the [build and install instructions](build.md) and consider taking a look at the [example c++ template project](https://github.com/rapidsai/raft/tree/HEAD/cpp/template) for ready-to-go examples that you can immediately build and start playing with. Also take a look at RAFT's library of [reproducible vector search benchmarks](raft_ann_benchmarks.md) to run benchmarks that compare RAFT against other state-of-the-art nearest neighbors algorithms at scale. + +For more information about the various APIs demonstrated in this tutorial, along with comprehensive usage examples of all the APIs offered by RAFT, please refer to the [RAFT's C++ API Documentation](https://docs.rapids.ai/api/raft/nightly/cpp_api/). + +## Step 1: Starting off with RAFT + +### CUDA Development? + +If you are reading this tuturial then you probably know about CUDA and its relationship to general-purpose GPU computing (GPGPU). You probably also know about Nvidia GPUs but might not necessarily be familiar with the programming model nor GPU computing. The good news is that extensive knowledge of CUDA and GPUs are not needed in order to get started with or build applications with RAFT. RAFT hides away most of the complexities behind simple single-threaded stateless functions that are inherently asynchronous, meaning the result of a computation isn't necessarily read to be used when the function executes and control is given back to the user. The functions are, however, allowed to be chained together in a sequence of calls that don't need to wait for subsequent computations to complete in order to continue execution. In fact, the only time you need to wait for the computation to complete is when you are ready to use the result. + +A common structure you will encounter when using RAFT is a `raft::device_resources` object. This object is a container for important resources for a single GPU that might be needed during computation. If communicating with multiple GPUs, multiple `device_resources` might be needed, one for each GPU. `device_resources` contains several methods for managing its state but most commonly, you'll call the `sync_stream()` to guarantee all recently submitted computation has completed (as mentioned above.) + +A simple example of using `raft::device_resources` in RAFT: + +```c++ +#include + +raft::device_resources res; +// Call a bunch of RAFT functions in sequence... +res.sync_stream() +``` + +### Host vs Device Memory + +We differentiate between two different types of memory. `host` memory is your traditional RAM memory that is primarily accessible by applications on the CPU. `device` memory, on the other hand, is what we call the special memory on the GPU, which is not accessible from the CPU. In order to access host memory from the GPU, it needs to be explicitly copied to the GPU and in order to access device memory by the CPU, it needs to be explicitly copied there. We have several mechanisms available for allocating and managing the lifetime of device memory on the stack so that we don't need to explicitly allocate and free pointers on the heap. For example, instead of a `std::vector` for host memory, we can use `rmm::device_uvector` on the device. The following function will copy an array from host memory to device memory: + +```c++ +#include +#include +#include + +raft::device_resources res; + +std::vector my_host_vector = {0, 1, 2, 3, 4}; +rmm::device_uvector my_device_vector(my_host_vector.size(), res.get_stream()); + +raft::copy(my_device_vector.data(), my_host_vector.data(), my_host_vector.size(), res.get_stream()); +``` + +Since a stream is involved in the copy operation above, RAFT functions can be invoked immediately so long as the same `device_resources` instances is used (or, more specifically, the same main stream from the `devices_resources`.) As you might notice in the example above, `res.get_stream()` can be used to extract the main stream from a `device_resources` instance. + +### Multi-dimensional data representation + +`rmm::device_uvector` is a great mechanism for allocating and managing a chunk of device memory. While it's possible to use a single array to represent objects in higher dimensions like matrices, it lacks the means to pass that information along. For example, in addition to knowing that we have a 2d structure, we would need to know the number of rows, the number of columns, and even whether we read the columns or rows first (referred to as column- or row-major respectively). + +For this reason, RAFT relies on the `mdspan` standard, which was composed specifically for this purpose. To be even more, `mdspan` itself doesn't actually allocate or own any data on host or device because it's just a view over an existing memory on host device. The `mdspan` simply gives us a way to represent multi-dimensional data so we can pass along the needed metadata to our APIs. Even more powerful is that we can design functions that only accept a matrix of `float` in device memory that is laid out in row-major format. + +The memory-owning counterpart to the `mdspan` is the `mdarray` and the `mdarray` can allocate memory on device or host and carry along with it the metadata about its shape and layout. An `mdspan` can be produced from an `mdarray` for invoking RAFT APIs with `mdarray.view()`. They also follow similar paradigms to the STL, where we represent an immutable `mdspan` of `int` using `mdspan` instead of `const mdspan` to ensure it's the type carried along by the `mdspan` that's not allowed to change. + +Many RAFT functions require `mdspan` to represent immutable input data and there's no implicit conversion between `mdspan` and `mdspan` we use `raft::make_const_mdspan()` to alleviate the pain of constructing a new `mdspan` to invoke these functions. + +The following example demonstrates how to create `mdarray` matrices in both device and host memory, copy one to the other, and create mdspans out of them: + +```c++ +#include +#include +#include + +raft::device_resources res; + +int n_rows = 10; +int n_cols = 10; + +auto device_matrix = raft::make_device_matrix(res, n_rows, n_cols); +auto host_matrix = raft::make_host_matrix(res, n_rows, n_cols); + +// Set the diagonal to 1 +for(int i = 0; i < n_rows; i++) { + host_matrix(i, i) = 1; +} + +raft::copy(res, device_matrix.view(), host_matrix.view()); +``` + +## Step 2: Generate some data + +Let's build upon the fundamentals from the prior section and actually invoke some of RAFT's computational APIs on the device. A good starting point is data generation. + +```c++ +#include +#include + +raft::device_resources res; + +int n_rows = 10000; +int n_cols = 10000; + +auto dataset = raft::make_device_matrix(res, n_rows, n_cols); +auto labels = raft::make_device_vector(res, n_rows); + +raft::make_blobs(res, dataset.view(), labels.view()); +``` + +That's it. We've now generated a random 10kx10k matrix with points that cleanly separate into Gaussian clusters, along with a vector of cluster labels for each of the data points. Notice the `cuh` extension in the header file include for `make_blobs`. This signifies to us that this file contains CUDA device functions like kernel code so the CUDA compiler, `nvcc` is needed in order to compile any code that uses it. Generally, any source files that include headers with a `cuh` extension use the `.cu` extension instead of `.cpp`. The rule here is that `cpp` source files contain code which can be compiled with a C++ compiler like `g++` while `cu` files require the CUDA compiler. + +Since the `make_blobs` code generates the random dataset on the GPU device, we didn't need to do any host to device copies in this one. `make_blobs` is also asynchronous, so if we don't need to copy and use the data in host memory right away, we can continue calling RAFT functions with the `device_resources` instance and the data transformations will all be scheduled on the same stream. + +## Step 3: Using brute-force indexes + +### Build brute-force index + +Consider the `(10k, 10k)` shaped random matrix we generated in the previous step. We want to be able to find the k-nearest neighbors for all points of the matrix, or what we refer to as the all-neighbors graph, which means finding the neighbors of all data points within the same matrix. +```c++ +#include + +raft::device_resources res; + +// set number of neighbors to search for +int const k = 64; + +auto bfknn_index = raft::neighbors::brute_force::build(res, + raft::make_const_mdspan(dataset.view())); +``` + +### Query brute-force index + +```c++ + +// using matrix `dataset` from previous example +auto search = raft::make_const_mdspan(dataset.view()); + +// Indices and Distances are of dimensions (n, k) +// where n is number of rows in the search matrix +auto reference_indices = raft::make_device_matrix(search.extent(0), k); // stores index of neighbors +auto reference_distances = raft::make_device_matrix(search.extent(0), k); // stores distance to neighbors + +raft::neighbors::brute_force::search(res, + bfknn_index, + search, + raft::make_const_mdspan(indices.view()), + raft::make_const_mdspan(distances.view())); +``` + +We have established several things here by building a flat index. Now we know the exact 64 neighbors of all points in the matrix, and this algorithm can be generally useful in several ways: +1. Creating a baseline to compare against when building an approximate nearest neighbors index. +2. Directly using the brute-force algorithm when accuracy is more important than speed of computation. Don't worry, our implementation is still the best in-class and will provide not only significant speedups over other brute force methods, but also be quick relatively when the matrices are small! + + +## Step 4: Using the ANN indexes + +### Build a CAGRA index + +Next we'll train an ANN index. We'll use our graph-based CAGRA algorithm for this example but the other index types use a very similar pattern. + +```c++ +#include + +raft::device_resources res; + +// use default index parameters +cagra::index_params index_params; + +auto index = cagra::build(res, index_params, dataset); +``` + +### Query the CAGRA index + +Now that we've trained a CAGRA index, we can query it by first allocating our output `mdarray` objects and passing the trained index model into the search function. + +```c++ +// create output arrays +auto indices = raft::make_device_matrix(res, n_rows, k); +auto distances = raft::make_device_matrix(res, n_rows, k); + +// use default search parameters +cagra::search_params search_params; + +// search K nearest neighbors +cagra::search( +res, search_params, index, search, indices.view(), distances.view()); +``` + +## Step 7: Evaluate neighborhood quality + +In step 3 we built a flat index and queried for exact neighbors while in step 4 we build an ANN index and queried for approximate neighbors. How do you quickly figure out the quality of our approximate neighbors and whether it's in an acceptable range based on your needs? Just compute the `neighborhood_recall` which gives a single value in the range [0, 1]. Closer the value to 1, higher the quality of the approximation. + +```c++ +#include + +raft::device_resources res; + +// Assuming matrices as type raft::device_matrix_view and variables as +// indices : approximate neighbor indices +// reference_indices : exact neighbor indices +// distances : approximate neighbor distances +// reference_distances : exact neighbor distances + +// We want our `neighborhood_recall` value in host memory +float const recall_scalar = 0.0; +auto recall_value = raft::make_host_scalar(recall_scalar); + +raft::stats::neighborhood_recall(res, + raft::make_const_mdspan(indices.view()), + raft::make_const_mdspan(reference_indices.view()), + recall_value.view(), + raft::make_const_mdspan(distances), + raft::make_const_mdspan(reference_distances)); + +res.sync_stream(); +``` + +Notice we can run invoke the functions for index build and search for both algorithms, one right after the other, because we don't need to access any outputs from the algorithms in host memory. We will need to synchronize the stream on the `raft::device_resources` instance before we can read the result of the `neighborhood_recall` computation, though. + +Similar to a Numpy array, when we use a `host_scalar`, we are really using a multi-dimensional structure that contains only a single dimension, and further a single element. We can use element indexing to access the resulting element directly. +```c++ +std::cout << recall_value(0) << std::endl; +``` + +While it may seem like unnecessary additional work to wrap the result in a `host_scalar` mdspan, this API choice is made intentionally to support the possibility of also receiving the result as a `device_scalar` so that it can be used directly on the device for follow-on computations without having to incur the synchronization or transfer cost of bringing the result to host. This pattern becomes even more important when the result is being computed in a loop, such as an iterative solver, and the cost of synchronization and device-to-host (d2h) transfer becomes very expensive. + +## Advanced features + +The following sections present some advanced features that we have found can be useful for squeezing more utilization out of GPU hardware. As you've seen in this tutorial, RAFT provides several very useful tools and building blocks for developing accelerated applications beyond vector search capabilities. + +### Stream pools + +Within each CPU thread, CUDA uses `streams` to submit asynchronous work. You can think of a stream as a queue. Each stream can submit work to the GPU independently of other streams but work submitted within each stream is queued and executed in the order in which it was submitted. Similar to how we can use thread pools to bound the parallelism of CPU threads, we can use CUDA stream pools to bound the amount of concurrent asynchronous work that can be scheduled on a GPU. Each instance of `device_resources` has a main stream, but can also create a stream pool. For a single CPU thread, multiple different instances of `device_resources` can be created with different main streams and used to invoke a series of RAFT functions concurrently on the same or different GPU devices, so long as the target devices have available resources to perform the work. Once a device is saturated, queued work on streams will be scheduled and wait for a chance to do more work. During this time the streams are waiting, the CPU thread will still continue its own execution asynchronously unless `sync_stream_pool()` is called, causing the thread to block and wait for the thread pools to complete. + +Also, beware that before splitting GPU work onto multiple different concurrent streams, it can often be important to wait for the main stream in the `device_resources`. This can be done with `wait_stream_pool_on_stream()`. + +To summarize, if wanting to execute multiple different streams in parallel, we would often use a stream pool like this: +```c++ +#include + +#include +#include + +int n_streams = 5; + +rmm::cuda_stream stream; +std::shared_ptr stream_pool(5) +raft::device_resources res(stream.view(), stream_pool); + +// Submit some work on the main stream... + +res.wait_stream_pool_on_stream() +for(int i = 0; i < n_streams; ++i) { + rmm::cuda_stream_view stream_from_pool = res.get_next_usable_stream(); + raft::device_resources pool_res(stream_from_pool); + // Submit some work with pool_res... +} + +res.sync_stream_pool(); +``` + +### Device resources manager + +In multi-threaded applications, it is often useful to create a set of +`raft::device_resources` objects on startup to avoid the overhead of +re-initializing underlying resources every time a `raft::device_resources` object +is needed. To help simplify this common initialization logic, RAFT +provides a `raft::device_resources_manager` to handle this for downstream +applications. On startup, the application can specify certain limits on the +total resource consumption of the `raft::device_resources` objects that will be +generated: +```c++ +#include + +void initialize_application() { + // Set the total number of CUDA streams to use on each GPU across all CPU + // threads. If this method is not called, the default stream per thread + // will be used. + raft::device_resources_manager::set_streams_per_device(16); + + // Create a memory pool with given max size in bytes. Passing std::nullopt will allow + // the pool to grow to the available memory of the device. + raft::device_resources_manager::set_max_mem_pool_size(std::nullopt); + + // Set the initial size of the memory pool in bytes. + raft::device_resources_manager::set_init_mem_pool_size(16000000); + + // If neither of the above methods are called, no memory pool will be used +} +``` +While this example shows some commonly used settings, +`raft::device_resources_manager` provides support for several other +resource options and constraints, including options to initialize entire +stream pools that can be used by an individual `raft::device_resources` object. After +this initialization method is called, the following function could be called +from any CPU thread: +```c++ +void foo() { + raft::device_resources const& res = raft::device_resources_manager::get_device_resources(); + // Submit some work with res + res.sync_stream(); +} +``` + +If any `raft::device_resources_manager` setters are called _after_ the first +call to `raft::device_resources_manager::get_device_resources()`, these new +settings are ignored, and a warning will be logged. If a thread calls +`raft::device_resources_manager::get_device_resources()` multiple times, it is +guaranteed to access the same underlying `raft::device_resources` object every +time. This can be useful for chaining work in different calls on the same +thread without keeping a persistent reference to the resources object. + +### Device memory resources + +The RAPIDS software ecosystem makes heavy use of the [RAPIDS Memory Manager](https://github.com/rapidsai/rmm) (RMM) to enable zero-copy sharing of device memory across various GPU-enabled libraries such as PyTorch, Jax, Tensorflow, and FAISS. A really powerful feature of RMM is the ability to set a memory resource, such as a pooled memory resource that allocates a block of memory up front to speed up subsequent smaller allocations, and have all the libraries in the GPU ecosystem recognize and use that same memory resource for all of their memory allocations. + +As an example, the following code snippet creates a `pool_memory_resource` and sets it as the default memory resource, which means all other libraries that use RMM will now allocate their device memory from this same pool: +```c++ +#include + +rmm::mr::cuda_memory_resource cuda_mr; +// Construct a resource that uses a coalescing best-fit pool allocator +rmm::mr::pool_memory_resource pool_mr{&cuda_mr}; +rmm::mr::set_current_device_resource(&pool_mr); // Updates the current device resource pointer to `pool_mr` +``` + +The `raft::device_resources` object will now also use the `rmm::current_device_resource`. This isn't limited to C++, however. Often a user will be interacting with PyTorch, RAPIDS, or Tensorflow through Python and so they can set and use RMM's `current_device_resource` [right in Python](https://github.com/rapidsai/rmm#using-rmm-in-python-code). + +### Workspace memory resource + +As mentioned above, `raft::device_resources` will use `rmm::current_device_resource` by default for all memory allocations. However, there are times when a particular algorithm might benefit from using a different memory resource such as a `managed_memory_resource`, which creates a unified memory space between device and host memory, paging memory in and out of device as needed. Most of RAFT's algorithms allocate temporary memory as needed to perform their computations and we can control the memory resource used for these temporary allocations through the `workspace_resource` in the `raft::device_resources` instance. + +For some applications, the `managed_memory_resource`, can enable a memory space that is larger than the GPU, thus allowing a natural spilling to host memory when needed. This isn't always the best way to use managed memory, though, as it can quickly lead to thrashing and severely impact performance. Still, when it can be used, it provides a very powerful tool that can also avoid out of memory errors when enough host memory is available. + +The following creates a managed memory allocator and set it as the `workspace_resource` of the `raft::device_resources` instance: +```c++ +#include +#include + +std::shared_ptr managed_resource; +raft::device_resource res(managed_resource); +``` + +The `workspace_resource` uses an `rmm::mr::limiting_resource_adaptor`, which limits the total amount of allocation possible. This allows RAFT algorithms to work within the confines of the memory constraints imposed by the user so that things like batch sizes can be automatically set to reasonable values without exceeding the allotted memory. By default, this limit restricts the memory allocation space for temporary workspace buffers to the memory available on the device. + +The below example specifies the total number of bytes that RAFT can use for temporary workspace allocations to 3GB: +```c++ +#include +#include + +#include + +std::shared_ptr managed_resource; +raft::device_resource res(managed_resource, std::make_optional(3 * 1024^3)); +``` \ No newline at end of file From 27dcf7bd9be6e2eb3f0948e16cb98c6012e0aaf1 Mon Sep 17 00:00:00 2001 From: Robert Maynard Date: Fri, 13 Oct 2023 18:17:18 -0400 Subject: [PATCH 2/6] Make all cuda kernels have hidden visibility (#1898) Effect on binary size of libraft.a 23.12: 133361630 pr: 129748904 Effect on binary size of libraft.so 23.12: 83603224 pr: 83873088 Authors: - Robert Maynard (https://github.com/robertmaynard) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1898 --- cpp/bench/prims/distance/masked_nn.cu | 8 +- cpp/bench/prims/sparse/convert_csr.cu | 2 +- .../raft/cluster/detail/agglomerative.cuh | 16 +- .../raft/cluster/detail/connectivities.cuh | 2 +- .../raft/cluster/detail/kmeans_balanced.cuh | 2 +- .../raft/cluster/detail/kmeans_deprecated.cuh | 46 ++--- cpp/include/raft/common/detail/scatter.cuh | 4 +- cpp/include/raft/core/detail/copy.hpp | 4 +- cpp/include/raft/core/detail/macros.hpp | 32 ++++ .../raft/distance/detail/compress_to_bits.cuh | 2 +- .../raft/distance/detail/fused_l2_nn.cuh | 30 ++-- .../detail/kernels/kernel_matrices.cuh | 10 +- .../raft/distance/detail/masked_nn.cuh | 36 ++-- .../detail/pairwise_matrix/kernel_sm60.cuh | 4 +- cpp/include/raft/label/detail/classlabels.cuh | 16 +- .../raft/label/detail/merge_labels.cuh | 26 ++- cpp/include/raft/linalg/detail/add.cuh | 10 +- .../linalg/detail/coalesced_reduction-inl.cuh | 22 +-- cpp/include/raft/linalg/detail/map.cuh | 2 +- .../raft/linalg/detail/map_then_reduce.cuh | 14 +- cpp/include/raft/linalg/detail/normalize.cuh | 22 +-- .../raft/linalg/detail/reduce_cols_by_key.cuh | 6 +- .../raft/linalg/detail/reduce_rows_by_key.cuh | 52 +++--- .../raft/linalg/detail/strided_reduction.cuh | 16 +- cpp/include/raft/linalg/detail/subtract.cuh | 10 +- .../raft/matrix/detail/columnWiseSort.cuh | 30 ++-- cpp/include/raft/matrix/detail/gather.cuh | 16 +- .../raft/matrix/detail/linewise_op.cuh | 66 ++++--- cpp/include/raft/matrix/detail/math.cuh | 4 +- cpp/include/raft/matrix/detail/matrix.cuh | 11 +- .../raft/matrix/detail/select_radix.cuh | 70 ++++---- .../raft/matrix/detail/select_warpsort.cuh | 8 +- .../neighbors/detail/cagra/graph_core.cuh | 40 ++--- .../cagra/search_multi_cta_kernel-inl.cuh | 12 +- .../detail/cagra/search_multi_kernel.cuh | 85 +++++---- .../cagra/search_single_cta_kernel-inl.cuh | 50 +++--- .../detail/cagra/topk_for_cagra/topk_core.cuh | 30 ++-- .../raft/neighbors/detail/ivf_flat_build.cuh | 33 ++-- .../detail/ivf_flat_interleaved_scan-inl.cuh | 2 +- .../raft/neighbors/detail/ivf_pq_build.cuh | 18 +- .../detail/ivf_pq_compute_similarity-ext.cuh | 42 ++--- .../detail/ivf_pq_compute_similarity-inl.cuh | 42 ++--- .../raft/neighbors/detail/ivf_pq_search.cuh | 30 ++-- .../raft/neighbors/detail/knn_merge_parts.cuh | 20 +-- .../raft/neighbors/detail/nn_descent.cuh | 20 +-- .../neighbors/detail/selection_faiss-inl.cuh | 18 +- cpp/include/raft/random/detail/make_blobs.cuh | 20 +-- .../raft/random/detail/make_regression.cuh | 4 +- .../random/detail/multi_variable_gaussian.cuh | 4 +- cpp/include/raft/random/detail/permute.cuh | 4 +- .../detail/rmat_rectangular_generator.cuh | 40 ++--- cpp/include/raft/random/detail/rng_device.cuh | 17 +- cpp/include/raft/random/detail/rng_impl.cuh | 4 +- .../raft/solver/detail/lap_kernels.cuh | 166 +++++++++--------- .../raft/sparse/convert/detail/adj_to_csr.cuh | 2 +- .../raft/sparse/convert/detail/coo.cuh | 10 +- .../raft/sparse/convert/detail/dense.cuh | 4 +- cpp/include/raft/sparse/detail/csr.cuh | 25 ++- cpp/include/raft/sparse/detail/utils.h | 2 +- .../sparse/distance/detail/bin_distance.cuh | 20 +-- .../distance/detail/coo_spmv_kernel.cuh | 40 ++--- .../coo_mask_row_iterators.cuh | 6 +- .../sparse/distance/detail/l2_distance.cuh | 44 ++--- cpp/include/raft/sparse/linalg/detail/add.cuh | 46 ++--- .../raft/sparse/linalg/detail/degree.cuh | 8 +- .../raft/sparse/linalg/detail/norm.cuh | 4 +- .../raft/sparse/linalg/detail/symmetrize.cuh | 56 +++--- cpp/include/raft/sparse/linalg/symmetrize.cuh | 36 ++-- .../neighbors/detail/cross_component_nn.cuh | 14 +- .../sparse/neighbors/detail/knn_graph.cuh | 4 +- cpp/include/raft/sparse/op/detail/filter.cuh | 24 +-- cpp/include/raft/sparse/op/detail/reduce.cuh | 24 +-- cpp/include/raft/sparse/op/detail/row_op.cuh | 4 +- .../raft/sparse/solver/detail/mst_kernels.cuh | 108 ++++++------ .../raft/spatial/knn/detail/ann_utils.cuh | 16 +- .../knn/detail/ball_cover/registers-inl.cuh | 86 ++++----- .../knn/detail/epsilon_neighborhood.cuh | 6 +- .../spatial/knn/detail/fused_l2_knn-inl.cuh | 32 ++-- .../spatial/knn/detail/haversine_distance.cuh | 12 +- .../raft/spectral/detail/matrix_wrappers.hpp | 10 +- .../raft/spectral/detail/spectral_util.cuh | 2 +- .../stats/detail/batched/silhouette_score.cuh | 30 ++-- .../raft/stats/detail/contingencyMatrix.cuh | 26 +-- cpp/include/raft/stats/detail/dispersion.cuh | 16 +- cpp/include/raft/stats/detail/histogram.cuh | 24 ++- cpp/include/raft/stats/detail/mean.cuh | 6 +- cpp/include/raft/stats/detail/meanvar.cuh | 8 +- cpp/include/raft/stats/detail/minmax.cuh | 28 +-- .../raft/stats/detail/mutual_info_score.cuh | 14 +- .../raft/stats/detail/neighborhood_recall.cuh | 2 +- cpp/include/raft/stats/detail/rand_index.cuh | 4 +- cpp/include/raft/stats/detail/scores.cuh | 4 +- .../raft/stats/detail/silhouette_score.cuh | 14 +- cpp/include/raft/stats/detail/stddev.cuh | 8 +- cpp/include/raft/stats/detail/sum.cuh | 6 +- .../stats/detail/trustworthiness_score.cuh | 14 +- .../thirdparty/mdspan/tests/offload_utils.hpp | 2 +- cpp/include/raft/util/cache_util.cuh | 53 +++--- cpp/include/raft/util/detail/scatter.cuh | 4 +- .../raft_internal/neighbors/naive_knn.cuh | 14 +- cpp/test/cluster/linkage.cu | 2 +- cpp/test/core/interruptible.cu | 3 +- cpp/test/core/math_device.cu | 2 +- cpp/test/core/operators_device.cu | 4 +- cpp/test/core/span.cu | 6 +- cpp/test/distance/dist_adj.cu | 16 +- cpp/test/distance/distance_base.cuh | 64 +++---- cpp/test/distance/fused_l2_nn.cu | 16 +- cpp/test/distance/masked_nn.cu | 33 ++-- .../distance/masked_nn_compress_to_bits.cu | 2 +- cpp/test/linalg/add.cuh | 4 +- cpp/test/linalg/axpy.cu | 2 +- cpp/test/linalg/binary_op.cuh | 4 +- cpp/test/linalg/divide.cu | 2 +- cpp/test/linalg/dot.cu | 2 +- cpp/test/linalg/eltwise.cu | 4 +- cpp/test/linalg/gemm_layout.cu | 2 +- cpp/test/linalg/gemv.cu | 14 +- cpp/test/linalg/map_then_reduce.cu | 2 +- cpp/test/linalg/matrix_vector_op.cuh | 36 ++-- cpp/test/linalg/mean_squared_error.cu | 2 +- cpp/test/linalg/norm.cu | 4 +- cpp/test/linalg/power.cu | 4 +- cpp/test/linalg/reduce.cuh | 38 ++-- cpp/test/linalg/reduce_rows_by_key.cu | 18 +- cpp/test/linalg/sqrt.cu | 2 +- cpp/test/linalg/subtract.cu | 4 +- cpp/test/linalg/unary_op.cuh | 2 +- cpp/test/matrix/math.cu | 6 +- cpp/test/neighbors/ann_cagra.cuh | 6 +- cpp/test/neighbors/ball_cover.cu | 16 +- cpp/test/neighbors/knn.cu | 4 +- cpp/test/random/make_blobs.cu | 18 +- cpp/test/random/multi_variable_gaussian.cu | 6 +- cpp/test/random/rmat_rectangular_generator.cu | 8 +- cpp/test/random/rng.cu | 2 +- cpp/test/random/rng_int.cu | 2 +- cpp/test/random/rng_pcg_host_api.cu | 10 +- cpp/test/sparse/convert_csr.cu | 2 +- cpp/test/sparse/neighbors/knn_graph.cu | 2 +- cpp/test/sparse/symmetrize.cu | 2 +- cpp/test/stats/histogram.cu | 2 +- cpp/test/stats/minmax.cu | 6 +- cpp/test/util/bitonic_sort.cu | 2 +- cpp/test/util/device_atomics.cu | 2 +- cpp/test/util/integer_utils.cu | 12 +- cpp/test/util/reduction.cu | 14 +- 147 files changed, 1322 insertions(+), 1316 deletions(-) diff --git a/cpp/bench/prims/distance/masked_nn.cu b/cpp/bench/prims/distance/masked_nn.cu index c804ecb3a1..19d78f4cd9 100644 --- a/cpp/bench/prims/distance/masked_nn.cu +++ b/cpp/bench/prims/distance/masked_nn.cu @@ -46,10 +46,10 @@ struct Params { AdjacencyPattern pattern; }; // struct Params -__global__ void init_adj(AdjacencyPattern pattern, - int n, - raft::device_matrix_view adj, - raft::device_vector_view group_idxs) +RAFT_KERNEL init_adj(AdjacencyPattern pattern, + int n, + raft::device_matrix_view adj, + raft::device_vector_view group_idxs) { int m = adj.extent(0); int num_groups = adj.extent(1); diff --git a/cpp/bench/prims/sparse/convert_csr.cu b/cpp/bench/prims/sparse/convert_csr.cu index c9dcae6985..634c749a54 100644 --- a/cpp/bench/prims/sparse/convert_csr.cu +++ b/cpp/bench/prims/sparse/convert_csr.cu @@ -30,7 +30,7 @@ struct bench_param { }; template -__global__ void init_adj_kernel(bool* adj, index_t num_rows, index_t num_cols, index_t divisor) +RAFT_KERNEL init_adj_kernel(bool* adj, index_t num_rows, index_t num_cols, index_t divisor) { index_t r = blockDim.y * blockIdx.y + threadIdx.y; index_t c = blockDim.x * blockIdx.x + threadIdx.x; diff --git a/cpp/include/raft/cluster/detail/agglomerative.cuh b/cpp/include/raft/cluster/detail/agglomerative.cuh index 624e67b7fa..f2c83abdd3 100644 --- a/cpp/include/raft/cluster/detail/agglomerative.cuh +++ b/cpp/include/raft/cluster/detail/agglomerative.cuh @@ -155,9 +155,7 @@ void build_dendrogram_host(raft::resources const& handle, } template -__global__ void write_levels_kernel(const value_idx* children, - value_idx* parents, - value_idx n_vertices) +RAFT_KERNEL write_levels_kernel(const value_idx* children, value_idx* parents, value_idx n_vertices) { value_idx tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid < n_vertices) { @@ -179,12 +177,12 @@ __global__ void write_levels_kernel(const value_idx* children, * @param labels */ template -__global__ void inherit_labels(const value_idx* children, - const value_idx* levels, - std::size_t n_leaves, - value_idx* labels, - int cut_level, - value_idx n_vertices) +RAFT_KERNEL inherit_labels(const value_idx* children, + const value_idx* levels, + std::size_t n_leaves, + value_idx* labels, + int cut_level, + value_idx n_vertices) { value_idx tid = blockDim.x * blockIdx.x + threadIdx.x; diff --git a/cpp/include/raft/cluster/detail/connectivities.cuh b/cpp/include/raft/cluster/detail/connectivities.cuh index ef046ab4ff..49ac6ae704 100644 --- a/cpp/include/raft/cluster/detail/connectivities.cuh +++ b/cpp/include/raft/cluster/detail/connectivities.cuh @@ -107,7 +107,7 @@ struct distance_graph_impl -__global__ void fill_indices2(value_idx* indices, size_t m, size_t nnz) +RAFT_KERNEL fill_indices2(value_idx* indices, size_t m, size_t nnz) { value_idx tid = (blockIdx.x * blockDim.x) + threadIdx.x; if (tid >= nnz) return; diff --git a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh index ade3a6e348..593d7d8fa9 100644 --- a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh @@ -434,7 +434,7 @@ template -__global__ void __launch_bounds__((WarpSize * BlockDimY)) +__launch_bounds__((WarpSize * BlockDimY)) RAFT_KERNEL adjust_centers_kernel(MathT* centers, // [n_clusters, dim] IdxT n_clusters, IdxT dim, diff --git a/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh b/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh index 5a1479a81f..0b5dec4e19 100644 --- a/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh @@ -92,12 +92,12 @@ constexpr unsigned int BSIZE_DIV_WSIZE = (BLOCK_SIZE / WARP_SIZE); * initialized to zero. */ template -static __global__ void computeDistances(index_type_t n, - index_type_t d, - index_type_t k, - const value_type_t* __restrict__ obs, - const value_type_t* __restrict__ centroids, - value_type_t* __restrict__ dists) +RAFT_KERNEL computeDistances(index_type_t n, + index_type_t d, + index_type_t k, + const value_type_t* __restrict__ obs, + const value_type_t* __restrict__ centroids, + value_type_t* __restrict__ dists) { // Loop index index_type_t i; @@ -173,11 +173,11 @@ static __global__ void computeDistances(index_type_t n, * cluster. Entries must be initialized to zero. */ template -static __global__ void minDistances(index_type_t n, - index_type_t k, - value_type_t* __restrict__ dists, - index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes) +RAFT_KERNEL minDistances(index_type_t n, + index_type_t k, + value_type_t* __restrict__ dists, + index_type_t* __restrict__ codes, + index_type_t* __restrict__ clusterSizes) { // Loop index index_type_t i, j; @@ -233,11 +233,11 @@ static __global__ void minDistances(index_type_t n, * @param code_new Index associated with new centroid. */ template -static __global__ void minDistances2(index_type_t n, - value_type_t* __restrict__ dists_old, - const value_type_t* __restrict__ dists_new, - index_type_t* __restrict__ codes_old, - index_type_t code_new) +RAFT_KERNEL minDistances2(index_type_t n, + value_type_t* __restrict__ dists_old, + const value_type_t* __restrict__ dists_new, + index_type_t* __restrict__ codes_old, + index_type_t code_new) { // Loop index index_type_t i = threadIdx.x + blockIdx.x * blockDim.x; @@ -275,9 +275,9 @@ static __global__ void minDistances2(index_type_t n, * cluster. Entries must be initialized to zero. */ template -static __global__ void computeClusterSizes(index_type_t n, - const index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes) +RAFT_KERNEL computeClusterSizes(index_type_t n, + const index_type_t* __restrict__ codes, + index_type_t* __restrict__ clusterSizes) { index_type_t i = threadIdx.x + blockIdx.x * blockDim.x; while (i < n) { @@ -308,10 +308,10 @@ static __global__ void computeClusterSizes(index_type_t n, * column is the mean position of a cluster). */ template -static __global__ void divideCentroids(index_type_t d, - index_type_t k, - const index_type_t* __restrict__ clusterSizes, - value_type_t* __restrict__ centroids) +RAFT_KERNEL divideCentroids(index_type_t d, + index_type_t k, + const index_type_t* __restrict__ clusterSizes, + value_type_t* __restrict__ centroids) { // Global indices index_type_t gidx, gidy; diff --git a/cpp/include/raft/common/detail/scatter.cuh b/cpp/include/raft/common/detail/scatter.cuh index 87a8826aa6..6e7522853e 100644 --- a/cpp/include/raft/common/detail/scatter.cuh +++ b/cpp/include/raft/common/detail/scatter.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ namespace raft::detail { template -__global__ void scatterKernel(DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op) +RAFT_KERNEL scatterKernel(DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op) { typedef TxN_t DataVec; typedef TxN_t IdxVec; diff --git a/cpp/include/raft/core/detail/copy.hpp b/cpp/include/raft/core/detail/copy.hpp index b23660fefe..dd50f47786 100644 --- a/cpp/include/raft/core/detail/copy.hpp +++ b/cpp/include/raft/core/detail/copy.hpp @@ -329,8 +329,8 @@ __device__ auto increment_indices(IdxType* indices, * parameters. */ template -__global__ mdspan_copyable_with_kernel_t mdspan_copy_kernel(DstType dst, - SrcType src) + +RAFT_KERNEL mdspan_copy_kernel(DstType dst, SrcType src) { using config = mdspan_copyable; diff --git a/cpp/include/raft/core/detail/macros.hpp b/cpp/include/raft/core/detail/macros.hpp index bb4207938b..364914043e 100644 --- a/cpp/include/raft/core/detail/macros.hpp +++ b/cpp/include/raft/core/detail/macros.hpp @@ -86,6 +86,38 @@ // as a weak symbol rather than a global." #define RAFT_WEAK_FUNCTION __attribute__((weak)) +// The RAFT_HIDDEN_FUNCTION specificies that the function will be hidden +// and therefore not callable by consumers of raft when compiled as +// a shared library. +// +// Hidden visibility also ensures that the linker doesn't de-duplicate the +// symbol across multiple `.so`. This allows multiple libraries to embed raft +// without issue +#define RAFT_HIDDEN_FUNCTION __attribute__((visibility("hidden"))) + +// The RAFT_KERNEL specificies that a kernel has hidden visibility +// +// Raft needs to ensure that the visibility of its __global__ function +// templates have hidden visibility ( default is weak visibility). +// +// When kernls have weak visibility it means that if two dynamic libraries +// both contain identical instantiations of a RAFT template, then the linker +// will discard one of the two instantiations and use only one of them. +// +// Do to unique requirements of how the CUDA works this de-deduplication +// can lead to the wrong kernels being called ( SM version being wrong ), +// silently no kernel being called at all, or cuda runtime errors being +// thrown. +// +// https://github.com/rapidsai/raft/issues/1722 +#if defined(__CUDACC_RDC__) +#define RAFT_KERNEL RAFT_HIDDEN_FUNCTION __global__ void +#elif defined(_RAFT_HAS_CUDA) +#define RAFT_KERNEL static __global__ void +#else +#define RAFT_KERNEL static void +#endif + /** * Some macro magic to remove optional parentheses of a macro argument. * See https://stackoverflow.com/a/62984543 diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh index fa0df25461..5ffb717c42 100644 --- a/cpp/include/raft/distance/detail/compress_to_bits.cuh +++ b/cpp/include/raft/distance/detail/compress_to_bits.cuh @@ -35,7 +35,7 @@ namespace raft::distance::detail { * Note: the division (`/`) is a ceilDiv. */ template ::value>> -__global__ void compress_to_bits_kernel( +RAFT_KERNEL compress_to_bits_kernel( raft::device_matrix_view in, raft::device_matrix_view out) { diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index f0f12acdb1..2468dcd740 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -87,7 +87,7 @@ struct MinReduceOpImpl { }; template -__global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; if (tid < m) { redOp.init(min + tid, maxVal); } @@ -139,20 +139,20 @@ template -__global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - DataT maxVal, - int* mutex, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - OpT distance_op, - FinalLambda fin_op) +__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedL2NNkernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + OpT distance_op, + FinalLambda fin_op) { // compile only if below non-ampere arch. #if __CUDA_ARCH__ < 800 diff --git a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh index f02e29c797..8d5b2c766e 100644 --- a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh +++ b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh @@ -36,7 +36,7 @@ namespace raft::distance::kernels::detail { * @param offset */ template -__global__ void polynomial_kernel_nopad( +RAFT_KERNEL polynomial_kernel_nopad( math_t* inout, size_t len, exp_t exponent, math_t gain, math_t offset) { for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < len; @@ -56,7 +56,7 @@ __global__ void polynomial_kernel_nopad( * @param offset */ template -__global__ void polynomial_kernel( +RAFT_KERNEL polynomial_kernel( math_t* inout, int ld, int rows, int cols, exp_t exponent, math_t gain, math_t offset) { for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; @@ -75,7 +75,7 @@ __global__ void polynomial_kernel( * @param offset */ template -__global__ void tanh_kernel_nopad(math_t* inout, size_t len, math_t gain, math_t offset) +RAFT_KERNEL tanh_kernel_nopad(math_t* inout, size_t len, math_t gain, math_t offset) { for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < len; tid += blockDim.x * gridDim.x) { @@ -93,7 +93,7 @@ __global__ void tanh_kernel_nopad(math_t* inout, size_t len, math_t gain, math_t * @param offset */ template -__global__ void tanh_kernel(math_t* inout, int ld, int rows, int cols, math_t gain, math_t offset) +RAFT_KERNEL tanh_kernel(math_t* inout, int ld, int rows, int cols, math_t gain, math_t offset) { for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; tidy += blockDim.y * gridDim.y) @@ -121,7 +121,7 @@ __global__ void tanh_kernel(math_t* inout, int ld, int rows, int cols, math_t ga * @param gain */ template -__global__ void rbf_kernel_expanded( +RAFT_KERNEL rbf_kernel_expanded( math_t* inout, int ld, int rows, int cols, math_t* norm_x, math_t* norm_y, math_t gain) { for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 0e13783c19..4de9f4764a 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -40,24 +40,24 @@ template -__global__ __launch_bounds__(P::Nthreads, 2) void masked_l2_nn_kernel(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - const uint64_t* adj, - const IdxT* group_idxs, - IdxT num_groups, - IdxT m, - IdxT n, - IdxT k, - bool sqrt, - DataT maxVal, - int* mutex, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - CoreLambda core_op, - FinalLambda fin_op) +__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL masked_l2_nn_kernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + const uint64_t* adj, + const IdxT* group_idxs, + IdxT num_groups, + IdxT m, + IdxT n, + IdxT k, + bool sqrt, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + CoreLambda core_op, + FinalLambda fin_op) { extern __shared__ char smem[]; diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 2d0a98862e..5393bf7389 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -31,8 +31,8 @@ template -__global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel( - OpT distance_op, pairwise_matrix_params params) +__launch_bounds__(Policy::Nthreads, 2) RAFT_KERNEL + pairwise_matrix_kernel(OpT distance_op, pairwise_matrix_params params) { // Early exit to minimize the size of the kernel when it is not supposed to be compiled. constexpr SM_compat_t sm_compat_range{}; diff --git a/cpp/include/raft/label/detail/classlabels.cuh b/cpp/include/raft/label/detail/classlabels.cuh index 64d8b4bfae..6e432e050c 100644 --- a/cpp/include/raft/label/detail/classlabels.cuh +++ b/cpp/include/raft/label/detail/classlabels.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -119,13 +119,13 @@ void getOvrlabels( // +/-1, return array with the new class labels and corresponding indices. template -__global__ void map_label_kernel(Type* map_ids, - size_t N_labels, - Type* in, - Type* out, - size_t N, - Lambda filter_op, - bool zero_based = false) +RAFT_KERNEL map_label_kernel(Type* map_ids, + size_t N_labels, + Type* in, + Type* out, + size_t N, + Lambda filter_op, + bool zero_based = false) { int tid = threadIdx.x + blockIdx.x * TPB_X; if (tid < N) { diff --git a/cpp/include/raft/label/detail/merge_labels.cuh b/cpp/include/raft/label/detail/merge_labels.cuh index f93a97d52b..166bb2122a 100644 --- a/cpp/include/raft/label/detail/merge_labels.cuh +++ b/cpp/include/raft/label/detail/merge_labels.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,13 +32,12 @@ namespace detail { * For an additional cost we can build the graph with edges * E={(A[i], B[i]) | M[i]=1} and make this step faster */ template -__global__ void __launch_bounds__(TPB_X) - propagate_label_kernel(const value_idx* __restrict__ labels_a, - const value_idx* __restrict__ labels_b, - value_idx* __restrict__ R, - const bool* __restrict__ mask, - bool* __restrict__ m, - value_idx N) +RAFT_KERNEL __launch_bounds__(TPB_X) propagate_label_kernel(const value_idx* __restrict__ labels_a, + const value_idx* __restrict__ labels_b, + value_idx* __restrict__ R, + const bool* __restrict__ mask, + bool* __restrict__ m, + value_idx N) { value_idx tid = threadIdx.x + blockIdx.x * TPB_X; if (tid < N) { @@ -65,12 +64,11 @@ __global__ void __launch_bounds__(TPB_X) } template -__global__ void __launch_bounds__(TPB_X) - reassign_label_kernel(value_idx* __restrict__ labels_a, - const value_idx* __restrict__ labels_b, - const value_idx* __restrict__ R, - value_idx N, - value_idx MAX_LABEL) +RAFT_KERNEL __launch_bounds__(TPB_X) reassign_label_kernel(value_idx* __restrict__ labels_a, + const value_idx* __restrict__ labels_b, + const value_idx* __restrict__ R, + value_idx N, + value_idx MAX_LABEL) { value_idx tid = threadIdx.x + blockIdx.x * TPB_X; if (tid < N) { diff --git a/cpp/include/raft/linalg/detail/add.cuh b/cpp/include/raft/linalg/detail/add.cuh index bf9b2bd1d8..121ac10e24 100644 --- a/cpp/include/raft/linalg/detail/add.cuh +++ b/cpp/include/raft/linalg/detail/add.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,10 +38,10 @@ void add(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t st } template -__global__ void add_dev_scalar_kernel(OutT* outDev, - const InT* inDev, - const InT* singleScalarDev, - IdxType len) +RAFT_KERNEL add_dev_scalar_kernel(OutT* outDev, + const InT* inDev, + const InT* singleScalarDev, + IdxType len) { IdxType i = ((IdxType)blockIdx.x * (IdxType)blockDim.x) + threadIdx.x; if (i < len) { outDev[i] = inDev[i] + *singleScalarDev; } diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh index 5b01196cf4..f3c150cbee 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh @@ -40,7 +40,7 @@ template -__global__ void __launch_bounds__(Policy::ThreadsPerBlock) +RAFT_KERNEL __launch_bounds__(Policy::ThreadsPerBlock) coalescedReductionThinKernel(OutType* dots, const InType* data, IdxType D, @@ -137,15 +137,15 @@ template -__global__ void __launch_bounds__(TPB) coalescedReductionMediumKernel(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - MainLambda main_op, - ReduceLambda reduce_op, - FinalLambda final_op, - bool inplace = false) +RAFT_KERNEL __launch_bounds__(TPB) coalescedReductionMediumKernel(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda final_op, + bool inplace = false) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -225,7 +225,7 @@ template -__global__ void __launch_bounds__(Policy::ThreadsPerBlock) +RAFT_KERNEL __launch_bounds__(Policy::ThreadsPerBlock) coalescedReductionThickKernel(OutType* buffer, const InType* data, IdxType D, diff --git a/cpp/include/raft/linalg/detail/map.cuh b/cpp/include/raft/linalg/detail/map.cuh index 0c79dec248..4ff3aa9754 100644 --- a/cpp/include/raft/linalg/detail/map.cuh +++ b/cpp/include/raft/linalg/detail/map.cuh @@ -65,7 +65,7 @@ __device__ __forceinline__ void map_kernel_mainloop( } template -__global__ void map_kernel(OutT* out_ptr, IdxT len, Func f, const InTs*... in_ptrs) +RAFT_KERNEL map_kernel(OutT* out_ptr, IdxT len, Func f, const InTs*... in_ptrs) { const IdxT tid = blockIdx.x * blockDim.x + threadIdx.x; if constexpr (R <= 1) { diff --git a/cpp/include/raft/linalg/detail/map_then_reduce.cuh b/cpp/include/raft/linalg/detail/map_then_reduce.cuh index 6fae16117f..d1e211f8d2 100644 --- a/cpp/include/raft/linalg/detail/map_then_reduce.cuh +++ b/cpp/include/raft/linalg/detail/map_then_reduce.cuh @@ -52,13 +52,13 @@ template -__global__ void mapThenReduceKernel(OutType* out, - IdxType len, - OutType neutral, - MapOp map, - ReduceLambda op, - const InType* in, - Args... args) +RAFT_KERNEL mapThenReduceKernel(OutType* out, + IdxType len, + OutType neutral, + MapOp map, + ReduceLambda op, + const InType* in, + Args... args) { OutType acc = neutral; auto idx = (threadIdx.x + (blockIdx.x * blockDim.x)); diff --git a/cpp/include/raft/linalg/detail/normalize.cuh b/cpp/include/raft/linalg/detail/normalize.cuh index 78c773ab35..d1ca4816e5 100644 --- a/cpp/include/raft/linalg/detail/normalize.cuh +++ b/cpp/include/raft/linalg/detail/normalize.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ template -__global__ void __launch_bounds__(Policy::ThreadsPerBlock) +RAFT_KERNEL __launch_bounds__(Policy::ThreadsPerBlock) coalesced_normalize_thin_kernel(Type* out, const Type* in, IdxType D, @@ -92,15 +92,15 @@ template -__global__ void __launch_bounds__(TPB) coalesced_normalize_medium_kernel(Type* out, - const Type* in, - IdxType D, - IdxType N, - Type init, - MainLambda main_op, - ReduceLambda reduce_op, - FinalLambda fin_op, - Type eps) +RAFT_KERNEL __launch_bounds__(TPB) coalesced_normalize_medium_kernel(Type* out, + const Type* in, + IdxType D, + IdxType N, + Type init, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda fin_op, + Type eps) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; diff --git a/cpp/include/raft/linalg/detail/reduce_cols_by_key.cuh b/cpp/include/raft/linalg/detail/reduce_cols_by_key.cuh index a85e04acca..b726e3ea5a 100644 --- a/cpp/include/raft/linalg/detail/reduce_cols_by_key.cuh +++ b/cpp/include/raft/linalg/detail/reduce_cols_by_key.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ namespace detail { ///@todo: specialize this to support shared-mem based atomics template -__global__ void reduce_cols_by_key_direct_kernel( +RAFT_KERNEL reduce_cols_by_key_direct_kernel( const T* data, const KeyIteratorT keys, T* out, IdxType nrows, IdxType ncols, IdxType nkeys) { typedef typename std::iterator_traits::value_type KeyType; @@ -44,7 +44,7 @@ __global__ void reduce_cols_by_key_direct_kernel( } template -__global__ void reduce_cols_by_key_cached_kernel( +RAFT_KERNEL reduce_cols_by_key_cached_kernel( const T* data, const KeyIteratorT keys, T* out, IdxType nrows, IdxType ncols, IdxType nkeys) { typedef typename std::iterator_traits::value_type KeyType; diff --git a/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh b/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh index 572d6b738c..ce11825e12 100644 --- a/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh +++ b/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ namespace detail { // template -void __global__ convert_array_kernel(IteratorT1 dst, IteratorT2 src, int n) +RAFT_KERNEL convert_array_kernel(IteratorT1 dst, IteratorT2 src, int n) { for (int idx = blockDim.x * blockIdx.x + threadIdx.x; idx < n; idx += gridDim.x * blockDim.x) { dst[idx] = src[idx]; @@ -95,14 +95,14 @@ struct quadSum { template __launch_bounds__(SUM_ROWS_SMALL_K_DIMX, 4) - __global__ void sum_rows_by_key_small_nkeys_kernel(const DataIteratorT d_A, - IdxT lda, - const char* d_keys, - const WeightT* d_weights, - IdxT nrows, - IdxT ncols, - IdxT nkeys, - SumsT* d_sums) + RAFT_KERNEL sum_rows_by_key_small_nkeys_kernel(const DataIteratorT d_A, + IdxT lda, + const char* d_keys, + const WeightT* d_weights, + IdxT nrows, + IdxT ncols, + IdxT nkeys, + SumsT* d_sums) { typedef typename std::iterator_traits::value_type DataType; typedef cub::BlockReduce, SUM_ROWS_SMALL_K_DIMX> BlockReduce; @@ -193,15 +193,15 @@ template -__global__ void sum_rows_by_key_large_nkeys_kernel_colmajor(const DataIteratorT d_A, - IdxT lda, - KeysIteratorT d_keys, - const WeightT* d_weights, - IdxT nrows, - IdxT ncols, - int key_offset, - IdxT nkeys, - SumsT* d_sums) +RAFT_KERNEL sum_rows_by_key_large_nkeys_kernel_colmajor(const DataIteratorT d_A, + IdxT lda, + KeysIteratorT d_keys, + const WeightT* d_weights, + IdxT nrows, + IdxT ncols, + int key_offset, + IdxT nkeys, + SumsT* d_sums) { typedef typename std::iterator_traits::value_type KeyType; typedef typename std::iterator_traits::value_type DataType; @@ -269,13 +269,13 @@ template -__global__ void sum_rows_by_key_large_nkeys_kernel_rowmajor(const DataIteratorT d_A, - IdxT lda, - const WeightT* d_weights, - KeysIteratorT d_keys, - IdxT nrows, - IdxT ncols, - SumsT* d_sums) +RAFT_KERNEL sum_rows_by_key_large_nkeys_kernel_rowmajor(const DataIteratorT d_A, + IdxT lda, + const WeightT* d_weights, + KeysIteratorT d_keys, + IdxT nrows, + IdxT ncols, + SumsT* d_sums) { IdxT gid = threadIdx.x + (blockDim.x * static_cast(blockIdx.x)); IdxT j = gid % ncols; diff --git a/cpp/include/raft/linalg/detail/strided_reduction.cuh b/cpp/include/raft/linalg/detail/strided_reduction.cuh index 42e79a9285..aef346bd4b 100644 --- a/cpp/include/raft/linalg/detail/strided_reduction.cuh +++ b/cpp/include/raft/linalg/detail/strided_reduction.cuh @@ -30,7 +30,7 @@ namespace detail { // of the matrix, i.e. reduce along columns for row major or reduce along rows // for column major layout template -__global__ void stridedSummationKernel( +RAFT_KERNEL stridedSummationKernel( Type* dots, const Type* data, int D, int N, Type init, MainLambda main_op) { // Thread reduction @@ -68,13 +68,13 @@ template -__global__ void stridedReductionKernel(OutType* dots, - const InType* data, - int D, - int N, - OutType init, - MainLambda main_op, - ReduceLambda reduce_op) +RAFT_KERNEL stridedReductionKernel(OutType* dots, + const InType* data, + int D, + int N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op) { // Thread reduction OutType thread_data = init; diff --git a/cpp/include/raft/linalg/detail/subtract.cuh b/cpp/include/raft/linalg/detail/subtract.cuh index 6df09df8ed..6519d58fa1 100644 --- a/cpp/include/raft/linalg/detail/subtract.cuh +++ b/cpp/include/raft/linalg/detail/subtract.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,10 +38,10 @@ void subtract(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream } template -__global__ void subtract_dev_scalar_kernel(math_t* outDev, - const math_t* inDev, - const math_t* singleScalarDev, - IdxType len) +RAFT_KERNEL subtract_dev_scalar_kernel(math_t* outDev, + const math_t* inDev, + const math_t* singleScalarDev, + IdxType len) { // TODO: kernel do not use shared memory in current implementation int i = ((IdxType)blockIdx.x * (IdxType)blockDim.x) + threadIdx.x; diff --git a/cpp/include/raft/matrix/detail/columnWiseSort.cuh b/cpp/include/raft/matrix/detail/columnWiseSort.cuh index 5df7ba3cdc..652c4fda0f 100644 --- a/cpp/include/raft/matrix/detail/columnWiseSort.cuh +++ b/cpp/include/raft/matrix/detail/columnWiseSort.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,7 +54,7 @@ struct SmemPerBlock { }; template -__global__ void devLayoutIdx(InType* in, int n_cols, int totalElements) +RAFT_KERNEL devLayoutIdx(InType* in, int n_cols, int totalElements) { int idx = threadIdx.x + blockDim.x * blockIdx.x; int n = n_cols; @@ -63,7 +63,7 @@ __global__ void devLayoutIdx(InType* in, int n_cols, int totalElements) } template -__global__ void devOffsetKernel(T* in, T value, int n_times) +RAFT_KERNEL devOffsetKernel(T* in, T value, int n_times) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < n_times) in[idx] = idx * value; @@ -76,12 +76,12 @@ template < int BLOCK_SIZE, int ITEMS_PER_THREAD, typename std::enable_if::IsValid, InType>::type* = nullptr> -__global__ void __launch_bounds__(1024, 1) devKeyValSortColumnPerRow(const InType* inputKeys, - InType* outputKeys, - OutType* inputVals, - int n_rows, - int n_cols, - InType MAX_VALUE) +RAFT_KERNEL __launch_bounds__(1024, 1) devKeyValSortColumnPerRow(const InType* inputKeys, + InType* outputKeys, + OutType* inputVals, + int n_rows, + int n_cols, + InType MAX_VALUE) { typedef cub::BlockLoad BlockLoadTypeKey; @@ -124,12 +124,12 @@ template < int BLOCK_SIZE, int ITEMS_PER_THREAD, typename std::enable_if::IsValid), InType>::type* = nullptr> -__global__ void devKeyValSortColumnPerRow(const InType* inputKeys, - InType* outputKeys, - OutType* inputVals, - int n_rows, - int n_cols, - InType MAX_VALUE) +RAFT_KERNEL devKeyValSortColumnPerRow(const InType* inputKeys, + InType* outputKeys, + OutType* inputVals, + int n_rows, + int n_cols, + InType MAX_VALUE) { // place holder function // so that compiler unrolls for all template types successfully diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 59fcf606c8..73072ec841 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -47,14 +47,14 @@ template -__global__ void gather_kernel(const InputIteratorT in, - IndexT D, - IndexT len, - const MapIteratorT map, - StencilIteratorT stencil, - OutputIteratorT out, - PredicateOp pred_op, - MapTransformOp transform_op) +RAFT_KERNEL gather_kernel(const InputIteratorT in, + IndexT D, + IndexT len, + const MapIteratorT map, + StencilIteratorT stencil, + OutputIteratorT out, + PredicateOp pred_op, + MapTransformOp transform_op) { typedef typename std::iterator_traits::value_type MapValueT; typedef typename std::iterator_traits::value_type StencilValueT; diff --git a/cpp/include/raft/matrix/detail/linewise_op.cuh b/cpp/include/raft/matrix/detail/linewise_op.cuh index 514d0dc51b..6061fe6aee 100644 --- a/cpp/include/raft/matrix/detail/linewise_op.cuh +++ b/cpp/include/raft/matrix/detail/linewise_op.cuh @@ -260,7 +260,7 @@ template -__global__ void __launch_bounds__(BlockSize) +RAFT_KERNEL __launch_bounds__(BlockSize) matrixLinewiseVecColsMainKernel(Type* out, const Type* in, const IdxType arrOffset, @@ -304,15 +304,14 @@ __global__ void __launch_bounds__(BlockSize) * @param [in] vecs pointers to the argument vectors */ template -__global__ void __launch_bounds__(MaxOffset, 2) - matrixLinewiseVecColsTailKernel(Type* out, - const Type* in, - const IdxType arrOffset, - const IdxType arrTail, - const IdxType rowLen, - const IdxType len, - Lambda op, - const Vecs*... vecs) +RAFT_KERNEL __launch_bounds__(MaxOffset, 2) matrixLinewiseVecColsTailKernel(Type* out, + const Type* in, + const IdxType arrOffset, + const IdxType arrTail, + const IdxType rowLen, + const IdxType len, + Lambda op, + const Vecs*... vecs) { // Note, L::VecElems == 1 typedef Linewise L; @@ -370,14 +369,13 @@ template -__global__ void __launch_bounds__(BlockSize) - matrixLinewiseVecRowsMainKernel(Type* out, - const Type* in, - const IdxType arrOffset, - const IdxType rowLen, - const IdxType len, - Lambda op, - const Vecs*... vecs) +RAFT_KERNEL __launch_bounds__(BlockSize) matrixLinewiseVecRowsMainKernel(Type* out, + const Type* in, + const IdxType arrOffset, + const IdxType rowLen, + const IdxType len, + Lambda op, + const Vecs*... vecs) { typedef Linewise L; constexpr uint workSize = L::VecElems * BlockSize; @@ -413,14 +411,13 @@ template -__global__ void __launch_bounds__(BlockSize) - matrixLinewiseVecRowsSpanKernel(Type* out, - const Type* in, - const IdxType rowLen, - const IdxType rowLenPadded, - const IdxType lenPadded, - Lambda op, - const Vecs*... vecs) +RAFT_KERNEL __launch_bounds__(BlockSize) matrixLinewiseVecRowsSpanKernel(Type* out, + const Type* in, + const IdxType rowLen, + const IdxType rowLenPadded, + const IdxType lenPadded, + Lambda op, + const Vecs*... vecs) { typedef Linewise L; constexpr uint workSize = L::VecElems * BlockSize; @@ -457,15 +454,14 @@ __global__ void __launch_bounds__(BlockSize) * @param [in] vecs pointers to the argument vectors */ template -__global__ void __launch_bounds__(MaxOffset, 2) - matrixLinewiseVecRowsTailKernel(Type* out, - const Type* in, - const IdxType arrOffset, - const IdxType arrTail, - const IdxType rowLen, - const IdxType len, - Lambda op, - const Vecs*... vecs) +RAFT_KERNEL __launch_bounds__(MaxOffset, 2) matrixLinewiseVecRowsTailKernel(Type* out, + const Type* in, + const IdxType arrOffset, + const IdxType arrTail, + const IdxType rowLen, + const IdxType len, + Lambda op, + const Vecs*... vecs) { // Note, L::VecElems == 1 constexpr uint workSize = MaxOffset; diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index d2707e1254..9e9d7f8b3b 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -331,7 +331,7 @@ void matrixVectorBinarySub(Type* data, // Computes an argmin/argmax column-wise in a DxN matrix template -__global__ void argReduceKernel(const T* d_in, IdxT D, IdxT N, OutT* out) +RAFT_KERNEL argReduceKernel(const T* d_in, IdxT D, IdxT N, OutT* out) { typedef cub:: BlockReduce, TPB, cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY> @@ -396,7 +396,7 @@ void argmax(const math_t* in, idx_t D, idx_t N, out_t* out, cudaStream_t stream) // Computes the argmax(abs(d_in)) column-wise in a DxN matrix followed by // flipping the sign if the |max| value for each column is negative. template -__global__ void signFlipKernel(T* d_in, int D, int N) +RAFT_KERNEL signFlipKernel(T* d_in, int D, int N) { typedef cub::BlockReduce, TPB> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index 48821df5b2..2fa741fd96 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -169,8 +169,7 @@ void printHost(const m_t* in, idx_t n_rows, idx_t n_cols) * (1-based) */ template -__global__ void slice( - const m_t* src_d, idx_t lda, m_t* dst_d, idx_t x1, idx_t y1, idx_t x2, idx_t y2) +RAFT_KERNEL slice(const m_t* src_d, idx_t lda, m_t* dst_d, idx_t x1, idx_t y1, idx_t x2, idx_t y2) { idx_t idx = threadIdx.x + blockDim.x * blockIdx.x; idx_t dm = x2 - x1, dn = y2 - y1; @@ -211,7 +210,7 @@ void sliceMatrix(const m_t* in, * @param k: min(n_rows, n_cols) */ template -__global__ void getUpperTriangular(const m_t* src, m_t* dst, idx_t n_rows, idx_t n_cols, idx_t k) +RAFT_KERNEL getUpperTriangular(const m_t* src, m_t* dst, idx_t n_rows, idx_t n_cols, idx_t k) { idx_t idx = threadIdx.x + blockDim.x * blockIdx.x; idx_t m = n_rows, n = n_cols; @@ -239,7 +238,7 @@ void copyUpperTriangular(const m_t* src, m_t* dst, idx_t n_rows, idx_t n_cols, c * @param k: dimensionality */ template -__global__ void copyVectorToMatrixDiagonal(const m_t* vec, m_t* matrix, idx_t lda, idx_t k) +RAFT_KERNEL copyVectorToMatrixDiagonal(const m_t* vec, m_t* matrix, idx_t lda, idx_t k) { idx_t idx = threadIdx.x + blockDim.x * blockIdx.x; @@ -254,7 +253,7 @@ __global__ void copyVectorToMatrixDiagonal(const m_t* vec, m_t* matrix, idx_t ld * @param k: dimensionality */ template -__global__ void copyVectorFromMatrixDiagonal(m_t* vec, const m_t* matrix, idx_t lda, idx_t k) +RAFT_KERNEL copyVectorFromMatrixDiagonal(m_t* vec, const m_t* matrix, idx_t lda, idx_t k) { idx_t idx = threadIdx.x + blockDim.x * blockIdx.x; @@ -290,7 +289,7 @@ void getDiagonalMatrix( * @param len: size of one side of the matrix */ template -__global__ void matrixDiagonalInverse(m_t* in, idx_t len) +RAFT_KERNEL matrixDiagonalInverse(m_t* in, idx_t len) { idx_t idx = threadIdx.x + blockDim.x * blockIdx.x; if (idx < len) { in[idx + idx * len] = 1.0 / in[idx + idx * len]; } diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index edde924892..b3c07b9d3a 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -422,16 +422,16 @@ _RAFT_DEVICE void last_filter(const T* in_buf, } template -__global__ void last_filter_kernel(const T* in, - const IdxT* in_idx, - const T* in_buf, - const IdxT* in_idx_buf, - T* out, - IdxT* out_idx, - IdxT len, - IdxT k, - Counter* counters, - const bool select_min) +RAFT_KERNEL last_filter_kernel(const T* in, + const IdxT* in_idx, + const T* in_buf, + const IdxT* in_idx_buf, + T* out, + IdxT* out_idx, + IdxT len, + IdxT k, + Counter* counters, + const bool select_min) { const size_t batch_id = blockIdx.y; // size_t to avoid multiplication overflow @@ -525,20 +525,20 @@ __global__ void last_filter_kernel(const T* in, * their indices. */ template -__global__ void radix_kernel(const T* in, - const IdxT* in_idx, - const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - Counter* counters, - IdxT* histograms, - const IdxT len, - const IdxT k, - const bool select_min, - const int pass) +RAFT_KERNEL radix_kernel(const T* in, + const IdxT* in_idx, + const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counters, + IdxT* histograms, + const IdxT len, + const IdxT k, + const bool select_min, + const int pass) { const size_t batch_id = blockIdx.y; auto counter = counters + batch_id; @@ -920,17 +920,17 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, } template -__global__ void radix_topk_one_block_kernel(const T* in, - const IdxT* in_idx, - const IdxT len, - const IdxT k, - T* out, - IdxT* out_idx, - const bool select_min, - T* buf1, - IdxT* idx_buf1, - T* buf2, - IdxT* idx_buf2) +RAFT_KERNEL radix_topk_one_block_kernel(const T* in, + const IdxT* in_idx, + const IdxT len, + const IdxT k, + T* out, + IdxT* out_idx, + const bool select_min, + T* buf1, + IdxT* idx_buf1, + T* buf2, + IdxT* idx_buf2) { constexpr int num_buckets = calc_num_buckets(); __shared__ Counter counter; diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index 2927604e7d..0ee87de4f7 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -56,7 +56,7 @@ the top-k result. Example: - __global__ void kernel() { + RAFT_KERNEL kernel() { block_sort queue(...); for (IdxT i = threadIdx.x; i < len, i += blockDim.x) { @@ -80,7 +80,7 @@ (see the usage of LaunchThreshold::len_factor_for_choosing). Example: - __global__ void kernel() { + RAFT_KERNEL kernel() { warp_sort_immediate<...> queue(...); int warp_id = threadIdx.x / WarpSize; int lane_id = threadIdx.x % WarpSize; @@ -750,8 +750,8 @@ template