diff --git a/cpp/include/cuvs/core/c_api.h b/cpp/include/cuvs/core/c_api.h index c8c8d3934..400d162ad 100644 --- a/cpp/include/cuvs/core/c_api.h +++ b/cpp/include/cuvs/core/c_api.h @@ -151,6 +151,22 @@ cuvsError_t cuvsRMMPoolMemoryResourceEnable(int initial_pool_size_percent, */ cuvsError_t cuvsRMMMemoryResourceReset(); +/** + * @brief Allocates pinned memory on the host using RMM + * @param[out] ptr Pointer to allocated host memory + * @param[in] bytes Size in bytes to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsRMMHostAlloc(void** ptr, size_t bytes); + +/** + * @brief Deallocates pinned memory on the host using RMM + * @param[in] ptr Pointer to allocated host memory to free + * @param[in] bytes Size in bytes to deallocate + * @return cuvsError_t + */ +cuvsError_t cuvsRMMHostFree(void* ptr, size_t bytes); + /** @} */ #ifdef __cplusplus diff --git a/cpp/src/core/c_api.cpp b/cpp/src/core/c_api.cpp index cfbeed2d5..4333bff0c 100644 --- a/cpp/src/core/c_api.cpp +++ b/cpp/src/core/c_api.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include extern "C" cuvsError_t cuvsResourcesCreate(cuvsResources_t* res) @@ -130,6 +131,21 @@ extern "C" cuvsError_t cuvsRMMMemoryResourceReset() }); } +thread_local std::unique_ptr pinned_mr; + +extern "C" cuvsError_t cuvsRMMHostAlloc(void** ptr, size_t bytes) +{ + return cuvs::core::translate_exceptions([=] { + if (pinned_mr == nullptr) { pinned_mr = std::make_unique(); } + *ptr = pinned_mr->allocate(bytes); + }); +} + +extern "C" cuvsError_t cuvsRMMHostFree(void* ptr, size_t bytes) +{ + return cuvs::core::translate_exceptions([=] { pinned_mr->deallocate(ptr, bytes); }); +} + thread_local std::string last_error_text = ""; extern "C" const char* cuvsGetLastErrorText() diff --git a/cpp/test/core/c_api.c b/cpp/test/core/c_api.c index a3dae6004..a51824d2b 100644 --- a/cpp/test/core/c_api.c +++ b/cpp/test/core/c_api.c @@ -73,6 +73,15 @@ int main() error = cuvsRMMMemoryResourceReset(); if (error == CUVS_ERROR) { exit(EXIT_FAILURE); } + // Alloc memory on host (pinned) + void* ptr3; + cuvsError_t alloc_error_pinned = cuvsRMMHostAlloc(&ptr3, 1024); + if (alloc_error_pinned == CUVS_ERROR) { exit(EXIT_FAILURE); } + + // Free memory + cuvsError_t free_error_pinned = cuvsRMMHostFree(ptr3, 1024); + if (free_error_pinned == CUVS_ERROR) { exit(EXIT_FAILURE); } + // Destroy resources error = cuvsResourcesDestroy(res); if (error == CUVS_ERROR) { exit(EXIT_FAILURE); }