From d07aa811febe5dc143feee069bd90c0b1ef86e40 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 20 Feb 2024 16:22:00 +0100 Subject: [PATCH] Add `ucxx::MemoryHandle` and `ucxx::RemoteKey` C++ classes (#190) Add `ucxx::MemoryHandle` and `ucxx::RemoteKey` C++ classes, wrapping UCP memory handles (`ucp_mem_h`) and remote keys (`ucp_rkey_h`), respectively. These new classes are required to provide basic support for RMA (Remote Memory Access) operations being implemented in #166 . Authors: - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - Lawrence Mitchell (https://github.com/wence-) - Bradley Dice (https://github.com/bdice) - Robert Maynard (https://github.com/robertmaynard) URL: https://github.com/rapidsai/ucxx/pull/190 --- cpp/CMakeLists.txt | 2 + cpp/include/ucxx/api.h | 2 + cpp/include/ucxx/constructors.h | 12 ++ cpp/include/ucxx/context.h | 39 ++++++ cpp/include/ucxx/memory_handle.h | 162 +++++++++++++++++++++++ cpp/include/ucxx/remote_key.h | 212 +++++++++++++++++++++++++++++++ cpp/include/ucxx/typedefs.h | 2 + cpp/src/context.cpp | 6 + cpp/src/memory_handle.cpp | 81 ++++++++++++ cpp/src/remote_key.cpp | 141 ++++++++++++++++++++ cpp/tests/CMakeLists.txt | 1 + cpp/tests/rma.cpp | 174 +++++++++++++++++++++++++ 12 files changed, 834 insertions(+) create mode 100644 cpp/include/ucxx/memory_handle.h create mode 100644 cpp/include/ucxx/remote_key.h create mode 100644 cpp/src/memory_handle.cpp create mode 100644 cpp/src/remote_key.cpp create mode 100644 cpp/tests/rma.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index e9634c21..4f2c19a5 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -130,6 +130,8 @@ add_library( src/internal/request_am.cpp src/listener.cpp src/log.cpp + src/memory_handle.cpp + src/remote_key.cpp src/request.cpp src/request_am.cpp src/request_data.cpp diff --git a/cpp/include/ucxx/api.h b/cpp/include/ucxx/api.h index 8f8452f9..bc942ec5 100644 --- a/cpp/include/ucxx/api.h +++ b/cpp/include/ucxx/api.h @@ -16,6 +16,8 @@ #include #include #include +#include +#include #include #include #include diff --git a/cpp/include/ucxx/constructors.h b/cpp/include/ucxx/constructors.h index ab8e7d17..cc601288 100644 --- a/cpp/include/ucxx/constructors.h +++ b/cpp/include/ucxx/constructors.h @@ -18,7 +18,9 @@ class Context; class Endpoint; class Future; class Listener; +class MemoryHandle; class Notifier; +class RemoteKey; class Request; class RequestAm; class RequestStream; @@ -55,6 +57,16 @@ std::shared_ptr createWorker(std::shared_ptr context, const bool enableDelayedSubmission, const bool enableFuture); +std::shared_ptr createMemoryHandle(std::shared_ptr context, + const size_t size, + void* buffer = nullptr); + +std::shared_ptr createRemoteKeyFromMemoryHandle( + std::shared_ptr memoryHandle); + +std::shared_ptr createRemoteKeyFromSerialized(std::shared_ptr endpoint, + SerializedRemoteKey serializedRemoteKey); + // Transfers std::shared_ptr createRequestAm( std::shared_ptr endpoint, diff --git a/cpp/include/ucxx/context.h b/cpp/include/ucxx/context.h index 9f1c1890..c10b0c89 100644 --- a/cpp/include/ucxx/context.h +++ b/cpp/include/ucxx/context.h @@ -17,6 +17,7 @@ namespace ucxx { +class MemoryHandle; class Worker; /** @@ -178,6 +179,44 @@ class Context : public Component { */ std::shared_ptr createWorker(const bool enableDelayedSubmission = false, const bool enableFuture = false); + + /** + * @brief Create a new `std::shared_ptr`. + * + * Create a new `std::shared_ptr` as a child of the current + * `ucxx::Context`. The `ucxx::Context` will retain ownership of the underlying + * `ucxx::MemoryHandle` and will not be destroyed until all `ucxx::MemoryHandle` + * objects are destroyed first. + * + * The allocation requires a `size` and a `buffer`. The actual size of the allocation may + * be larger than requested, and can later be found calling the `getSize()` method. The + * `buffer` provided may be either a `nullptr`, in which case UCP will allocate a new + * memory region for it, or an already existing allocation, in which case UCP will only + * map it for RMA and it's the caller's responsibility to keep `buffer` alive until this + * object is destroyed. + * + * @code{.cpp} + * // `context` is `std::shared_ptr` + * // Allocate a 128-byte buffer with UCP. + * auto memoryHandle = context->createMemoryHandle(128, nullptr); + * + * // Map an existing 128-byte buffer with UCP. + * size_t allocationSize = 128; + * auto buffer = new uint8_t[allocationSize]; + * auto memoryHandleFromBuffer = context->createMemoryHandle( + * allocationSize * sizeof(*buffer), reinterpret_cast(buffer) + * ); + * @endcode + * + * @throws ucxx::Error if either `ucp_mem_map` or `ucp_mem_query` fail. + * + * @param[in] size the minimum size of the memory allocation + * @param[in] buffer the pointer to an existing allocation or `nullptr` to allocate a + * new memory region. + * + * @returns The `shared_ptr` object + */ + std::shared_ptr createMemoryHandle(const size_t size, void* buffer); }; } // namespace ucxx diff --git a/cpp/include/ucxx/memory_handle.h b/cpp/include/ucxx/memory_handle.h new file mode 100644 index 00000000..d38682ca --- /dev/null +++ b/cpp/include/ucxx/memory_handle.h @@ -0,0 +1,162 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once +#include +#include +#include + +#include + +#include +#include + +namespace ucxx { + +class RemoteKey; + +/** + * @brief Component holding a UCP memory handle. + * + * The UCP layer provides RMA (Remote Memory Access) to memory handles that it controls + * in form of `ucp_mem_h` object, this class encapsulates that object and provides + * methods to simplify its handling. + */ +class MemoryHandle : public Component { + private: + ucp_mem_h _handle{}; ///< The UCP handle to the memory allocation. + size_t _size{0}; ///< The actual allocation size. + uint64_t _baseAddress{0}; ///< The allocation's base address. + + /** + * @brief Private constructor of `ucxx::MemoryHandle`. + * + * This is the internal implementation of `ucxx::MemoryHandle` constructor, made private + * not to be called directly. This constructor is made private to ensure all UCXX objects + * are shared pointers and the correct lifetime management of each one. + * + * Instead the user should use one of the following: + * + * - `ucxx::Context::createMemoryHandle` + * - `ucxx::createMemoryHandle()` + * + * @throws ucxx::Error if either `ucp_mem_map` or `ucp_mem_query` fail. + * + * @param[in] context parent context where to map memory. + * @param[in] size the minimum size of the memory allocation + * @param[in] buffer the pointer to an existing allocation or `nullptr` to allocate a + * new memory region. + */ + MemoryHandle(std::shared_ptr context, const size_t size, void* buffer); + + public: + MemoryHandle() = delete; + MemoryHandle(const MemoryHandle&) = delete; + MemoryHandle& operator=(MemoryHandle const&) = delete; + MemoryHandle(MemoryHandle&& o) = delete; + MemoryHandle& operator=(MemoryHandle&& o) = delete; + + /** + * @brief Constructor for `shared_ptr`. + * + * The constructor for a `shared_ptr` object, mapping a memory buffer + * with UCP to provide RMA (Remote Memory Access) to. + * + * The allocation requires a `size` and a `buffer`. The `buffer` provided may be either + * a `nullptr`, in which case UCP will allocate a new memory region for it, or an already + * existing allocation, in which case UCP will only map it for RMA and it's the caller's + * responsibility to keep `buffer` alive until this object is destroyed. When the UCP + * allocates `buffer` (i.e., when the value passed is `nullptr`), the actual size of the + * allocation may be larger than requested, and can later be found calling the `getSize()` + * method, if a preallocated buffer is passed `getSize()` will return the same value + * specified for `size`. + * + * @code{.cpp} + * // `context` is `std::shared_ptr` + * // Allocate a 128-byte buffer with UCP. + * auto memoryHandle = context->createMemoryHandle(128, nullptr); + * + * // Equivalent to line above + * // auto memoryHandle = ucxx::createMemoryHandle(context, 128, nullptr); + * + * // Map an existing 128-byte buffer with UCP. + * size_t allocationSize = 128; + * auto buffer = new uint8_t[allocationSize]; + * auto memoryHandleFromBuffer = context->createMemoryHandle( + * allocationSize * sizeof(*buffer), reinterpret_cast(buffer) + * ); + * + * // Equivalent to line above + * // auto memoryHandleFromBuffer = ucxx::createMemoryHandle( + * // context, allocationSize * sizeof(*buffer), reinterpret_cast(buffer) + * // ); + * @endcode + * + * @throws ucxx::Error if either `ucp_mem_map` or `ucp_mem_query` fail. + * + * @param[in] context parent context where to map memory. + * @param[in] size the minimum size of the memory allocation + * @param[in] buffer the pointer to an existing allocation or `nullptr` to allocate a + * new memory region. + * + * @returns The `shared_ptr` object + */ + friend std::shared_ptr createMemoryHandle(std::shared_ptr context, + const size_t size, + void* buffer); + + ~MemoryHandle(); + + /** + * @brief Get the underlying `ucp_mem_h` handle. + * + * Lifetime of the `ucp_mem_h` handle is managed by the `ucxx::MemoryHandle` object and + * its ownership is non-transferrable. Once the `ucxx::MemoryHandle` is destroyed the + * memory is unmapped and the handle is not valid anymore, it is the user's responsibility + * to ensure the owner's lifetime while using the handle. + * + * @code{.cpp} + * // memoryHandle is `std::shared_ptr` + * ucp_mem_h ucpMemoryHandle = memoryHandle->getHandle(); + * @endcode + * + * @returns The underlying `ucp_mem_h` handle. + */ + ucp_mem_h getHandle(); + + /** + * @brief Get the size of the memory allocation. + * + * Get the size of the memory allocation, which is at least the number of bytes specified + * with the `size` argument passed to `createMemoryHandle()`. + * + * @code{.cpp} + * // memoryHandle is `std::shared_ptr` + * auto memorySize = memoryHandle->getSize(); + * @endcode + * + * @returns The size of the memory allocation. + */ + size_t getSize() const; + + /** + * @brief Get the base address of the memory allocation. + * + * Get the base address of the memory allocation, which is going to be used as the remote + * address to put or get memory via the `ucxx::Endpoint::memPut()` or + * `ucxx::Endpoint::memGet()` methods. + * + * @code{.cpp} + * // memoryHandle is `std::shared_ptr` + * auto memoryBase Address = memoryHandle->getBaseAddress(); + * @endcode + * + * @returns The base address of the memory allocation. + */ + uint64_t getBaseAddress(); + + std::shared_ptr createRemoteKey(); +}; + +} // namespace ucxx diff --git a/cpp/include/ucxx/remote_key.h b/cpp/include/ucxx/remote_key.h new file mode 100644 index 00000000..4a9f8a2e --- /dev/null +++ b/cpp/include/ucxx/remote_key.h @@ -0,0 +1,212 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace ucxx { + +typedef size_t SerializedRemoteKeyHash; + +/** + * @brief Component holding a UCP rkey (remote key). + * + * To provide RMA (Remote Memory Access) to memory handles, UCP packs their information + * in the form of `ucp_rkey_h` (remote key, or rkey for short). This class encapsulates + * that object and provides methods to simplify its handling, both locally and remotely + * including (de-)serialization for transfers over the wire and reconstruction of the + * object on the remote process. + */ +class RemoteKey : public Component { + private: + ucp_rkey_h _remoteKey{nullptr}; ///< The unpacked remote key. + void* _packedRemoteKey{ + nullptr}; ///< The packed ucp_rkey_h key, suitable for transfer to a remote process. + size_t _packedRemoteKeySize{0}; ///< The size in bytes of the remote key. + std::vector _packedRemoteKeyVector{}; ///< The deserialized packed remote key. + uint64_t _memoryBaseAddress{0}; ///< The allocation's base address. + size_t _memorySize{0}; ///< The actual allocation size. + + /** + * @brief Private constructor of `ucxx::RemoteKey`. + * + * This is the internal implementation of `ucxx::RemoteKey` constructor from a local + * `std::shared_ptr`, made private not to be called directly. This + * constructor is made private to ensure all UCXX objects are shared pointers and the + * correct lifetime management of each one. + * + * Instead the user should use one of the following: + * + * - `ucxx::MemoryHandle::createRemoteKey()` + * - `ucxx::createRemoteKeyFromMemoryHandle()` + * + * @param[in] memoryHandle the memory handle mapped on the local process. + */ + explicit RemoteKey(std::shared_ptr memoryHandle); + + /** + * @brief Private constructor of `ucxx::RemoteKey`. + * + * This is the internal implementation of `ucxx::RemoteKey` constructor from a remote + * `std::shared_ptr`, made private not to be called directly. This + * constructor is made private to ensure all UCXX objects are shared pointers and the + * correct lifetime management of each one. + * + * Instead the user should use one of the following: + * + * - `ucxx::createRemoteKeyFromSerialized()` + * + * @param[in] endpoint the `std::shared_ptr` parent component. + * @param[in] serializedRemoteKey the remote key that was serialized by the owner of + * the memory handle and transferred over-the-wire for + * reconstruction and remote access. + */ + RemoteKey(std::shared_ptr endpoint, SerializedRemoteKey serializedRemoteKey); + + /** + * @brief Deserialize and reconstruct the remote key. + * + * Deserialize the remote key that was serialized with `ucxx::RemoteKey::serialize()` and + * possibly transferred over-the-wire and reconstruct the object to allow remote access. + * + * @code{.cpp} + * // remoteKey is `std::shared_ptr` + * auto serializedRemoteKey = remoteKey->serialize(); + * @endcode + * + * @throws std::runtime_error if checksum of the serialized object fails. + * + * @returns The deserialized remote key. + */ + void deserialize(const SerializedRemoteKey& serializedHeader); + + public: + /** + * @brief Constructor for `std::shared_ptr` from local memory handle. + * + * The constructor for a `std::shared_ptr` object from a local + * `std::shared_ptr`, mapping a local memory buffer to be made + * accessible from a remote endpoint to perform RMA (Remote Memory Access) on the memory. + * + * @code{.cpp} + * // `memoryHandle` is `std::shared_ptr` + * auto remoteKey = memoryHandle->createRemoteKey(); + * + * // Equivalent to line above + * // auto remoteKey = ucxx::createRemoteKeyFromMemoryHandle(memoryHandle); + * @endcode + * + * @throws ucxx::Error if `ucp_rkey_pack` fails. + * + * @param[in] memoryHandle the memory handle mapped on the local process. + * + * @returns The `shared_ptr` object + */ + friend std::shared_ptr createRemoteKeyFromMemoryHandle( + std::shared_ptr memoryHandle); + + /** + * @brief Constructor for `std::shared_ptr` from remote. + * + * The constructor for a `std::shared_ptr` object from a serialized + * `std::shared_ptr`, mapping a remote memory buffer to be made + * accessible via a local endpoint to perform RMA (Remote Memory Access) on the memory. + * + * @code{.cpp} + * // `serializedRemoteKey` is `ucxx::SerializedRemoteKey>`, created on a remote worker + * // after a call to `ucxx::RemoteKey::serialize()` and transferred over-the-wire. + * auto remoteKey = ucxx::createRemoteKeyFromSerialized(serializedRemoteKey); + * + * // Equivalent to line above + * // auto remoteKey = ucxx::createRemoteKeyFromMemoryHandle(memoryHandle); + * @endcode + * + * @throws ucxx::Error if `ucp_ep_rkey_unpack` fails. + * + * @param[in] endpoint the `std::shared_ptr` parent component. + * @param[in] serializedRemoteKey the remote key that was serialized by the owner of + * the memory handle and transferred over-the-wire for + * reconstruction and remote access. + * + * @returns The `shared_ptr` object + */ + friend std::shared_ptr createRemoteKeyFromSerialized( + std::shared_ptr endpoint, SerializedRemoteKey serializedRemoteKey); + + ~RemoteKey(); + + /** + * @brief Get the underlying `ucp_rkey_h` handle. + * + * Lifetime of the `ucp_rkey_h` handle is managed by the `ucxx::RemoteKey` object and + * its ownership is non-transferrable. Once the `ucxx::RemoteKey` is destroyed the handle + * becomes invalid and so does the address to the remote memory handle it points to, it is + * the user's responsibility to ensure the owner's lifetime while using the handle. + * + * @code{.cpp} + * // remoteKey is `std::shared_ptr` + * auto remoteKeyHandle = remoteKey->getHandle(); + * @endcode + * + * @returns The underlying `ucp_mem_h` handle. + */ + ucp_rkey_h getHandle(); + + /** + * @brief Get the size of the memory allocation. + * + * Get the size of the memory allocation the remote key packs, which is at least the + * number of bytes specified with the `size` argument passed to `createMemoryHandle()`. + * + * @code{.cpp} + * // remoteKey is `std::shared_ptr` + * auto remoteMemorySize = remoteKey->getSize(); + * @endcode + * + * @returns The size of the memory allocation. + */ + size_t getSize() const; + + /** + * @brief Get the base address of the memory allocation. + * + * Get the base address of the memory allocation the remote key packs, which is going + * to be used as the remote address to put or get memory via the + * `ucxx::Endpoint::memPut()` or `ucxx::Endpoint::memGet()` methods. + * + * @code{.cpp} + * // remoteKey is `std::shared_ptr` + * auto remoteMemoryBaseAddress = remoteKey->getBaseAddress(); + * @endcode + * + * @returns The base address of the memory allocation. + */ + uint64_t getBaseAddress(); + + /** + * @brief Serialize the remote key. + * + * Serialize the remote key to allow over-the-wire transfer and subsequent + * reconstruction of the object in the remote process. + * + * @code{.cpp} + * // remoteKey is `std::shared_ptr` + * auto serializedRemoteKey = remoteKey->serialize(); + * @endcode + * + * @returns The serialized remote key. + */ + SerializedRemoteKey serialize() const; +}; + +} // namespace ucxx diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index c4ef29cd..958103b8 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -103,4 +103,6 @@ typedef std::shared_ptr RequestCallbackUserData; */ typedef std::function(size_t)> AmAllocatorType; +typedef const std::string SerializedRemoteKey; + } // namespace ucxx diff --git a/cpp/src/context.cpp b/cpp/src/context.cpp index 9303a0d7..08493fa3 100644 --- a/cpp/src/context.cpp +++ b/cpp/src/context.cpp @@ -99,4 +99,10 @@ std::shared_ptr Context::createWorker(const bool enableDelayedSubmission return worker; } +std::shared_ptr Context::createMemoryHandle(const size_t size, void* buffer) +{ + auto context = std::dynamic_pointer_cast(shared_from_this()); + return ucxx::createMemoryHandle(context, size, buffer); +} + } // namespace ucxx diff --git a/cpp/src/memory_handle.cpp b/cpp/src/memory_handle.cpp new file mode 100644 index 00000000..684e6133 --- /dev/null +++ b/cpp/src/memory_handle.cpp @@ -0,0 +1,81 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace ucxx { + +MemoryHandle::MemoryHandle(std::shared_ptr context, const size_t size, void* buffer) +{ + setParent(context); + + ucp_mem_map_params_t params; + if (buffer == nullptr) { + params = {.field_mask = UCP_MEM_MAP_PARAM_FIELD_FLAGS | UCP_MEM_MAP_PARAM_FIELD_LENGTH, + .length = size, + .flags = UCP_MEM_MAP_NONBLOCK | UCP_MEM_MAP_ALLOCATE, + .memory_type = UCS_MEMORY_TYPE_HOST}; + } else { + params = {.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH, + .address = buffer, + .length = size, + .memory_type = UCS_MEMORY_TYPE_HOST}; + } + + utils::ucsErrorThrow(ucp_mem_map(context->getHandle(), ¶ms, &_handle)); + + ucp_mem_attr_t attr = {.field_mask = UCP_MEM_ATTR_FIELD_ADDRESS | UCP_MEM_ATTR_FIELD_LENGTH}; + + utils::ucsErrorThrow(ucp_mem_query(_handle, &attr)); + + _baseAddress = (uint64_t)attr.address; + _size = attr.length; + + ucxx_trace("MemoryHandle created: %p, UCP handle: %p, base address: 0x%lx, size: %lu", + this, + _handle, + _baseAddress, + _size); +} + +MemoryHandle::~MemoryHandle() +{ + ucp_mem_unmap(std::dynamic_pointer_cast(getParent())->getHandle(), _handle); + ucxx_trace("ucxx::MemoryHandle destroyed: %p, UCP handle: %p, base address: 0x%lx, size: %lu", + this, + _handle, + _baseAddress, + _size); +} + +std::shared_ptr createMemoryHandle(std::shared_ptr context, + const size_t size, + void* buffer) +{ + return std::shared_ptr(new MemoryHandle(context, size, buffer)); +} + +ucp_mem_h MemoryHandle::getHandle() { return _handle; } + +size_t MemoryHandle::getSize() const { return _size; } + +uint64_t MemoryHandle::getBaseAddress() { return _baseAddress; } + +std::shared_ptr MemoryHandle::createRemoteKey() +{ + return createRemoteKeyFromMemoryHandle( + std::dynamic_pointer_cast(shared_from_this())); +} + +} // namespace ucxx diff --git a/cpp/src/remote_key.cpp b/cpp/src/remote_key.cpp new file mode 100644 index 00000000..768f49fc --- /dev/null +++ b/cpp/src/remote_key.cpp @@ -0,0 +1,141 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include +#include +#include +#include + +#include + +#include +#include + +namespace ucxx { + +RemoteKey::RemoteKey(std::shared_ptr memoryHandle) + : _memoryBaseAddress(memoryHandle->getBaseAddress()), _memorySize(memoryHandle->getSize()) +{ + setParent(memoryHandle); + + utils::ucsErrorThrow( + ucp_rkey_pack(std::dynamic_pointer_cast(memoryHandle->getParent())->getHandle(), + memoryHandle->getHandle(), + &_packedRemoteKey, + &_packedRemoteKeySize)); + + ucxx_trace( + "ucxx::RemoteKey created (memory handle): %p, base address: 0x%lx, size: %lu, packed remote " + "key " + "size: %lu", + this, + _memoryBaseAddress, + _memorySize, + _packedRemoteKeySize); +} + +RemoteKey::RemoteKey(std::shared_ptr endpoint, SerializedRemoteKey serializedRemoteKey) +{ + setParent(endpoint); + + deserialize(serializedRemoteKey); + + utils::ucsErrorThrow(ucp_ep_rkey_unpack(endpoint->getHandle(), _packedRemoteKey, &_remoteKey)); + + ucxx_trace( + "ucxx::RemoteKey created (deserialize): %p, UCP handle: %p, base address: 0x%lx, size: %lu, " + "packed remote key size: %lu", + this, + _remoteKey, + _memoryBaseAddress, + _memorySize, + _packedRemoteKeySize); +} + +RemoteKey::~RemoteKey() +{ + // ucxx_trace("ucxx::Endpoint destroyed: %p, UCP handle: %p", this, _originalHandle); + if (std::dynamic_pointer_cast(getParent()) != nullptr) { + // Only packed remote key if this object was created from a `MemoryHandle`, i.e., the + // buffer is local. + ucp_rkey_buffer_release(_packedRemoteKey); + ucxx_trace("ucxx::RemoteKey destroyed (memory handle): %p", this); + } + if (_remoteKey != nullptr) { + // Only destroy remote key if this was created from a `SerializedRemoteKey`, i.e., the + // buffer is remote. + ucp_rkey_destroy(_remoteKey); + ucxx_trace("ucxx::RemoteKey destroyed (deserialized): %p, UCP handle: %p", _remoteKey); + } +} + +std::shared_ptr createRemoteKeyFromMemoryHandle( + std::shared_ptr memoryHandle) +{ + return std::shared_ptr(new RemoteKey(memoryHandle)); +} + +std::shared_ptr createRemoteKeyFromSerialized(std::shared_ptr endpoint, + SerializedRemoteKey serializedRemoteKey) +{ + return std::shared_ptr(new RemoteKey(endpoint, serializedRemoteKey)); +} + +size_t RemoteKey::getSize() const { return _memorySize; } + +uint64_t RemoteKey::getBaseAddress() { return (uint64_t)_memoryBaseAddress; } + +ucp_rkey_h RemoteKey::getHandle() { return _remoteKey; } + +SerializedRemoteKey RemoteKey::serialize() const +{ + std::stringstream ss; + + ss.write(reinterpret_cast(&_packedRemoteKeySize), sizeof(_packedRemoteKeySize)); + ss.write(reinterpret_cast(_packedRemoteKey), _packedRemoteKeySize); + ss.write(reinterpret_cast(&_memoryBaseAddress), sizeof(_memoryBaseAddress)); + ss.write(reinterpret_cast(&_memorySize), sizeof(_memorySize)); + + auto serializedString = ss.str(); + + // Hash data to provide some degree of confidence on received data. + std::stringstream ssHash; + std::hash hasher; + SerializedRemoteKeyHash hash = hasher(serializedString); + ssHash.write(reinterpret_cast(&hash), sizeof(hash)); + + return ssHash.str() + serializedString; +} + +void RemoteKey::deserialize(const SerializedRemoteKey& serializedRemoteKey) +{ + auto serializedRemoteKeyHash = std::string( + serializedRemoteKey.begin(), serializedRemoteKey.begin() + sizeof(SerializedRemoteKeyHash)); + auto serializedRemoteKeyData = std::string( + serializedRemoteKey.begin() + sizeof(SerializedRemoteKeyHash), serializedRemoteKey.end()); + + // Check data hash and throw if there's no match. + std::stringstream ss{serializedRemoteKeyHash}; + SerializedRemoteKeyHash expectedHash; + ss.read(reinterpret_cast(&expectedHash), sizeof(expectedHash)); + std::hash hasher; + SerializedRemoteKeyHash actualHash = hasher(serializedRemoteKeyData); + if (actualHash != expectedHash) + throw std::runtime_error("Checksum error of serialized remote key"); + + ss = std::stringstream{std::string(serializedRemoteKey.begin() + sizeof(SerializedRemoteKeyHash), + serializedRemoteKey.end())}; + + ss.read(reinterpret_cast(&_packedRemoteKeySize), sizeof(_packedRemoteKeySize)); + + // Use a vector to store data so we don't need to bother releasing it later. + _packedRemoteKeyVector = std::vector(_packedRemoteKeySize); + _packedRemoteKey = _packedRemoteKeyVector.data(); + + ss.read(reinterpret_cast(_packedRemoteKey), _packedRemoteKeySize); + ss.read(reinterpret_cast(&_memoryBaseAddress), sizeof(_memoryBaseAddress)); + ss.read(reinterpret_cast(&_memorySize), sizeof(_memorySize)); +} + +} // namespace ucxx diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 78de5ada..49fb65ce 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -51,6 +51,7 @@ ConfigureTest( header.cpp listener.cpp request.cpp + rma.cpp utils.cpp worker.cpp ) diff --git a/cpp/tests/rma.cpp b/cpp/tests/rma.cpp new file mode 100644 index 00000000..f70d1e84 --- /dev/null +++ b/cpp/tests/rma.cpp @@ -0,0 +1,174 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "include/utils.h" +#include "ucxx/buffer.h" +#include "ucxx/constructors.h" +#include "ucxx/utils/ucx.h" + +namespace { + +using ::testing::Combine; +using ::testing::Values; + +class RmaTest : public ::testing::TestWithParam> { + protected: + std::shared_ptr _context{nullptr}; + std::shared_ptr _worker{nullptr}; + std::shared_ptr _ep{nullptr}; + + ucs_memory_type_t _memoryType; + size_t _messageSize; + bool _preallocateBuffer; + size_t _rndvThresh{8192}; + void* _buffer{nullptr}; + + void SetUp() + { + std::tie(_memoryType, _messageSize, _preallocateBuffer) = GetParam(); + + _context = ucxx::createContext({{"RNDV_THRESH", std::to_string(_rndvThresh)}}, + ucxx::Context::defaultFeatureFlags); + _worker = _context->createWorker(); + + _ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + + _buffer = preallocate(); + } + + void TearDown() { release(); } + + void* preallocate() + { + if (_preallocateBuffer) { + if (_memoryType == UCS_MEMORY_TYPE_HOST) + return malloc(_messageSize); + else + throw std::runtime_error("Unsupported memory type"); + } + return nullptr; + } + + void release() + { + if (_preallocateBuffer && _buffer != nullptr) { + if (_memoryType == UCS_MEMORY_TYPE_HOST) free(_buffer); + } + } +}; + +class BasicUcxxRmaTest : public ::testing::TestWithParam> { + protected: + std::shared_ptr _context{nullptr}; + std::shared_ptr _worker{nullptr}; + std::shared_ptr _ep{nullptr}; + + size_t _messageSize; + + void SetUp() + { + std::tie(_messageSize) = GetParam(); + + _context = ucxx::createContext({}, ucxx::Context::defaultFeatureFlags); + _worker = _context->createWorker(); + _ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + } +}; + +TEST_P(RmaTest, MemoryHandle) +{ + auto memoryHandle = _context->createMemoryHandle(_messageSize, _buffer); + ASSERT_GE(memoryHandle->getSize(), _messageSize); + if (_messageSize == 0) + ASSERT_EQ(memoryHandle->getBaseAddress(), 0); + else + ASSERT_NE(memoryHandle->getBaseAddress(), 0); + + ASSERT_NE(memoryHandle->getHandle(), nullptr); +} + +TEST_P(RmaTest, MemoryHandleUcxxNamespaceConstructor) +{ + auto memoryHandle = ucxx::createMemoryHandle(_context, _messageSize, _buffer); + ASSERT_GE(memoryHandle->getSize(), _messageSize); + if (_messageSize == 0) + ASSERT_EQ(memoryHandle->getBaseAddress(), 0); + else + ASSERT_NE(memoryHandle->getBaseAddress(), 0); + + ASSERT_NE(memoryHandle->getHandle(), nullptr); +} + +TEST_P(RmaTest, RemoteKey) +{ + auto memoryHandle = _context->createMemoryHandle(_messageSize, _buffer); + + auto remoteKey = memoryHandle->createRemoteKey(); + + ASSERT_EQ(remoteKey->getSize(), memoryHandle->getSize()); + ASSERT_EQ(remoteKey->getBaseAddress(), memoryHandle->getBaseAddress()); + ASSERT_EQ(remoteKey->getHandle(), nullptr); +} + +TEST_P(RmaTest, RemoteKeyUcxxNamespaceConstructor) +{ + auto memoryHandle = ucxx::createMemoryHandle(_context, _messageSize, _buffer); + + auto remoteKey = ucxx::createRemoteKeyFromMemoryHandle(memoryHandle); + + ASSERT_EQ(remoteKey->getSize(), memoryHandle->getSize()); + ASSERT_EQ(remoteKey->getBaseAddress(), memoryHandle->getBaseAddress()); + ASSERT_EQ(remoteKey->getHandle(), nullptr); +} + +TEST_P(RmaTest, RemoteKeySerialization) +{ + auto memoryHandle = _context->createMemoryHandle(_messageSize, _buffer); + + auto remoteKey = memoryHandle->createRemoteKey(); + + auto serializedRemoteKey = remoteKey->serialize(); + + auto deserializedRemoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + ASSERT_EQ(remoteKey->getSize(), deserializedRemoteKey->getSize()); + ASSERT_EQ(remoteKey->getBaseAddress(), deserializedRemoteKey->getBaseAddress()); + ASSERT_NE(deserializedRemoteKey->getHandle(), nullptr); +} + +TEST_P(BasicUcxxRmaTest, RemoteKeyCorruptedSerializedData) +{ + auto memoryHandle = _context->createMemoryHandle(_messageSize, nullptr); + + auto remoteKey = memoryHandle->createRemoteKey(); + + auto serializedRemoteKey = remoteKey->serialize(); + serializedRemoteKey[1] = ~serializedRemoteKey[1]; + + EXPECT_THROW(ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey), std::runtime_error); +} + +INSTANTIATE_TEST_SUITE_P(AttributeTests, + RmaTest, + Combine(Values(UCS_MEMORY_TYPE_HOST), + Values(0, 1, 4, 4096, 8192, 4194304), + Values(false, true))); + +INSTANTIATE_TEST_SUITE_P(FailureTests, BasicUcxxRmaTest, Combine(Values(0, 4194304))); + +} // namespace