-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
1,126 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# ============================================================================= | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except | ||
# in compliance with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software distributed under the License | ||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
# or implied. See the License for the specific language governing permissions and limitations under | ||
# the License. | ||
# ============================================================================= | ||
|
||
# This function finds dlpack and sets any additional necessary environment variables. | ||
function(find_and_configure_dlpack VERSION) | ||
|
||
include(${rapids-cmake-dir}/find/generate_module.cmake) | ||
rapids_find_generate_module(DLPACK HEADER_NAMES dlpack.h) | ||
|
||
rapids_cpm_find( | ||
dlpack ${VERSION} | ||
GIT_REPOSITORY https://github.com/dmlc/dlpack.git | ||
GIT_TAG v${VERSION} | ||
GIT_SHALLOW TRUE | ||
DOWNLOAD_ONLY TRUE | ||
OPTIONS "BUILD_MOCK OFF" | ||
) | ||
|
||
if(DEFINED dlpack_SOURCE_DIR) | ||
# otherwise find_package(DLPACK) will set this variable | ||
set(DLPACK_INCLUDE_DIR | ||
"${dlpack_SOURCE_DIR}/include" | ||
PARENT_SCOPE | ||
) | ||
endif() | ||
endfunction() | ||
|
||
set(CUVS_MIN_VERSION_dlpack 0.8) | ||
|
||
find_and_configure_dlpack(${CUVS_MIN_VERSION_dlpack}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <stdint.h> | ||
|
||
#include <cuda_runtime.h> | ||
|
||
/** | ||
* @defgroup c_api C API Core Types and Functions | ||
* @{ | ||
*/ | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
/** | ||
* @brief An opaque C handle for C++ type `raft::resources` | ||
* | ||
*/ | ||
typedef uintptr_t cuvsResources_t; | ||
|
||
/** | ||
* @brief An enum denoting return values for function calls | ||
* | ||
*/ | ||
typedef enum { CUVS_ERROR, CUVS_SUCCESS } cuvsError_t; | ||
|
||
/** | ||
* @brief Create an Initialized opaque C handle for C++ type `raft::resources` | ||
* | ||
* @param[in] res cuvsResources_t opaque C handle | ||
* @return cuvsError_t | ||
*/ | ||
cuvsError_t cuvsResourcesCreate(cuvsResources_t* res); | ||
|
||
/** | ||
* @brief Destroy and de-allocate opaque C handle for C++ type `raft::resources` | ||
* | ||
* @param[in] res cuvsResources_t opaque C handle | ||
* @return cuvsError_t | ||
*/ | ||
cuvsError_t cuvsResourcesDestroy(cuvsResources_t res); | ||
|
||
/** | ||
* @brief Set cudaStream_t on cuvsResources_t to queue CUDA kernels on APIs | ||
* that accept a cuvsResources_t handle | ||
* | ||
* @param[in] res cuvsResources_t opaque C handle | ||
* @param[in] stream cudaStream_t stream to queue CUDA kernels | ||
* @return cuvsError_t | ||
*/ | ||
cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
|
||
/** @} */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <raft/core/device_mdspan.hpp> | ||
#include <raft/core/error.hpp> | ||
#include <raft/core/host_mdspan.hpp> | ||
#include <raft/core/mdspan_types.hpp> | ||
|
||
#include <dlpack/dlpack.h> | ||
#include <rmm/device_buffer.hpp> | ||
#include <sys/types.h> | ||
|
||
namespace cuvs::core::detail { | ||
|
||
template <typename AccessorType> | ||
DLDevice accessor_type_to_DLDevice() | ||
{ | ||
if constexpr (AccessorType::is_host_accessible and AccessorType::is_device_accessible) { | ||
return DLDevice{kDLCUDAManaged}; | ||
} else if constexpr (AccessorType::is_device_accessible) { | ||
return DLDevice{kDLCUDA}; | ||
} else if constexpr (AccessorType::is_host_accessible) { | ||
return DLDevice{kDLCPU}; | ||
} | ||
} | ||
|
||
template <typename T> | ||
DLDataType data_type_to_DLDataType() | ||
{ | ||
uint8_t const bits{sizeof(T) * 8}; | ||
uint16_t const lanes{1}; | ||
if constexpr (std::is_floating_point_v<T>) { | ||
return DLDataType{kDLFloat, bits, lanes}; | ||
} else if constexpr (std::is_signed_v<T>) { | ||
return DLDataType{kDLInt, bits, lanes}; | ||
} else { | ||
return DLDataType{kDLUInt, bits, lanes}; | ||
} | ||
} | ||
|
||
bool is_dlpack_device_compatible(DLTensor tensor) | ||
{ | ||
return tensor.device.device_type == kDLCUDAManaged || tensor.device.device_type == kDLCUDAHost || | ||
tensor.device.device_type == kDLCUDA; | ||
} | ||
|
||
bool is_dlpack_host_compatible(DLTensor tensor) | ||
{ | ||
return tensor.device.device_type == kDLCUDAManaged || tensor.device.device_type == kDLCUDAHost || | ||
tensor.device.device_type == kDLCPU; | ||
} | ||
|
||
template <typename MdspanType, typename = raft::is_mdspan_t<MdspanType>> | ||
MdspanType from_dlpack(DLManagedTensor* managed_tensor) | ||
{ | ||
auto tensor = managed_tensor->dl_tensor; | ||
|
||
auto to_data_type = data_type_to_DLDataType<typename MdspanType::value_type>(); | ||
RAFT_EXPECTS(to_data_type.code == tensor.dtype.code, | ||
"code mismatch between return mdspan and DLTensor"); | ||
RAFT_EXPECTS(to_data_type.bits == tensor.dtype.bits, | ||
"bits mismatch between return mdspan and DLTensor"); | ||
RAFT_EXPECTS(to_data_type.lanes == tensor.dtype.lanes, | ||
"lanes mismatch between return mdspan and DLTensor"); | ||
RAFT_EXPECTS(tensor.dtype.lanes == 1, "More than 1 DLTensor lanes not supported"); | ||
RAFT_EXPECTS(tensor.strides == nullptr, "Strided memory layout for DLTensor not supported"); | ||
|
||
auto to_device = accessor_type_to_DLDevice<typename MdspanType::accessor_type>(); | ||
if (to_device.device_type == kDLCUDA) { | ||
RAFT_EXPECTS(is_dlpack_device_compatible(tensor), | ||
"device_type mismatch between return mdspan and DLTensor"); | ||
} else if (to_device.device_type == kDLCPU) { | ||
RAFT_EXPECTS(is_dlpack_host_compatible(tensor), | ||
"device_type mismatch between return mdspan and DLTensor"); | ||
} | ||
|
||
RAFT_EXPECTS(MdspanType::extents_type::rank() == tensor.ndim, | ||
"ndim mismatch between return mdspan and DLTensor"); | ||
|
||
// auto exts = typename MdspanType::extents_type{tensor.shape}; | ||
std::array<int64_t, MdspanType::extents_type::rank()> shape{}; | ||
for (int64_t i = 0; i < tensor.ndim; ++i) { | ||
shape[i] = tensor.shape[i]; | ||
} | ||
auto exts = typename MdspanType::extents_type{shape}; | ||
|
||
return MdspanType{reinterpret_cast<typename MdspanType::data_handle_type>(tensor.data), exts}; | ||
} | ||
|
||
} // namespace cuvs::core::detail |
Oops, something went wrong.