Skip to content

Commit

Permalink
Provide a raft::copy overload for mdspan-to-mdspan copies (#1818)
Browse files Browse the repository at this point in the history
# Purpose
This PR provides a utility for copying between generic mdspans. This includes between host and device, between mdspans of different layouts, and between mdspans of different (convertible) data types

## API
`raft::copy(raft_resources, dest_mdspan, src_mdspan);`

# Limitations

- Currently does not support copies between mdspans on two different GPUs
- Currently not performant for generic host-to-host copies (would be much easier to optimize with submdspan for padded layouts)
- Submdspan with padded layouts would also make it easier to improve perf of some device-to-device copies, though perf should already be quite good for most device-to-device copies.

# Design

- Includes optional `RAFT_DISABLE_CUDA` build definition in order to use this utility in CUDA-free builds (important for use in the FIL backend for Triton)
- Includes a new `raft::stream_view` object which is a thin wrapper around `rmm::stream_view`. Its purpose is solely to provide a symbol that will be defined in CUDA-free builds and which will throw exceptions or log error messages if someone tries to use a CUDA stream in a CUDA-free build. This avoids a whole bunch of ifdefs that would otherwise infect the whole codebase.
- Uses (roughly in order of preference): `cudaMemcpyAsync, std::copy, cublas, custom device kernel, custom host-to-host transfer logic` for the underlying copy
- Provides two different headers: `raft/core/copy.hpp` and `raft/core/copy.cuh`. This is to accommodate the custom kernel necessary for handling completely generic device-to-device copies. See below for more details.

## Details on the header split
For many instantiations, even those which involve the device, we do not require nvcc compilation. If, however, we determine at compilation time that we must use a custom kernel for the copy, then we must invoke nvcc. We do not wish to indicate that a public header file is a C++ header when it is a CUDA header or vice versa, so we split the definitions into separate `hpp` and `cuh` files, with all template instantiations requiring the custom kernel enable-if'd out of the hpp file.

Thus, the cuh header can be used for _any_ mdspan-to-mdspan copy, but the hpp file will not compile for those specific instantiations that require a custom kernel. The recommended workflow is that if a `cpp` file requires an mdspan-to-mdspan copy, first try the `hpp` header. If that fails, the `cpp` file must be converted to a `cu` file, and the `cuh` header should be used. For source files that are already being compiled with nvcc (i.e. `.cu` files), the `cuh` header might as well be used and will not result in any additional compile time penalty.

# Remaining tasks to leave WIP status

- [x] Add benchmarks for copies
- [x] Ensure that new function is correctly added to docs

# Follow-up items

- Optimize host-to-host transfers using a cache-oblivious approach with SIMD-accelerated transposes for contiguous memory
- Test cache-oblivious device-to-device transfers and compare performance
- Provide transparent support for copies between devices.

## Relationship to mdbuffer
This utility encapsulates a substantial chunk of the core logic required for the mdbuffer implementation. It is being split into its own PR both because it is useful on its own and because the mdbuffer work has been delayed by higher priority tasks.

Close #1779

Authors:
  - William Hicks (https://github.com/wphicks)
  - Tarang Jain (https://github.com/tarang-jain)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #1818
  • Loading branch information
wphicks authored Oct 6, 2023
1 parent 1eff78b commit c735ecb
Show file tree
Hide file tree
Showing 14 changed files with 2,129 additions and 4 deletions.
6 changes: 5 additions & 1 deletion cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ function(ConfigureBench)
PRIVATE raft::raft
raft_internal
$<$<BOOL:${ConfigureBench_LIB}>:raft::compiled>
${RAFT_CTK_MATH_DEPENDENCIES}
benchmark::benchmark
Threads::Threads
$<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX>
Expand Down Expand Up @@ -73,11 +74,14 @@ function(ConfigureBench)
endfunction()

if(BUILD_PRIMS_BENCH)
ConfigureBench(
NAME CORE_BENCH PATH bench/prims/core/bitset.cu bench/prims/core/copy.cu bench/prims/main.cpp
)

ConfigureBench(
NAME CLUSTER_BENCH PATH bench/prims/cluster/kmeans_balanced.cu bench/prims/cluster/kmeans.cu
bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY
)
ConfigureBench(NAME CORE_BENCH PATH bench/prims/core/bitset.cu bench/prims/main.cpp)

ConfigureBench(
NAME TUNE_DISTANCE PATH bench/prims/distance/tune_pairwise/kernel.cu
Expand Down
401 changes: 401 additions & 0 deletions cpp/bench/prims/core/copy.cu

Large diffs are not rendered by default.

74 changes: 74 additions & 0 deletions cpp/include/raft/core/copy.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <raft/core/detail/copy.hpp>
namespace raft {
/**
* @brief Copy data from one mdspan to another with the same extents
*
* This function copies data from one mdspan to another, regardless of whether
* or not the mdspans have the same layout, memory type (host/device/managed)
* or data type. So long as it is possible to convert the data type from source
* to destination, and the extents are equal, this function should be able to
* perform the copy. Any necessary device operations will be stream-ordered via the CUDA stream
* provided by the `raft::resources` argument.
*
* This header includes a custom kernel used for copying data between
* completely arbitrary mdspans on device. To compile this function in a
* non-CUDA translation unit, `raft/core/copy.hpp` may be used instead. The
* pure C++ header will correctly compile even without a CUDA compiler.
* Depending on the specialization, this CUDA header may invoke the kernel and
* therefore require a CUDA compiler.
*
* Limitations: Currently this function does not support copying directly
* between two arbitrary mdspans on different CUDA devices. It is assumed that the caller sets the
* correct CUDA device. Furthermore, host-to-host copies that require a transformation of the
* underlying memory layout are currently not performant, although they are supported.
*
* Note that when copying to an mdspan with a non-unique layout (i.e. the same
* underlying memory is addressed by different element indexes), the source
* data must contain non-unique values for every non-unique destination
* element. If this is not the case, the behavior is undefined. Some copies
* to non-unique layouts which are well-defined will nevertheless fail with an
* exception to avoid race conditions in the underlying copy.
*
* @tparam DstType An mdspan type for the destination container.
* @tparam SrcType An mdspan type for the source container
* @param res raft::resources used to provide a stream for copies involving the
* device.
* @param dst The destination mdspan.
* @param src The source mdspan.
*/
template <typename DstType, typename SrcType>
detail::mdspan_copyable_with_kernel_t<DstType, SrcType> copy(resources const& res,
DstType&& dst,
SrcType&& src)
{
detail::copy(res, std::forward<DstType>(dst), std::forward<SrcType>(src));
}

#ifndef RAFT_NON_CUDA_COPY_IMPLEMENTED
#define RAFT_NON_CUDA_COPY_IMPLEMENTED
template <typename DstType, typename SrcType>
detail::mdspan_copyable_not_with_kernel_t<DstType, SrcType> copy(resources const& res,
DstType&& dst,
SrcType&& src)
{
detail::copy(res, std::forward<DstType>(dst), std::forward<SrcType>(src));
}
#endif
} // namespace raft
69 changes: 69 additions & 0 deletions cpp/include/raft/core/copy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <raft/core/detail/copy.hpp>
namespace raft {

#ifndef RAFT_NON_CUDA_COPY_IMPLEMENTED
#define RAFT_NON_CUDA_COPY_IMPLEMENTED
/**
* @brief Copy data from one mdspan to another with the same extents
*
* This function copies data from one mdspan to another, regardless of whether
* or not the mdspans have the same layout, memory type (host/device/managed)
* or data type. So long as it is possible to convert the data type from source
* to destination, and the extents are equal, this function should be able to
* perform the copy.
*
* This header does _not_ include the custom kernel used for copying data
* between completely arbitrary mdspans on device. For arbitrary copies of this
* kind, `#include <raft/core/copy.cuh>` instead. Specializations of this
* function that require the custom kernel will be SFINAE-omitted when this
* header is used instead of `copy.cuh`. This header _does_ support
* device-to-device copies that can be performed with cuBLAS or a
* straightforward cudaMemcpy. Any necessary device operations will be stream-ordered via the CUDA
* stream provided by the `raft::resources` argument.
*
* Limitations: Currently this function does not support copying directly
* between two arbitrary mdspans on different CUDA devices. It is assumed that the caller sets the
* correct CUDA device. Furthermore, host-to-host copies that require a transformation of the
* underlying memory layout are currently not performant, although they are supported.
*
* Note that when copying to an mdspan with a non-unique layout (i.e. the same
* underlying memory is addressed by different element indexes), the source
* data must contain non-unique values for every non-unique destination
* element. If this is not the case, the behavior is undefined. Some copies
* to non-unique layouts which are well-defined will nevertheless fail with an
* exception to avoid race conditions in the underlying copy.
*
* @tparam DstType An mdspan type for the destination container.
* @tparam SrcType An mdspan type for the source container
* @param res raft::resources used to provide a stream for copies involving the
* device.
* @param dst The destination mdspan.
* @param src The source mdspan.
*/
template <typename DstType, typename SrcType>
detail::mdspan_copyable_not_with_kernel_t<DstType, SrcType> copy(resources const& res,
DstType&& dst,
SrcType&& src)
{
detail::copy(res, std::forward<DstType>(dst), std::forward<SrcType>(src));
}
#endif

} // namespace raft
23 changes: 23 additions & 0 deletions cpp/include/raft/core/cuda_support.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace raft {
#ifndef RAFT_DISABLE_CUDA
auto constexpr static const CUDA_ENABLED = true;
#else
auto constexpr static const CUDA_ENABLED = false;
#endif
} // namespace raft
Loading

0 comments on commit c735ecb

Please sign in to comment.