From 18e61d8ab2659f3c5461abfe3119196f6199cf38 Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Fri, 26 Jul 2024 09:59:49 +0800 Subject: [PATCH 1/6] add wholememory_hierarchy_type, new gather func needed --- cpp/include/wholememory/wholememory.h | 32 +++ cpp/src/wholememory/communicator.cpp | 7 + cpp/src/wholememory/communicator.hpp | 3 + cpp/src/wholememory/embedding.cpp | 3 + cpp/src/wholememory/memory_handle.cpp | 203 +++++++++++++++++- cpp/src/wholememory/memory_handle.hpp | 6 + cpp/src/wholememory/wholememory.cpp | 16 ++ cpp/src/wholememory_ops/gather_op.cpp | 13 ++ cpp/src/wholememory_ops/gather_op_impl.h | 11 + .../gather_op_impl_hierarchy.cu | 172 +++++++++++++++ .../binding/wholememory_binding.pyx | 2 + .../pylibwholegraph/test_utils/test_comm.py | 2 + .../pylibwholegraph/torch/common_options.py | 2 +- .../pylibwholegraph/torch/embedding.py | 11 + .../pylibwholegraph/torch/utils.py | 6 +- 15 files changed, 483 insertions(+), 6 deletions(-) create mode 100644 cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index a1678ee8b..e83510244 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -63,6 +63,7 @@ enum wholememory_memory_type_t { WHOLEMEMORY_MT_CONTINUOUS, /*!< Memory from all ranks are mapped in continuous address space */ WHOLEMEMORY_MT_CHUNKED, /*!< Memory from all ranks are mapped in chunked address space */ WHOLEMEMORY_MT_DISTRIBUTED, /*!< Memory from other ranks are not mapped. */ + WHOLEMEMORY_MT_HIERARCHY, /*!< Memory from other ranks are mapped in hierarchy address space */ }; /** @@ -206,6 +207,23 @@ wholememory_error_code_t wholememory_communicator_get_rank(int* rank, wholememor */ wholememory_error_code_t wholememory_communicator_get_size(int* size, wholememory_comm_t comm); +/** + * Get the local rank size of current process in the WholeMemory Communicator + * @param local_size : returned local rank size + * @param comm : WholeMemory Communicator + * @return : wholememory_error_code_t + */ + +wholememory_error_code_t wholememory_communicator_get_local_size(int* local_size, + wholememory_comm_t comm); + +/** + * Get the clique info of WholeMemory Communicator + * @param clique_info : returned clique info + * @param comm : WholeMemory Communicator + * @return : wholememory_error_code_t + */ + wholememory_error_code_t wholememory_communicator_get_clique_info(clique_info_t* clique_info, wholememory_comm_t comm); @@ -311,6 +329,20 @@ wholememory_error_code_t wholememory_get_local_memory(void** local_ptr, size_t* local_offset, wholememory_handle_t wholememory_handle); +/** + * Get local node memory from WholeMemory Handle, all gpus of the rank has direct access to the + * memory. Note that this is only available for WHOLEMEMORY_MT_HIERARCHY memory type. + * @param local_ptr : returned local node memory pointer + * @param local_size : returned local node memory size + * @param local_offset : returned local node memory offset from WholeMemory + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_get_local_node_memory(void** local_ptr, + size_t* local_size, + size_t* local_offset, + wholememory_handle_t wholememory_handle); + /** * Get local memory size from WholeMemory Handle of current rank * @param local_size : returned local memory size diff --git a/cpp/src/wholememory/communicator.cpp b/cpp/src/wholememory/communicator.cpp index dabb9ba1b..f76a4c7b1 100644 --- a/cpp/src/wholememory/communicator.cpp +++ b/cpp/src/wholememory/communicator.cpp @@ -897,6 +897,13 @@ wholememory_error_code_t communicator_get_size(int* size, wholememory_comm_t com return WHOLEMEMORY_SUCCESS; } +wholememory_error_code_t communicator_get_local_size(int* local_size, + wholememory_comm_t comm) noexcept +{ + *local_size = comm->intra_node_rank_num; + return WHOLEMEMORY_SUCCESS; +} + // wholememory_error_code_t communicator_get_clique_rank(int* clique_rank, // wholememory_comm_t comm) noexcept // { diff --git a/cpp/src/wholememory/communicator.hpp b/cpp/src/wholememory/communicator.hpp index b48d66b77..709965c55 100644 --- a/cpp/src/wholememory/communicator.hpp +++ b/cpp/src/wholememory/communicator.hpp @@ -291,6 +291,9 @@ wholememory_error_code_t communicator_get_rank(int* rank, wholememory_comm_t com wholememory_error_code_t communicator_get_size(int* size, wholememory_comm_t comm) noexcept; +wholememory_error_code_t communicator_get_local_size(int* local_size, + wholememory_comm_t comm) noexcept; + wholememory_error_code_t communicator_get_clique_info(clique_info_t* clique_info, wholememory_comm_t comm) noexcept; diff --git a/cpp/src/wholememory/embedding.cpp b/cpp/src/wholememory/embedding.cpp index 23e9ccb53..7d6aae869 100644 --- a/cpp/src/wholememory/embedding.cpp +++ b/cpp/src/wholememory/embedding.cpp @@ -964,6 +964,9 @@ wholememory_error_code_t wholememory_create_embedding( int embedding_world_size = 1; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&embedding_world_size, comm)); if (cache_policy != nullptr) { + if (memory_type == WHOLEMEMORY_MT_HIERARCHY) { + WHOLEMEMORY_ERROR("Cache is not supported now in hierarchy memory type."); + } if (cache_policy->cache_comm == comm) { if (cache_policy->cache_memory_location != WHOLEMEMORY_ML_DEVICE) { WHOLEMEMORY_ERROR( diff --git a/cpp/src/wholememory/memory_handle.cpp b/cpp/src/wholememory/memory_handle.cpp index c8f1644e3..9fc31ded1 100644 --- a/cpp/src/wholememory/memory_handle.cpp +++ b/cpp/src/wholememory/memory_handle.cpp @@ -106,7 +106,7 @@ class wholememory_impl { return gref; } virtual bool contains_pointer(const void* ptr) const = 0; - void get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset) const + virtual void get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset) const { if (local_ptr != nullptr) *local_ptr = local_partition_memory_pointer_; if (local_size != nullptr) *local_size = get_local_size(); @@ -128,7 +128,7 @@ class wholememory_impl { *rank_memory_offset = 0; return false; } - [[nodiscard]] size_t get_partition_stride() const + [[nodiscard]] virtual size_t get_partition_stride() const { return rank_partition_strategy_.partition_mem_stride; } @@ -647,6 +647,11 @@ class continuous_device_wholememory_impl : public wholememory_impl { data_granularity, rank_entry_partition) { + printf( + "while in continuous device wholememory creation, the memory_type (%d) and memory_location " + "(%d).\n", + (int)memory_type, + (int)memory_location); WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_CONTINUOUS); } void create_memory() override @@ -1747,6 +1752,138 @@ struct wholememory_create_param { size_t min_granularity; }; +class hierarchy_wholememory_impl : public wholememory_impl { + public: + hierarchy_wholememory_impl(wholememory_handle_t wholememory_handle, + size_t total_size, + wholememory_comm_t global_comm, + wholememory_comm_t local_comm, + wholememory_memory_type_t memory_type, + wholememory_memory_location_t memory_location, + size_t data_granularity) + : wholememory_impl( + wholememory_handle, total_size, global_comm, memory_type, memory_location, data_granularity) + { + WHOLEMEMORY_CHECK(memory_type == WHOLEMEMORY_MT_HIERARCHY); + local_comm_ = local_comm; + } + void create_memory() override + { + std::unique_lock mlock(local_comm_->mu); + local_memory_handle_ = new wholememory_handle_(); + local_memory_handle_->handle_id = negotiate_handle_id_with_comm_locked(local_comm_); + determine_node_size(); + + WM_COMM_CHECK_ALL_SAME(local_comm_, WM_MEM_OP_CREATE); + wholememory_create_param wcp(node_partition_strategy_.local_mem_size, + WHOLEMEMORY_MT_CONTINUOUS, + location_, + data_granularity_); + WM_COMM_CHECK_ALL_SAME(local_comm_, wcp); + + // TODO chunkded memory type and nvshmem type are both not supported yet. + if (is_intranode_communicator(local_comm_) || !SupportEGM()) + if (location_ == WHOLEMEMORY_ML_HOST) { + local_memory_handle_->impl = + new global_mapped_host_wholememory_impl(local_memory_handle_, + node_partition_strategy_.local_mem_size, + local_comm_, + WHOLEMEMORY_MT_CONTINUOUS, + location_, + data_granularity_); + } else if (location_ == WHOLEMEMORY_ML_DEVICE) { + local_memory_handle_->impl = + new continuous_device_wholememory_impl(local_memory_handle_, + node_partition_strategy_.local_mem_size, + local_comm_, + WHOLEMEMORY_MT_CONTINUOUS, + location_, + data_granularity_); + } else { + WHOLEMEMORY_ERROR("unsupported memory location"); + } + else { +#if CUDA_VERSION >= 12030 + local_memory_handle_->impl = + new continuous_mnnvl_wholememory_impl(local_memory_handle_, + node_partition_strategy_.local_mem_size, + local_comm_, + WHOLEMEMORY_MT_CONTINUOUS, + location_, + data_granularity_); +#else + WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINOUS is only supported on CUDA version >= 12.3"); +#endif + } + local_memory_handle_->impl->create_memory(); + local_comm_->wholememory_map.insert( + std::pair(local_memory_handle_->handle_id, local_memory_handle_)); + local_node_memory_pointer_ = local_memory_handle_->impl->get_continuous_mapping_pointer(); + } + [[nodiscard]] wholememory_gref_t get_global_reference() const noexcept override + { + wholememory_gref_t gref{}; + gref.pointer = local_node_memory_pointer_; + gref.stride = 0; + return gref; + } + void get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset) const override + { + get_local_memory_from_handle(local_ptr, local_size, local_offset, local_memory_handle_); + *local_offset += node_partition_strategy_.local_mem_offset; + return; + } + void get_local_node_memory(void** local_node_ptr, + size_t* local_node_size, + size_t* local_node_offset) + { + *local_node_ptr = local_node_memory_pointer_; + *local_node_size = node_partition_strategy_.local_mem_size; + *local_node_offset = node_partition_strategy_.local_mem_offset; + } + [[nodiscard]] size_t get_partition_stride() const override + { + return local_memory_handle_->impl->get_partition_stride(); + } + [[nodiscard]] wholememory_comm_t get_local_comm() const { return local_comm_; } + void destroy_memory() noexcept override { destroy_wholememory(local_memory_handle_); } + bool contains_pointer(const void* ptr) const override + { + uint64_t int_ptr = reinterpret_cast(ptr); + uint64_t int_start_ptr = reinterpret_cast(local_node_memory_pointer_); + return int_ptr >= int_start_ptr && + int_ptr < int_start_ptr + node_partition_strategy_.local_mem_size; + } + + protected: + void determine_node_size() + { + size_t node_num = comm_->world_size / local_comm_->world_size; + size_t node_id = comm_->world_rank / local_comm_->world_size; + size_t data_slot_count = total_size_ / data_granularity_; + size_t data_slot_per_rank = determine_entry_partition_plan(data_slot_count, comm_->world_size); + size_t data_slot_per_node = data_slot_per_rank * local_comm_->world_size; + size_t node_data_slot_start = std::min(node_id * data_slot_per_node, data_slot_count); + size_t node_data_slot_end = std::min((node_id + 1) * data_slot_per_node, data_slot_count); + size_t node_data_slot_count = node_data_slot_end - node_data_slot_start; + + node_partition_strategy_.local_mem_size = node_data_slot_count * data_granularity_; + node_partition_strategy_.local_mem_offset = node_data_slot_start * data_granularity_; + node_partition_strategy_.partition_mem_stride = data_slot_per_node * data_granularity_; + } + + wholememory_handle_t local_memory_handle_; + wholememory_comm_t local_comm_; + void* local_node_memory_pointer_; + struct partition_strategy { + // size of memory this rank is responsible for + size_t local_mem_size = 0; + // start location of the memory this rank is responsible for + size_t local_mem_offset = 0; + size_t partition_mem_stride = 0; + } node_partition_strategy_; +}; + wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_handle_ptr, size_t total_size, wholememory_comm_t comm, @@ -1802,7 +1939,15 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha data_granularity, rank_entry_partition); } - } else if (memory_type == WHOLEMEMORY_MT_CONTINUOUS) { + } else if (memory_type == WHOLEMEMORY_MT_CONTINUOUS || + (memory_type == WHOLEMEMORY_MT_HIERARCHY && is_intranode_communicator(comm)) || + (memory_type == WHOLEMEMORY_MT_HIERARCHY && is_intra_mnnvl_communicator(comm))) { + if (memory_type == WHOLEMEMORY_MT_HIERARCHY) { + WHOLEMEMORY_WARN( + "intra-node or intra-mnnvl HIERARCHY memory type is implemented as CONTINUOUS memory " + "type"); + memory_type = WHOLEMEMORY_MT_CONTINUOUS; + } if (is_intranode_communicator(comm) || !SupportEGM()) { if (memory_location == WHOLEMEMORY_ML_HOST) { whole_memory_handle->impl = new global_mapped_host_wholememory_impl(whole_memory_handle, @@ -1853,6 +1998,39 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha data_granularity, rank_entry_partition); } + } else if (memory_type == WHOLEMEMORY_MT_HIERARCHY) { + wholememory_comm_t local_comm; + if (SupportEGM() && is_intra_mnnvl_communicator(comm)) { +#if CUDA_VERSION >= 12030 + clique_info_t* clique_info = nullptr; + wholememory_communicator_get_clique_info(clique_info, comm); + WHOLEMEMORY_CHECK_NOTHROW(clique_info->is_in_clique); + wholememory_split_communicator( + &local_comm, comm, clique_info->clique_id, clique_info->clique_rank); + whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle, + total_size, + comm, + local_comm, + memory_type, + memory_location, + data_granularity); +#else + WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINUOUS is only supported on CUDA Version >= 12.3"); +#endif + } else { + int world_rank = -1, local_size = -1; + wholememory_communicator_get_size(&world_rank, comm); + wholememory_communicator_get_local_size(&local_size, comm); + wholememory_split_communicator( + &local_comm, comm, world_rank / local_size, world_rank % local_size); + whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle, + total_size, + comm, + local_comm, + memory_type, + memory_location, + data_granularity); + } } else { WHOLEMEMORY_FATAL("Unsupported memory_type (%d) and memory_location (%d).", (int)memory_type, @@ -1967,6 +2145,25 @@ wholememory_error_code_t get_local_memory_from_handle( return WHOLEMEMORY_SUCCESS; } +wholememory_error_code_t get_local_node_memory_from_handle( + void** local_ptr, + size_t* local_size, + size_t* local_offset, + wholememory_handle_t wholememory_handle) noexcept +{ + if (get_memory_type(wholememory_handle) != WHOLEMEMORY_MT_HIERARCHY) { + WHOLEMEMORY_ERROR("Only Hierarchy memory type support get_local_node_memory function."); + return WHOLEMEMORY_INVALID_INPUT; + } + if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { + return WHOLEMEMORY_INVALID_INPUT; + } + hierarchy_wholememory_impl* hierarchy_impl = + dynamic_cast(wholememory_handle->impl); + hierarchy_impl->get_local_node_memory(local_ptr, local_size, local_offset); + return WHOLEMEMORY_SUCCESS; +} + wholememory_error_code_t get_rank_memory_from_handle( void** rank_memory_ptr, size_t* rank_memory_size, diff --git a/cpp/src/wholememory/memory_handle.hpp b/cpp/src/wholememory/memory_handle.hpp index c16e5bc03..9e3ac8334 100644 --- a/cpp/src/wholememory/memory_handle.hpp +++ b/cpp/src/wholememory/memory_handle.hpp @@ -65,6 +65,12 @@ wholememory_error_code_t get_local_memory_from_handle( size_t* local_offset, wholememory_handle_t wholememory_handle) noexcept; +wholememory_error_code_t get_local_node_memory_from_handle( + void** local_ptr, + size_t* local_size, + size_t* local_offset, + wholememory_handle_t wholememory_handle) noexcept; + wholememory_error_code_t get_rank_memory_from_handle( void** rank_memory_ptr, size_t* rank_memory_size, diff --git a/cpp/src/wholememory/wholememory.cpp b/cpp/src/wholememory/wholememory.cpp index 600906889..2807b1582 100644 --- a/cpp/src/wholememory/wholememory.cpp +++ b/cpp/src/wholememory/wholememory.cpp @@ -75,6 +75,13 @@ wholememory_error_code_t wholememory_communicator_get_size(int* size, wholememor { return wholememory::communicator_get_size(size, comm); } + +wholememory_error_code_t wholememory_communicator_get_local_size(int* local_size, + wholememory_comm_t comm) +{ + return wholememory::communicator_get_local_size(local_size, comm); +} + bool wholememory_communicator_is_bind_to_nvshmem(wholememory_comm_t comm) { #ifdef WITH_NVSHMEM_SUPPORT @@ -166,6 +173,15 @@ wholememory_error_code_t wholememory_get_local_memory(void** local_ptr, local_ptr, local_size, local_offset, wholememory_handle); } +wholememory_error_code_t wholememory_get_local_node_memory(void** local_ptr, + size_t* local_size, + size_t* local_offset, + wholememory_handle_t wholememory_handle) +{ + return wholememory::get_local_node_memory_from_handle( + local_ptr, local_size, local_offset, wholememory_handle); +} + wholememory_error_code_t wholememory_get_rank_memory(void** rank_memory_ptr, size_t* rank_memory_size, size_t* rank_memory_offset, diff --git a/cpp/src/wholememory_ops/gather_op.cpp b/cpp/src/wholememory_ops/gather_op.cpp index 98d41d222..944407143 100644 --- a/cpp/src/wholememory_ops/gather_op.cpp +++ b/cpp/src/wholememory_ops/gather_op.cpp @@ -93,6 +93,19 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten gather_sms); } + if (has_handle && memory_type == WHOLEMEMORY_MT_HIERARCHY) { + return wholememory_ops::wholememory_gather_hierarchy( + wholememory_tensor_get_memory_handle(wholememory_tensor), + matrix_description, + indices, + indices_desc, + output, + output_desc, + p_env_fns, + static_cast(stream), + gather_sms); + } + WHOLEMEMORY_EXPECTS_NOTHROW(!has_handle || memory_type == WHOLEMEMORY_MT_CHUNKED || memory_type == WHOLEMEMORY_MT_CONTINUOUS, "Memory type not supported."); diff --git a/cpp/src/wholememory_ops/gather_op_impl.h b/cpp/src/wholememory_ops/gather_op_impl.h index 21896ff24..19f3c08b3 100644 --- a/cpp/src/wholememory_ops/gather_op_impl.h +++ b/cpp/src/wholememory_ops/gather_op_impl.h @@ -42,6 +42,17 @@ wholememory_error_code_t wholememory_gather_nccl(wholememory_handle_t wholememor cudaStream_t stream, int gather_sms); +wholememory_error_code_t wholememory_gather_hierarchy( + wholememory_handle_t wholememory_handle, + wholememory_matrix_description_t wholememory_desc, + void* indices, + wholememory_array_description_t indice_desc, + void* output, + wholememory_matrix_description_t output_desc, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream, + int gather_sms); + wholememory_error_code_t wholememory_gather_distributed( wholememory_handle_t wholememory_handle, wholememory_matrix_description_t wholememory_desc, diff --git a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu new file mode 100644 index 000000000..2e49bcd02 --- /dev/null +++ b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include + +#include "logger.hpp" +#include "wholememory/communicator.hpp" +#include "wholememory/memory_handle.hpp" +#include "wholememory_ops/functions/bucket_ids_func.h" +#include "wholememory_ops/functions/exchange_embeddings_nccl_func.h" +#include "wholememory_ops/functions/exchange_ids_nccl_func.h" +#include "wholememory_ops/functions/gather_scatter_func.h" +#include "wholememory_ops/gather_op_impl.h" +#include "wholememory_ops/temp_memory_handle.hpp" +#include "wholememory_ops/thrust_allocator.hpp" + +namespace wholememory_ops { + +wholememory_error_code_t wholememory_gather_hierarchy( + wholememory_handle_t wholememory_handle, + wholememory_matrix_description_t wholememory_desc, + void* indices, + wholememory_array_description_t indice_desc, + void* output, + wholememory_matrix_description_t output_desc, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream, + int gather_sms) +{ + try { + if (wholememory_desc.storage_offset < 0 || + wholememory_desc.storage_offset + wholememory_desc.sizes[1] > wholememory_desc.stride) { + return WHOLEMEMORY_INVALID_INPUT; + } + + wm_thrust_allocator thrust_allocator(p_env_fns); + + size_t embedding_size_per_rank; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_partition_plan(&embedding_size_per_rank, wholememory_handle)); + + size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); + size_t embedding_entry_size = element_size * wholememory_desc.stride; + + WHOLEMEMORY_EXPECTS_NOTHROW( + embedding_size_per_rank % embedding_entry_size == 0, + "embedding_size_per_rank=%ld is not multiple of embedding_entry_size=%ldx%ld", + embedding_size_per_rank, + element_size, + wholememory_desc.stride); + + size_t embedding_entry_count_per_rank = embedding_size_per_rank / embedding_entry_size; + + wholememory_comm_t wm_comm; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, wholememory_handle)); + + int world_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_comm)); + + temp_memory_handle host_rank_id_count(p_env_fns), host_recv_rank_id_count(p_env_fns); + int64_t* host_rank_id_count_ptr = + static_cast(host_rank_id_count.host_malloc(world_size, WHOLEMEMORY_DT_INT64)); + int64_t* host_recv_rank_id_count_ptr = + static_cast(host_recv_rank_id_count.host_malloc(world_size, WHOLEMEMORY_DT_INT64)); + + temp_memory_handle dev_recv_indice_buffer(p_env_fns); + temp_memory_handle dev_raw_indice(p_env_fns); + int64_t* dev_raw_indice_ptr = + static_cast(dev_raw_indice.device_malloc(indice_desc.size, WHOLEMEMORY_DT_INT64)); + + int64_t total_recv_count = 0; + WHOLEMEMORY_RETURN_ON_FAIL(bucket_and_exchange_ids_func(indices, + indice_desc, + host_recv_rank_id_count_ptr, + host_rank_id_count_ptr, + &dev_recv_indice_buffer, + dev_raw_indice_ptr, + embedding_entry_count_per_rank, + wm_comm, + &thrust_allocator, + p_env_fns, + stream)); + + // Local Gather + for (int i = 0; i < world_size; i++) { + total_recv_count += host_recv_rank_id_count_ptr[i]; + } + size_t local_mem_offset, local_mem_size; + temp_memory_handle dev_local_gather_buffer(p_env_fns); + temp_memory_handle dev_embedding_recv_buffer(p_env_fns); + void* dev_local_gather_buffer_ptr = dev_local_gather_buffer.device_malloc( + wholememory_desc.sizes[1] * total_recv_count, output_desc.dtype); + void* dev_embedding_recv_buffer_ptr = dev_embedding_recv_buffer.device_malloc( + wholememory_desc.sizes[1] * indice_desc.size, output_desc.dtype); + void* local_fake_ptr = nullptr; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_local_memory( + &local_fake_ptr, &local_mem_size, &local_mem_offset, wholememory_handle)); + local_fake_ptr = static_cast(local_fake_ptr) - local_mem_offset; + wholememory_gref_t local_fake_gref = + wholememory_create_continuous_global_reference(local_fake_ptr); + int64_t local_buffer_size[2] = {total_recv_count, wholememory_desc.sizes[1]}; + wholememory_matrix_description_t local_gather_buffer_desc = wholememory_create_matrix_desc( + local_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); + auto dev_recv_indice_desc = + wholememory_create_array_desc(total_recv_count, 0, indice_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(local_fake_gref, + wholememory_desc, + dev_recv_indice_buffer.pointer(), + dev_recv_indice_desc, + dev_local_gather_buffer_ptr, + local_gather_buffer_desc, + stream, + gather_sms)); + // AllToAllV for embeddings + size_t embedding_size = + wholememory_desc.sizes[1] * wholememory_dtype_get_element_size(output_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(exchange_embeddings_nccl_func(dev_local_gather_buffer_ptr, + host_recv_rank_id_count_ptr, + host_rank_id_count_ptr, + dev_embedding_recv_buffer_ptr, + embedding_size, + wm_comm, + stream)); + // Local reorder + int64_t total_need_indice_count = 0; + for (int i = 0; i < world_size; i++) { + total_need_indice_count += host_rank_id_count_ptr[i]; + } + wholememory_gref_t output_gref = wholememory_create_continuous_global_reference(output); + wholememory_matrix_description_t local_recv_buffer_desc = + wholememory_create_matrix_desc(output_desc.sizes, output_desc.sizes[1], 0, output_desc.dtype); + local_recv_buffer_desc.sizes[0] = total_need_indice_count; + auto raw_indice_desc = + wholememory_create_array_desc(total_need_indice_count, 0, WHOLEMEMORY_DT_INT64); + WHOLEMEMORY_RETURN_ON_FAIL(scatter_func(dev_embedding_recv_buffer_ptr, + local_recv_buffer_desc, + dev_raw_indice_ptr, + raw_indice_desc, + output_gref, + output_desc, + stream)); + WM_CUDA_CHECK(cudaGetLastError()); + // WM_CUDA_CHECK(cudaStreamSynchronize(stream)); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("CUDA logic Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } catch (wholememory::logic_error& wle) { + WHOLEMEMORY_ERROR("LOGIC Error %s\n", wle.what()); + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_UNKNOW_ERROR; + } + + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index 61039d83c..a16a09257 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -61,6 +61,7 @@ cdef extern from "wholememory/wholememory.h": WHOLEMEMORY_MT_CONTINUOUS "WHOLEMEMORY_MT_CONTINUOUS" WHOLEMEMORY_MT_CHUNKED "WHOLEMEMORY_MT_CHUNKED" WHOLEMEMORY_MT_DISTRIBUTED "WHOLEMEMORY_MT_DISTRIBUTED" + WHOLEMEMORY_MT_HIERARCHY "WHOLEMEMORY_MT_HIERARCHY" ctypedef enum wholememory_memory_location_t: WHOLEMEMORY_ML_NONE "WHOLEMEMORY_ML_NONE" @@ -226,6 +227,7 @@ cpdef enum WholeMemoryMemoryType: MtContinuous = WHOLEMEMORY_MT_CONTINUOUS MtChunked = WHOLEMEMORY_MT_CHUNKED MtDistributed = WHOLEMEMORY_MT_DISTRIBUTED + MtHierarchy = WHOLEMEMORY_MT_HIERARCHY cpdef enum WholeMemoryMemoryLocation: MlNone = WHOLEMEMORY_ML_NONE diff --git a/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py b/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py index f9f87f721..273440cf7 100644 --- a/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py +++ b/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py @@ -185,6 +185,8 @@ def int_to_wholememory_type(value: int): return wmb.WholeMemoryMemoryType.MtChunked if value == 2: return wmb.WholeMemoryMemoryType.MtDistributed + if value == 3: + return wmb.WholeMemoryMemoryType.MtHierarchy else: raise ValueError("invalid int_to_wholememory_type value") diff --git a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py index ebfe7dfb6..1d6f371df 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py @@ -33,7 +33,7 @@ def add_training_options(argparser: ArgumentParser): "--embedding-memory-type", dest="embedding_memory_type", default="chunked", - help="Embedding memory type, should be: continuous, chunked or distributed", + help="Embedding memory type, should be: continuous, chunked, distributed, hierarchy", ) argparser.add_argument( "--cache-type", diff --git a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py index 825c8cbaa..d242ca2fd 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py @@ -145,6 +145,7 @@ def create_builtin_cache_policy( embedding_memory_type != "continuous" and embedding_memory_type != "chunked" and embedding_memory_type != "distributed" + and embedding_memory_type != "hierarchy" ): raise ValueError(f"embedding_memory_type={embedding_memory_type} is not valid") @@ -425,6 +426,16 @@ def create_embedding( if embedding_entry_partition is not None and round_robin_size != 0: print("round_robin_size is ignored because embedding_entry_partition is specified") round_robin_size = 0 + if memory_type == 'hierarchy': # todo: modified + comm_backend = comm.distributed_backend + if comm_backend == 'nvshmem': + raise AssertionError + ("Hierarchy embedding is not supported yet when using NVSHMEM.") + if cache_policy is not None: + raise AssertionError + ("Hierarchy embedding is not supported yet when using cache.") + comm_backend = 'nccl' + wm_embedding = WholeMemoryEmbedding( wmb.create_embedding( tensor_desc, diff --git a/python/pylibwholegraph/pylibwholegraph/torch/utils.py b/python/pylibwholegraph/pylibwholegraph/torch/utils.py index c03c2f061..b63e7f235 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/utils.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -92,9 +92,11 @@ def str_to_wmb_wholememory_memory_type(str_wmb_type: str): return wmb.WholeMemoryMemoryType.MtChunked elif str_wmb_type == "distributed": return wmb.WholeMemoryMemoryType.MtDistributed + elif str_wmb_type == "hierarchy": + return wmb.WholeMemoryMemoryType.MtHierarchy else: raise ValueError( - "WholeMemory type %s not supported, should be (continuous, chunked, distributed)" + "WholeMemory type %s not supported, should be (continuous, chunked, distributed, hierarchy)" % (str_wmb_type,) ) From 1d61a49ba04719126bf74533d9d3aebccc7d49fc Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Mon, 5 Aug 2024 15:48:10 +0800 Subject: [PATCH 2/6] support obtaining local_comm and cross_comm for hierarchy memory type --- cpp/include/wholememory/wholememory.h | 19 +++++++ cpp/src/wholememory/memory_handle.cpp | 49 +++++++++++++++++++ cpp/src/wholememory/memory_handle.hpp | 6 +++ cpp/src/wholememory/wholememory.cpp | 12 +++++ .../binding/wholememory_binding.pyx | 16 ++++++ 5 files changed, 102 insertions(+) diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index e83510244..7cf8f6908 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -283,6 +283,25 @@ wholememory_error_code_t wholememory_free(wholememory_handle_t wholememory_handl wholememory_error_code_t wholememory_get_communicator(wholememory_comm_t* comm, wholememory_handle_t wholememory_handle); +/** + * Get underlying Wholememory Local Communicator for "Hierarchy" memory type from WholeMemory Handle + * @param comm : returned Local WholeMemory Communicator + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_get_local_communicator( + wholememory_comm_t* comm, wholememory_handle_t wholememory_handle); + +/** + * Get underlying Wholememory Cross Communicator for "Hierarchy" memory type from WholeMemory Handle + * One comminicator includes all rank with a same local id from different nodes + * @param comm : returned Local WholeMemory Communicator + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_get_cross_communicator( + wholememory_comm_t* comm, wholememory_handle_t wholememory_handle); + /** * Get WholeMemory Type * @param wholememory_handle : WholeMemory Handle diff --git a/cpp/src/wholememory/memory_handle.cpp b/cpp/src/wholememory/memory_handle.cpp index 9fc31ded1..d9ededa87 100644 --- a/cpp/src/wholememory/memory_handle.cpp +++ b/cpp/src/wholememory/memory_handle.cpp @@ -1766,6 +1766,23 @@ class hierarchy_wholememory_impl : public wholememory_impl { { WHOLEMEMORY_CHECK(memory_type == WHOLEMEMORY_MT_HIERARCHY); local_comm_ = local_comm; + if (SupportEGM() && is_intra_mnnvl_communicator(global_comm)) { +#if CUDA_VERSION >= 12030 + clique_info_t* clique_info = nullptr; + wholememory_communicator_get_clique_info(clique_info, global_comm); + WHOLEMEMORY_CHECK_NOTHROW(clique_info->is_in_clique); + wholememory_split_communicator( + &cross_comm_, global_comm, clique_info->clique_rank, clique_info->clique_id); +#else + WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINUOUS is only supported on CUDA Version >= 12.3"); +#endif + } else { + int world_rank = -1, local_size = -1; + wholememory_communicator_get_size(&world_rank, global_comm); + wholememory_communicator_get_local_size(&local_size, global_comm); + wholememory_split_communicator( + &cross_comm_, global_comm, world_rank % local_size, world_rank / local_size); + } } void create_memory() override { @@ -1846,6 +1863,7 @@ class hierarchy_wholememory_impl : public wholememory_impl { return local_memory_handle_->impl->get_partition_stride(); } [[nodiscard]] wholememory_comm_t get_local_comm() const { return local_comm_; } + [[nodiscard]] wholememory_comm_t get_cross_comm() const { return cross_comm_; } void destroy_memory() noexcept override { destroy_wholememory(local_memory_handle_); } bool contains_pointer(const void* ptr) const override { @@ -1874,6 +1892,7 @@ class hierarchy_wholememory_impl : public wholememory_impl { wholememory_handle_t local_memory_handle_; wholememory_comm_t local_comm_; + wholememory_comm_t cross_comm_; void* local_node_memory_pointer_; struct partition_strategy { // size of memory this rank is responsible for @@ -2106,6 +2125,36 @@ wholememory_error_code_t get_communicator_from_handle( return WHOLEMEMORY_SUCCESS; } +wholememory_error_code_t get_local_communicator_from_handle( + wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept +{ + if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { + return WHOLEMEMORY_INVALID_INPUT; + } + if (get_memory_type(wholememory_handle) != WHOLEMEMORY_MT_HIERARCHY) { + return WHOLEMEMORY_NOT_SUPPORTED; + } + hierarchy_wholememory_impl* hierarchy_impl = + dynamic_cast(wholememory_handle->impl); + *comm = hierarchy_impl->get_local_comm(); + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t get_cross_communicator_from_handle( + wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept +{ + if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { + return WHOLEMEMORY_INVALID_INPUT; + } + if (get_memory_type(wholememory_handle) != WHOLEMEMORY_MT_HIERARCHY) { + return WHOLEMEMORY_NOT_SUPPORTED; + } + hierarchy_wholememory_impl* hierarchy_impl = + dynamic_cast(wholememory_handle->impl); + *comm = hierarchy_impl->get_cross_comm(); + return WHOLEMEMORY_SUCCESS; +} + wholememory_memory_type_t get_memory_type(wholememory_handle_t wholememory_handle) noexcept { return wholememory_handle->impl->get_type(); diff --git a/cpp/src/wholememory/memory_handle.hpp b/cpp/src/wholememory/memory_handle.hpp index 9e3ac8334..a5ef21165 100644 --- a/cpp/src/wholememory/memory_handle.hpp +++ b/cpp/src/wholememory/memory_handle.hpp @@ -51,6 +51,12 @@ wholememory_error_code_t destroy_wholememory(wholememory_handle_t wholememory_ha wholememory_error_code_t get_communicator_from_handle( wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept; +wholememory_error_code_t get_local_communicator_from_handle( + wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept; + +wholememory_error_code_t get_cross_communicator_from_handle( + wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept; + wholememory_memory_type_t get_memory_type(wholememory_handle_t wholememory_handle) noexcept; wholememory_memory_location_t get_memory_location(wholememory_handle_t wholememory_handle) noexcept; diff --git a/cpp/src/wholememory/wholememory.cpp b/cpp/src/wholememory/wholememory.cpp index 2807b1582..47da7c735 100644 --- a/cpp/src/wholememory/wholememory.cpp +++ b/cpp/src/wholememory/wholememory.cpp @@ -137,6 +137,18 @@ wholememory_error_code_t wholememory_get_communicator(wholememory_comm_t* comm, return wholememory::get_communicator_from_handle(comm, wholememory_handle); } +wholememory_error_code_t wholememory_get_local_communicator(wholememory_comm_t* comm, + wholememory_handle_t wholememory_handle) +{ + return wholememory::get_local_communicator_from_handle(comm, wholememory_handle); +} + +wholememory_error_code_t wholememory_get_cross_communicator(wholememory_comm_t* comm, + wholememory_handle_t wholememory_handle) +{ + return wholememory::get_cross_communicator_from_handle(comm, wholememory_handle); +} + wholememory_memory_type_t wholememory_get_memory_type(wholememory_handle_t wholememory_handle) { return wholememory::get_memory_type(wholememory_handle); diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index a16a09257..5d8de9ce7 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -131,6 +131,12 @@ cdef extern from "wholememory/wholememory.h": cdef wholememory_error_code_t wholememory_get_communicator(wholememory_comm_t * comm, wholememory_handle_t wholememory_handle) + cdef wholememory_error_code_t wholememory_get_local_communicator(wholememory_comm_t * comm, + wholememory_handle_t wholememory_handle) + + cdef wholememory_error_code_t wholememory_get_cross_communicator(wholememory_comm_t * comm, + wholememory_handle_t wholememory_handle) + cdef wholememory_memory_type_t wholememory_get_memory_type(wholememory_handle_t wholememory_handle) cdef wholememory_memory_location_t wholememory_get_memory_location(wholememory_handle_t wholememory_handle) @@ -1344,6 +1350,16 @@ cdef class PyWholeMemoryHandle: check_wholememory_error_code(wholememory_get_communicator(&py_comm.comm_id, self.wholememory_handle)) return py_comm + def get_local_communicator(self): + py_comm = PyWholeMemoryComm() + check_wholememory_error_code(wholememory_get_local_communicator(&py_comm.comm_id, self.wholememory_handle)) + return py_comm + + def get_cross_communicator(self): + py_comm = PyWholeMemoryComm() + check_wholememory_error_code(wholememory_get_cross_communicator(&py_comm.comm_id, self.wholememory_handle)) + return py_comm + def get_memory_type(self): return WholeMemoryMemoryType(wholememory_get_memory_type(self.wholememory_handle)) From 76802891166289c28b49e9ab60ab8dda534d6d30 Mon Sep 17 00:00:00 2001 From: Zhuofan Li Date: Mon, 19 Aug 2024 03:23:11 +0000 Subject: [PATCH 3/6] implement hierarchy gather function --- .../bucket_ids_for_hierarchy_func.cu | 386 ++++++++++++++++++ .../functions/bucket_ids_for_hierarchy_func.h | 49 +++ .../sort_unique_ids_for_hierarchy_func.cu | 145 +++++++ .../sort_unique_ids_for_hierarchy_func.h | 35 ++ .../functions/sort_unique_indices_func.cu | 118 ++++++ .../functions/sort_unique_indices_func.h | 37 ++ .../gather_op_impl_hierarchy.cu | 322 +++++++++++---- 7 files changed, 1009 insertions(+), 83 deletions(-) create mode 100644 cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu create mode 100644 cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h create mode 100644 cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.cu create mode 100644 cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h create mode 100644 cpp/src/wholememory_ops/functions/sort_unique_indices_func.cu create mode 100644 cpp/src/wholememory_ops/functions/sort_unique_indices_func.h diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu new file mode 100644 index 000000000..6ad83a49a --- /dev/null +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu @@ -0,0 +1,386 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include +#include +#include +#include + +#include + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "logger.hpp" +#include "wholememory/integer_utils.hpp" +#include "wholememory_ops/register.hpp" +#include "wholememory_ops/temp_memory_handle.hpp" +#include + +namespace wholememory_ops { + +template +__global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices, + size_t indice_count, + int64_t* dev_rank_id_count_ptr, + size_t embedding_entry_count_per_rank, + int local_size, + int bucket_size) +{ + extern __shared__ int rank_count_shared[]; + for (int idx = threadIdx.x; idx < bucket_size; idx += blockDim.x) { + rank_count_shared[idx] = 0; + } + __syncthreads(); + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count; + idx += blockDim.x * gridDim.x) { + IndexT node_idx = indices[idx]; + if (node_idx < 0) continue; + int rank = node_idx / embedding_entry_count_per_rank; + int bucket = 0; + if (CROSS_OR_LOCAL == 0) // bucket cross ranks + bucket = rank % local_size; + else // bucket local ranks + bucket = rank / local_size; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + atomicAdd_block(&rank_count_shared[bucket], 1); +#else + atomicAdd(&rank_count_shared[bucket], 1); +#endif + } + __syncthreads(); + for (int idx = threadIdx.x; idx < bucket_size; idx += blockDim.x) { + atomicAdd(reinterpret_cast(dev_rank_id_count_ptr) + idx, + static_cast(rank_count_shared[idx])); + } +} + +template +void bucket_ids_for_hierarchy_temp_func(const void* indices, + wholememory_array_description_t indice_desc, + int64_t* dev_rank_id_count_ptr, + size_t embedding_entry_count_per_rank, + int local_size, + int cross_size, + int bucket_cross_or_local, + int sm_count, + cudaStream_t stream) +{ + static constexpr int BLOCK_SIZE = 128; + int block_count = wholememory::div_rounding_up_unsafe(indice_desc.size, BLOCK_SIZE); + block_count = std::min(block_count, sm_count * 4); + const IndexT* indices_ptr = static_cast(indices); + indices_ptr += indice_desc.storage_offset; + + if (bucket_cross_or_local == 0) { + int bucket_size = local_size; + cudaMemsetAsync(dev_rank_id_count_ptr, 0, sizeof(int64_t) * bucket_size, stream); + bucket_ids_for_hierarchy_kernel + <<>>( + indices_ptr, + indice_desc.size, + dev_rank_id_count_ptr, + embedding_entry_count_per_rank, + local_size, + bucket_size); + } else { + int bucket_size = cross_size; + cudaMemsetAsync(dev_rank_id_count_ptr, 0, sizeof(int64_t) * bucket_size, stream); + bucket_ids_for_hierarchy_kernel + <<>>( + indices_ptr, + indice_desc.size, + dev_rank_id_count_ptr, + embedding_entry_count_per_rank, + local_size, + bucket_size); + } +} + +REGISTER_DISPATCH_ONE_TYPE(BucketIdsForHierarchy, bucket_ids_for_hierarchy_temp_func, SINT3264) + +template +__global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, + size_t indice_count, + IndexT* dev_bucket_indices, + IndexT* dev_indice_map, + const int64_t* dev_rank_id_offset_ptr, + size_t embedding_entry_count_per_rank, + int local_size, + int64_t* dev_bucket_atomic_add_ptr) +{ + int nbucket = local_size; + constexpr size_t shared_mem_size = 24576; + __shared__ char shared_mem[shared_mem_size]; + int* block_bucket_count_shared = reinterpret_cast(shared_mem); + int* block_bucket_atomic_add_shared = reinterpret_cast(shared_mem) + nbucket; + IndexT* block_bucket_offset_shared = + reinterpret_cast(shared_mem + 2 * sizeof(int) * nbucket); + IndexT* global_bucket_offset_shared = block_bucket_offset_shared + nbucket; + size_t buffer_size = + (shared_mem_size - nbucket * 2 * (sizeof(IndexT) + sizeof(int))) / sizeof(IndexT) / 2; + buffer_size = (buffer_size / blockDim.x) * blockDim.x; + assert(buffer_size > 0); + + IndexT* buffer_load = global_bucket_offset_shared + nbucket; + IndexT* buffer_store = buffer_load + buffer_size; + + int warp_idx = threadIdx.x / warpSize; + int lane_idx = threadIdx.x % warpSize; + int nwarp = blockDim.x / warpSize; + for (IndexT load_offset = buffer_size * blockIdx.x; load_offset < indice_count; + load_offset += gridDim.x * buffer_size) { + for (int i = threadIdx.x; i < nbucket; i += blockDim.x) { + block_bucket_count_shared[i] = 0; + block_bucket_atomic_add_shared[i] = 0; + } + __syncthreads(); + for (IndexT i = threadIdx.x; i < buffer_size; i += blockDim.x) { + IndexT load_idx = i + load_offset; + if (load_idx >= indice_count) break; + IndexT indice = indices[load_idx]; + + buffer_load[i] = indice; + int bucket_idx = (indice / embedding_entry_count_per_rank) % local_size; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + atomicAdd_block(&block_bucket_count_shared[bucket_idx], 1); +#else + atomicAdd(&block_bucket_count_shared[bucket_idx], 1); +#endif + } + __syncthreads(); + if (threadIdx.x == blockDim.x - 1) { + IndexT bucket_offset_tmp = 0; + for (int bi = 0; bi < nbucket; bi++) { + block_bucket_offset_shared[bi] = bucket_offset_tmp; + bucket_offset_tmp += block_bucket_count_shared[bi]; + } + } + if (threadIdx.x < nbucket) { + int bucket_idx = threadIdx.x; + global_bucket_offset_shared[bucket_idx] = + atomicAdd(reinterpret_cast(dev_bucket_atomic_add_ptr) + bucket_idx, + block_bucket_count_shared[bucket_idx]); + } + __syncthreads(); + for (IndexT i = threadIdx.x; i < buffer_size; i += blockDim.x) { + IndexT indice = buffer_load[i]; + IndexT load_idx = i + load_offset; + if (load_idx >= indice_count) break; + int bucket_idx = (indice / embedding_entry_count_per_rank) % local_size; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + int block_bucket_inc = atomicAdd_block(&block_bucket_atomic_add_shared[bucket_idx], 1); +#else + int block_bucket_inc = atomicAdd(&block_bucket_atomic_add_shared[bucket_idx], 1); +#endif + buffer_store[block_bucket_offset_shared[bucket_idx] + block_bucket_inc] = indice; + dev_indice_map[load_idx] = dev_rank_id_offset_ptr[bucket_idx] + + global_bucket_offset_shared[bucket_idx] + block_bucket_inc; + } + __syncthreads(); + for (int bucket_idx = warp_idx; bucket_idx < nbucket; bucket_idx += nwarp) { + int bucket_length = block_bucket_count_shared[bucket_idx]; + IndexT global_bucket_offset = + dev_rank_id_offset_ptr[bucket_idx] + global_bucket_offset_shared[bucket_idx]; + for (int idx = lane_idx; idx < bucket_length; idx += warpSize) { + dev_bucket_indices[global_bucket_offset + idx] = + buffer_store[block_bucket_offset_shared[bucket_idx] + idx]; + } + } + __syncthreads(); + } +} + +template +void reorder_ids_for_hierarchy_temp_func(const void* indices, + wholememory_array_description_t indice_desc, + void* dev_bucket_indices, + void* dev_indice_map, + const int64_t* dev_rank_id_count_ptr, + size_t embedding_entry_count_per_rank, + int local_size, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + int sm_count, + cudaStream_t stream) +{ + WHOLEMEMORY_CHECK(indice_desc.storage_offset == 0); + WHOLEMEMORY_CHECK(indice_desc.dtype == WHOLEMEMORY_DT_INT || + indice_desc.dtype == WHOLEMEMORY_DT_INT64); + + temp_memory_handle dev_rank_id_offset_handle(p_env_fns); + int64_t* dev_rank_id_offset_ptr = static_cast( + dev_rank_id_offset_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64)); + void* cub_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceScan::ExclusiveSum(cub_temp_storage, + temp_storage_bytes, + dev_rank_id_count_ptr, + dev_rank_id_offset_ptr, + local_size, + stream); + cub_temp_storage = p_thrust_allocator->allocate(temp_storage_bytes); + cub::DeviceScan::ExclusiveSum(cub_temp_storage, + temp_storage_bytes, + dev_rank_id_count_ptr, + dev_rank_id_offset_ptr, + local_size, + stream); + p_thrust_allocator->deallocate(reinterpret_cast(cub_temp_storage), temp_storage_bytes); + + temp_memory_handle dev_bucket_atomic_add_handle(p_env_fns); + int64_t* dev_bucket_atomic_add_ptr = static_cast( + dev_bucket_atomic_add_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64)); + cudaMemsetAsync((void*)dev_bucket_atomic_add_ptr, 0, sizeof(int64_t) * local_size, stream); + static constexpr int BLOCK_SIZE = 128; + int block_count = wholememory::div_rounding_up_unsafe(indice_desc.size, BLOCK_SIZE); + block_count = std::min(block_count, sm_count * 4); + + reorder_ids_for_hierarchy_kernel<<>>( + static_cast(indices), + indice_desc.size, + static_cast(dev_bucket_indices), + static_cast(dev_indice_map), + dev_rank_id_offset_ptr, + embedding_entry_count_per_rank, + local_size, + dev_bucket_atomic_add_ptr); +} + +REGISTER_DISPATCH_ONE_TYPE(ReorderIdsForHierarchy, reorder_ids_for_hierarchy_temp_func, SINT3264) + +wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( + void* indices, + wholememory_array_description_t indice_desc, + void* dev_bucket_indices, + void* dev_indice_map, + int64_t* host_bucket_id_count, + size_t embedding_entry_count_per_rank, + wholememory_comm_t wm_global_comm, + wholememory_comm_t wm_local_comm, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + if (indice_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } + int world_size, local_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_global_comm)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); + WHOLEMEMORY_CHECK_NOTHROW(world_size % local_size == 0); + + constexpr int K_DEFAULT_SM_COUNT = 108; + auto prop = get_device_prop(-1); + int sm_count = (prop != nullptr) ? prop->multiProcessorCount : K_DEFAULT_SM_COUNT; + temp_memory_handle dev_rank_id_count_handle(p_env_fns); + int64_t* dev_rank_id_count_ptr = + static_cast(dev_rank_id_count_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64)); + cudaMemsetAsync((void*)dev_rank_id_count_ptr, 0, sizeof(int64_t) * local_size, stream); + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + BucketIdsForHierarchy, + indices, + indice_desc, + dev_rank_id_count_ptr, + embedding_entry_count_per_rank, + local_size, + 0, // ignore + 0, + sm_count, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("bucket_ids_for_hierarchy_func CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } + WM_CUDA_CHECK_NO_THROW(cudaMemcpyAsync(host_bucket_id_count, + dev_rank_id_count_ptr, + local_size * sizeof(int64_t), + cudaMemcpyDeviceToHost, + stream)); + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + ReorderIdsForHierarchy, + indices, + indice_desc, + dev_bucket_indices, + dev_indice_map, + dev_rank_id_count_ptr, + embedding_entry_count_per_rank, + local_size, + p_thrust_allocator, + p_env_fns, + sm_count, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("reorder_ids_for_hierarchy CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } catch (wholememory::logic_error& wle) { + WHOLEMEMORY_ERROR("reorder_ids_for_hierarchy LOGIC Error %s\n", wle.what()); + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_UNKNOW_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t bucket_local_ids_func(void* indices, + wholememory_array_description_t indice_desc, + int64_t* host_bucket_id_count, + size_t embedding_entry_count_per_rank, + wholememory_comm_t wm_local_comm, + wholememory_comm_t wm_cross_comm, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + if (indice_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } + int cross_size, local_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); + + constexpr int K_DEFAULT_SM_COUNT = 108; + auto prop = get_device_prop(-1); + int sm_count = (prop != nullptr) ? prop->multiProcessorCount : K_DEFAULT_SM_COUNT; + temp_memory_handle dev_rank_id_count_handle(p_env_fns); + int64_t* dev_rank_id_count_ptr = + static_cast(dev_rank_id_count_handle.device_malloc(cross_size, WHOLEMEMORY_DT_INT64)); + cudaMemsetAsync((void*)dev_rank_id_count_ptr, 0, sizeof(int64_t) * cross_size, stream); + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + BucketIdsForHierarchy, + indices, + indice_desc, + dev_rank_id_count_ptr, + embedding_entry_count_per_rank, + local_size, + cross_size, + 1, + sm_count, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("bucket_ids_for_hierarchy CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } + WM_CUDA_CHECK_NO_THROW(cudaMemcpyAsync(host_bucket_id_count, + dev_rank_id_count_ptr, + cross_size * sizeof(int64_t), + cudaMemcpyDeviceToHost, + stream)); + WM_CUDA_CHECK(cudaGetLastError()); + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h new file mode 100644 index 000000000..a86a9945e --- /dev/null +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include "wholememory_ops/temp_memory_handle.hpp" + +namespace wholememory_ops { + +wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( + void* indices, + wholememory_array_description_t indice_desc, + void* dev_bucket_indices, + void* dev_indice_map, + int64_t* host_bucket_id_count, + size_t embedding_entry_count_per_rank, + wholememory_comm_t wm_global_comm, + wholememory_comm_t wm_local_comm, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +wholememory_error_code_t bucket_local_ids_func(void* indices, + wholememory_array_description_t indice_desc, + int64_t* host_bucket_id_count, + size_t embedding_entry_count_per_rank, + wholememory_comm_t wm_local_comm, + wholememory_comm_t wm_cross_comm, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.cu b/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.cu new file mode 100644 index 000000000..caa9667c4 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.cu @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sort_unique_ids_for_hierarchy_func.h" +#include "sort_unique_indices_func.h" + +#include +#include + +#include +#include +#include +#include + +#include + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "logger.hpp" +#include "wholememory/communicator.hpp" +#include "wholememory/integer_utils.hpp" +#include "wholememory_ops/register.hpp" +#include "wholememory_ops/temp_memory_handle.hpp" +#include + +namespace wholememory_ops { + +template +__global__ void SortUniqueIndiceMapKernel(IndexT* indice_map, + size_t indice_count, + const IndexT* sort_raw_indices, + const int* unique_count_ptr, + const IndexT* unique_offset_ptr, + size_t num_unique) +{ + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count; + idx += blockDim.x * gridDim.x) { + if (idx >= num_unique) break; + IndexT offset = unique_offset_ptr[idx]; + int count = unique_count_ptr[idx]; + for (IndexT i = offset; i < offset + count; i++) { + indice_map[sort_raw_indices[i]] = idx; + } + } +} + +template +void SortUniqueIndicesMapTempFunc(void* indice_map, + wholememory_array_description_t indice_desc, + const void* sort_raw_indices, + const int* unique_count_ptr, + size_t num_unique, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + static constexpr int BLOCK_SIZE = 128; + int block_count = wholememory::div_rounding_up_unsafe(num_unique, BLOCK_SIZE); + + temp_memory_handle dev_unique_offset_handle(p_env_fns); + IndexT* unique_offset_ptr = + static_cast(dev_unique_offset_handle.device_malloc(num_unique, indice_desc.dtype)); + IndexT* indice_map_ptr = static_cast(indice_map); + const IndexT* sort_raw_indices_ptr = static_cast(sort_raw_indices); + + void* cub_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceScan::ExclusiveSum( + cub_temp_storage, temp_storage_bytes, unique_count_ptr, unique_offset_ptr, num_unique, stream); + cub_temp_storage = p_thrust_allocator->allocate(temp_storage_bytes); + cub::DeviceScan::ExclusiveSum( + cub_temp_storage, temp_storage_bytes, unique_count_ptr, unique_offset_ptr, num_unique, stream); + SortUniqueIndiceMapKernel<<>>(indice_map_ptr, + indice_desc.size, + sort_raw_indices_ptr, + unique_count_ptr, + unique_offset_ptr, + num_unique); + p_thrust_allocator->deallocate(reinterpret_cast(cub_temp_storage), temp_storage_bytes); +} + +REGISTER_DISPATCH_ONE_TYPE(SortUniqueIndicesMapTempFunc, SortUniqueIndicesMapTempFunc, SINT3264) + +wholememory_error_code_t sort_unique_ids_for_hierarchy_func( + void* indices, + wholememory_array_description_t indice_desc, + temp_memory_handle* output_indices_handle, + wholememory_array_description_t* output_indices_desc, + temp_memory_handle* dev_indice_map_handle, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + if (indice_desc.size == 0) { + *output_indices_desc = wholememory_create_array_desc(0, 0, indice_desc.dtype); + return WHOLEMEMORY_SUCCESS; + } + int num_runs = 0; + temp_memory_handle unique_count_handle(p_env_fns); + temp_memory_handle dev_sort_raw_indices_handle(p_env_fns); + void* dev_sort_raw_indices_ptr = + dev_sort_raw_indices_handle.device_malloc(indice_desc.size, indice_desc.dtype); + sort_unique_indices_func(indices, + indice_desc, + dev_sort_raw_indices_ptr, + &num_runs, + output_indices_handle, + &unique_count_handle, + p_thrust_allocator, + p_env_fns, + stream); + *output_indices_desc = wholememory_create_array_desc(num_runs, 0, indice_desc.dtype); + void* dev_indice_map_ptr = + dev_indice_map_handle->device_malloc(indice_desc.size, indice_desc.dtype); + WM_CUDA_CHECK(cudaGetLastError()); + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + SortUniqueIndicesMapTempFunc, + dev_indice_map_ptr, + indice_desc, + dev_sort_raw_indices_ptr, + static_cast(unique_count_handle.pointer()), + num_runs, + p_thrust_allocator, + p_env_fns, + stream); + } catch (...) { + WHOLEMEMORY_FAIL_NOTHROW("map indices failed"); + } + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h b/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h new file mode 100644 index 000000000..8491e58f7 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "wholememory_ops/temp_memory_handle.hpp" +#include +#include +#include + +namespace wholememory_ops { + +wholememory_error_code_t sort_unique_ids_for_hierarchy_func( + void* indices, + wholememory_array_description_t indice_desc, + temp_memory_handle* output_indices_handle, + wholememory_array_description_t* output_indices_desc, + temp_memory_handle* dev_indice_map_handle, // indice_desc + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_unique_indices_func.cu b/cpp/src/wholememory_ops/functions/sort_unique_indices_func.cu new file mode 100644 index 000000000..a3d3fc647 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_unique_indices_func.cu @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sort_indices_func.h" +#include "sort_unique_indices_func.h" + +#include +#include +#include + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "logger.hpp" +#include "wholememory_ops/register.hpp" + +namespace wholememory_ops { + +template +void SortUniqueIndicesTempFunc(const void* indices, + wholememory_array_description_t indice_desc, + void* sort_raw_indices, + int* num_runs, + temp_memory_handle* unique_indices_handle, + temp_memory_handle* unique_count_handle, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + if (indice_desc.size == 0) return; + wm_thrust_allocator& allocator = *p_thrust_allocator; + WHOLEMEMORY_CHECK_NOTHROW(indice_desc.storage_offset == 0); + temp_memory_handle sorted_indices_handle(p_env_fns); + sorted_indices_handle.device_malloc(indice_desc.size, indice_desc.dtype); + IndexT* sorted_indices = static_cast(sorted_indices_handle.pointer()); + + sort_indices_func( + indices, indice_desc, sorted_indices, sort_raw_indices, p_thrust_allocator, p_env_fns, stream); + + unique_indices_handle->device_malloc(indice_desc.size, indice_desc.dtype); + unique_count_handle->device_malloc(indice_desc.size, WHOLEMEMORY_DT_INT); + IndexT* unique_indices = static_cast(unique_indices_handle->pointer()); + int* unique_counts = static_cast(unique_count_handle->pointer()); + temp_memory_handle number_runs_handle(p_env_fns); + number_runs_handle.device_malloc(1, WHOLEMEMORY_DT_INT); + int* number_runs = static_cast(number_runs_handle.pointer()); + void* cub_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + cub::DeviceRunLengthEncode::Encode(cub_temp_storage, + temp_storage_bytes, + sorted_indices, + unique_indices, + unique_counts, + number_runs, + indice_desc.size, + stream); + cub_temp_storage = allocator.allocate(temp_storage_bytes); + cub::DeviceRunLengthEncode::Encode(cub_temp_storage, + temp_storage_bytes, + sorted_indices, + unique_indices, + unique_counts, + number_runs, + indice_desc.size, + stream); + WM_CUDA_CHECK_NO_THROW( + cudaMemcpyAsync(num_runs, number_runs, sizeof(int), cudaMemcpyDeviceToHost, stream)); +} + +REGISTER_DISPATCH_ONE_TYPE(SortUniqueIndicesTempFunc, SortUniqueIndicesTempFunc, SINT3264) + +wholememory_error_code_t sort_unique_indices_func(const void* indices, + wholememory_array_description_t indice_desc, + void* sort_raw_indices, + int* num_runs, + temp_memory_handle* unique_indices_handle, + temp_memory_handle* unique_count_handle, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + SortUniqueIndicesTempFunc, + indices, + indice_desc, + sort_raw_indices, + num_runs, + unique_indices_handle, + unique_count_handle, + p_thrust_allocator, + p_env_fns, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("sort_unique_indices_func CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } catch (wholememory::logic_error& wle) { + WHOLEMEMORY_ERROR("sort_unique_indices_func LOGIC Error %s\n", wle.what()); + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_UNKNOW_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_unique_indices_func.h b/cpp/src/wholememory_ops/functions/sort_unique_indices_func.h new file mode 100644 index 000000000..2ff697c90 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_unique_indices_func.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include + +namespace wholememory_ops { + +wholememory_error_code_t sort_unique_indices_func(const void* indices, + wholememory_array_description_t indice_desc, + void* sort_raw_indices, + int* num_runs, + temp_memory_handle* unique_indices_handle, + temp_memory_handle* unique_count_handle, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu index 2e49bcd02..808ebe768 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu @@ -21,16 +21,110 @@ #include "logger.hpp" #include "wholememory/communicator.hpp" #include "wholememory/memory_handle.hpp" -#include "wholememory_ops/functions/bucket_ids_func.h" +#include "wholememory_ops/functions/bucket_ids_for_hierarchy_func.h" #include "wholememory_ops/functions/exchange_embeddings_nccl_func.h" -#include "wholememory_ops/functions/exchange_ids_nccl_func.h" #include "wholememory_ops/functions/gather_scatter_func.h" +#include "wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h" #include "wholememory_ops/gather_op_impl.h" #include "wholememory_ops/temp_memory_handle.hpp" #include "wholememory_ops/thrust_allocator.hpp" namespace wholememory_ops { +static wholememory_error_code_t wholememory_cross_gather( + wholememory_handle_t wholememory_handle, + wholememory_matrix_description_t wholememory_desc, + void* indices, + wholememory_array_description_t indice_desc, + void* output, + wholememory_matrix_description_t output_desc, + size_t embedding_entry_count_per_rank, + wholememory_comm_t wm_local_comm, + wholememory_comm_t wm_cross_comm, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream, + int gather_sms) +{ + int cross_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); + // bucket ids + std::vector host_bucket_id_count(cross_size, 0); + std::vector host_bucket_id_offset(cross_size); + std::vector host_recv_id_count(cross_size, 0); + std::vector host_recv_id_offset(cross_size); + bucket_local_ids_func(indices, + indice_desc, + host_bucket_id_count.data(), + embedding_entry_count_per_rank, + wm_local_comm, + wm_cross_comm, + p_thrust_allocator, + p_env_fns, + stream); + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); + // exchange node count + wm_cross_comm->host_alltoall( + host_bucket_id_count.data(), host_recv_id_count.data(), 1, WHOLEMEMORY_DT_INT64); + host_bucket_id_offset[0] = 0; + for (int i = 1; i < cross_size; i++) + host_bucket_id_offset[i] = host_bucket_id_offset[i - 1] + host_bucket_id_count[i - 1]; + wm_cross_comm->sync_stream(); + // exchange indices + int64_t total_recv_count = 0; + for (int i = 0; i < cross_size; i++) { + host_recv_id_offset[i] = total_recv_count; + total_recv_count += host_recv_id_count[i]; + } + temp_memory_handle dev_recv_bucket_indices_handle(p_env_fns); + void* dev_recv_bucket_indices_ptr = + dev_recv_bucket_indices_handle.device_malloc(total_recv_count, indice_desc.dtype); + wm_cross_comm->alltoallv(indices, + dev_recv_bucket_indices_ptr, + reinterpret_cast(host_bucket_id_count.data()), + reinterpret_cast(host_bucket_id_offset.data()), + reinterpret_cast(host_recv_id_count.data()), + reinterpret_cast(host_recv_id_offset.data()), + indice_desc.dtype, + stream); + wm_cross_comm->sync_stream(stream); + // local gather + temp_memory_handle dev_local_gather_buffer_handle(p_env_fns); + void* dev_local_gather_buffer_ptr = dev_local_gather_buffer_handle.device_malloc( + wholememory_desc.sizes[1] * total_recv_count, output_desc.dtype); + int64_t local_gather_buffer_size[2] = {total_recv_count, wholememory_desc.sizes[1]}; + wholememory_matrix_description_t local_gather_buffer_desc = wholememory_create_matrix_desc( + local_gather_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); + void* local_fake_ptr = nullptr; + size_t local_mem_offset, local_mem_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_local_memory( + &local_fake_ptr, &local_mem_size, &local_mem_offset, wholememory_handle)); + local_fake_ptr = static_cast(local_fake_ptr) - local_mem_offset; + wholememory_gref_t local_fake_gref = + wholememory_create_continuous_global_reference(local_fake_ptr); + auto local_gather_indice_desc = + wholememory_create_array_desc(total_recv_count, 0, indice_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(local_fake_gref, + wholememory_desc, + dev_recv_bucket_indices_ptr, + local_gather_indice_desc, + dev_local_gather_buffer_ptr, + local_gather_buffer_desc, + stream, + gather_sms)); + // exchange embeddings + size_t output_embedding_size = + wholememory_desc.sizes[1] * wholememory_dtype_get_element_size(output_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(exchange_embeddings_nccl_func(dev_local_gather_buffer_ptr, + host_recv_id_count.data(), + host_bucket_id_count.data(), + output, + output_embedding_size, + wm_cross_comm, + stream)); + return WHOLEMEMORY_SUCCESS; +} + wholememory_error_code_t wholememory_gather_hierarchy( wholememory_handle_t wholememory_handle, wholememory_matrix_description_t wholememory_desc, @@ -53,109 +147,171 @@ wholememory_error_code_t wholememory_gather_hierarchy( size_t embedding_size_per_rank; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_get_partition_plan(&embedding_size_per_rank, wholememory_handle)); - size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); size_t embedding_entry_size = element_size * wholememory_desc.stride; - WHOLEMEMORY_EXPECTS_NOTHROW( embedding_size_per_rank % embedding_entry_size == 0, "embedding_size_per_rank=%ld is not multiple of embedding_entry_size=%ldx%ld", embedding_size_per_rank, element_size, wholememory_desc.stride); - size_t embedding_entry_count_per_rank = embedding_size_per_rank / embedding_entry_size; - wholememory_comm_t wm_comm; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, wholememory_handle)); + wholememory_comm_t wm_global_comm; + int world_size, world_rank; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_global_comm, wholememory_handle)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_global_comm)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&world_rank, wm_global_comm)); + + wholememory_comm_t wm_local_comm; + int local_size, local_rank; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_local_communicator(&wm_local_comm, wholememory_handle)); + // WHOLEMEMORY_RETURN_ON_FAIL(wholememory_split_communicator( + // &wm_local_comm, wm_global_comm, world_rank / local_size, world_rank % local_size)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&local_rank, wm_local_comm)); - int world_size; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_comm)); + wholememory_comm_t wm_cross_comm; + int cross_size; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_cross_communicator(&wm_cross_comm, wholememory_handle)); + // WHOLEMEMORY_RETURN_ON_FAIL(wholememory_split_communicator( + // &wm_cross_comm, wm_global_comm, world_rank % local_size, world_rank / local_size)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); + WHOLEMEMORY_CHECK_NOTHROW(world_size == local_size * cross_size); - temp_memory_handle host_rank_id_count(p_env_fns), host_recv_rank_id_count(p_env_fns); - int64_t* host_rank_id_count_ptr = - static_cast(host_rank_id_count.host_malloc(world_size, WHOLEMEMORY_DT_INT64)); - int64_t* host_recv_rank_id_count_ptr = - static_cast(host_recv_rank_id_count.host_malloc(world_size, WHOLEMEMORY_DT_INT64)); + temp_memory_handle dev_bucket_indices_handle(p_env_fns); + void* dev_bucket_indices_ptr = + dev_bucket_indices_handle.device_malloc(indice_desc.size, indice_desc.dtype); + temp_memory_handle dev_bucket_ids_map_handle(p_env_fns); + void* dev_bucket_ids_map_ptr = + dev_bucket_ids_map_handle.device_malloc(indice_desc.size, indice_desc.dtype); - temp_memory_handle dev_recv_indice_buffer(p_env_fns); - temp_memory_handle dev_raw_indice(p_env_fns); - int64_t* dev_raw_indice_ptr = - static_cast(dev_raw_indice.device_malloc(indice_desc.size, WHOLEMEMORY_DT_INT64)); + std::vector host_bucket_id_count(local_size, 0); + std::vector host_bucket_id_offset(local_size); + std::vector host_recv_id_count(local_size, 0); + std::vector host_recv_id_offset(local_size); + // bucket indices + WHOLEMEMORY_RETURN_ON_FAIL( + bucket_and_reorder_ids_for_hierarchy_func(indices, + indice_desc, + dev_bucket_indices_ptr, + dev_bucket_ids_map_ptr, + host_bucket_id_count.data(), + embedding_entry_count_per_rank, + wm_global_comm, + wm_local_comm, + &thrust_allocator, + p_env_fns, + stream)); + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); + // exchange node count + wm_local_comm->host_alltoall( + host_bucket_id_count.data(), host_recv_id_count.data(), 1, WHOLEMEMORY_DT_INT64); + host_bucket_id_offset[0] = 0; + for (int i = 1; i < local_size; i++) + host_bucket_id_offset[i] = host_bucket_id_offset[i - 1] + host_bucket_id_count[i - 1]; + wm_local_comm->sync_stream(); + // exchange indices int64_t total_recv_count = 0; - WHOLEMEMORY_RETURN_ON_FAIL(bucket_and_exchange_ids_func(indices, - indice_desc, - host_recv_rank_id_count_ptr, - host_rank_id_count_ptr, - &dev_recv_indice_buffer, - dev_raw_indice_ptr, - embedding_entry_count_per_rank, - wm_comm, - &thrust_allocator, - p_env_fns, - stream)); - - // Local Gather - for (int i = 0; i < world_size; i++) { - total_recv_count += host_recv_rank_id_count_ptr[i]; + for (int i = 0; i < local_size; i++) { + host_recv_id_offset[i] = total_recv_count; + total_recv_count += host_recv_id_count[i]; } - size_t local_mem_offset, local_mem_size; - temp_memory_handle dev_local_gather_buffer(p_env_fns); - temp_memory_handle dev_embedding_recv_buffer(p_env_fns); - void* dev_local_gather_buffer_ptr = dev_local_gather_buffer.device_malloc( - wholememory_desc.sizes[1] * total_recv_count, output_desc.dtype); - void* dev_embedding_recv_buffer_ptr = dev_embedding_recv_buffer.device_malloc( - wholememory_desc.sizes[1] * indice_desc.size, output_desc.dtype); - void* local_fake_ptr = nullptr; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_local_memory( - &local_fake_ptr, &local_mem_size, &local_mem_offset, wholememory_handle)); - local_fake_ptr = static_cast(local_fake_ptr) - local_mem_offset; - wholememory_gref_t local_fake_gref = - wholememory_create_continuous_global_reference(local_fake_ptr); - int64_t local_buffer_size[2] = {total_recv_count, wholememory_desc.sizes[1]}; - wholememory_matrix_description_t local_gather_buffer_desc = wholememory_create_matrix_desc( - local_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); - auto dev_recv_indice_desc = + temp_memory_handle dev_recv_bucket_indices_handle(p_env_fns); + void* dev_recv_bucket_indices_ptr = + dev_recv_bucket_indices_handle.device_malloc(total_recv_count, indice_desc.dtype); + auto recv_bucket_indices_desc = wholememory_create_array_desc(total_recv_count, 0, indice_desc.dtype); - WHOLEMEMORY_RETURN_ON_FAIL(gather_func(local_fake_gref, - wholememory_desc, - dev_recv_indice_buffer.pointer(), - dev_recv_indice_desc, - dev_local_gather_buffer_ptr, - local_gather_buffer_desc, + wm_local_comm->alltoallv(dev_bucket_indices_ptr, + dev_recv_bucket_indices_ptr, + reinterpret_cast(host_bucket_id_count.data()), + reinterpret_cast(host_bucket_id_offset.data()), + reinterpret_cast(host_recv_id_count.data()), + reinterpret_cast(host_recv_id_offset.data()), + indice_desc.dtype, + stream); + wm_local_comm->sync_stream(stream); + WM_CUDA_CHECK(cudaGetLastError()); + // sort unique recv indices + temp_memory_handle sort_unique_indices_handle(p_env_fns); + wholememory_array_description_t sort_unique_indice_desc; + temp_memory_handle dev_sort_unique_ids_map_handle(p_env_fns); + sort_unique_ids_for_hierarchy_func(dev_recv_bucket_indices_ptr, + recv_bucket_indices_desc, + &sort_unique_indices_handle, + &sort_unique_indice_desc, + &dev_sort_unique_ids_map_handle, + &thrust_allocator, + p_env_fns, + stream); + // cross gather + temp_memory_handle dev_cross_gather_buffer_handle(p_env_fns); + void* dev_cross_gather_buffer_ptr = dev_cross_gather_buffer_handle.device_malloc( + wholememory_desc.sizes[1] * sort_unique_indice_desc.size, output_desc.dtype); + int64_t cross_gather_buffer_size[2] = {sort_unique_indice_desc.size, wholememory_desc.sizes[1]}; + wholememory_matrix_description_t cross_gather_buffer_desc = wholememory_create_matrix_desc( + cross_gather_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); + wholememory_cross_gather(wholememory_handle, + wholememory_desc, + sort_unique_indices_handle.pointer(), + sort_unique_indice_desc, + dev_cross_gather_buffer_ptr, + cross_gather_buffer_desc, + embedding_entry_count_per_rank, + wm_local_comm, + wm_cross_comm, + &thrust_allocator, + p_env_fns, + stream, + gather_sms); + // sort-unique reorder + temp_memory_handle dev_embedding_map_buffer_handle(p_env_fns); + void* dev_embedding_map_buffer_ptr = dev_embedding_map_buffer_handle.device_malloc( + wholememory_desc.sizes[1] * total_recv_count, output_desc.dtype); + int64_t embedding_map_buffer_size[2] = {total_recv_count, wholememory_desc.sizes[1]}; + wholememory_matrix_description_t embedding_map_buffer_desc = wholememory_create_matrix_desc( + embedding_map_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); + wholememory_gref_t cross_gather_fake_gref = + wholememory_create_continuous_global_reference(dev_cross_gather_buffer_ptr); + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(cross_gather_fake_gref, + cross_gather_buffer_desc, + dev_sort_unique_ids_map_handle.pointer(), + recv_bucket_indices_desc, + dev_embedding_map_buffer_ptr, + embedding_map_buffer_desc, stream, gather_sms)); - // AllToAllV for embeddings - size_t embedding_size = + // exchange embeddings + size_t output_embedding_size = wholememory_desc.sizes[1] * wholememory_dtype_get_element_size(output_desc.dtype); - WHOLEMEMORY_RETURN_ON_FAIL(exchange_embeddings_nccl_func(dev_local_gather_buffer_ptr, - host_recv_rank_id_count_ptr, - host_rank_id_count_ptr, - dev_embedding_recv_buffer_ptr, - embedding_size, - wm_comm, + temp_memory_handle dev_recv_embedding_buffer_handle(p_env_fns); + void* dev_recv_embedding_buffer_ptr = dev_recv_embedding_buffer_handle.device_malloc( + wholememory_desc.sizes[1] * indice_desc.size, output_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(exchange_embeddings_nccl_func(dev_embedding_map_buffer_ptr, + host_recv_id_count.data(), + host_bucket_id_count.data(), + dev_recv_embedding_buffer_ptr, + output_embedding_size, + wm_local_comm, stream)); - // Local reorder - int64_t total_need_indice_count = 0; - for (int i = 0; i < world_size; i++) { - total_need_indice_count += host_rank_id_count_ptr[i]; - } - wholememory_gref_t output_gref = wholememory_create_continuous_global_reference(output); - wholememory_matrix_description_t local_recv_buffer_desc = - wholememory_create_matrix_desc(output_desc.sizes, output_desc.sizes[1], 0, output_desc.dtype); - local_recv_buffer_desc.sizes[0] = total_need_indice_count; - auto raw_indice_desc = - wholememory_create_array_desc(total_need_indice_count, 0, WHOLEMEMORY_DT_INT64); - WHOLEMEMORY_RETURN_ON_FAIL(scatter_func(dev_embedding_recv_buffer_ptr, - local_recv_buffer_desc, - dev_raw_indice_ptr, - raw_indice_desc, - output_gref, - output_desc, - stream)); + // bucket reorder + wholememory_gref_t recv_embedding_buffer_fake_gref = + wholememory_create_continuous_global_reference(dev_recv_embedding_buffer_ptr); + int64_t recv_embedding_buffer_size[2] = {indice_desc.size, wholememory_desc.sizes[1]}; + wholememory_matrix_description_t recv_embedding_buffer_desc = wholememory_create_matrix_desc( + recv_embedding_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(recv_embedding_buffer_fake_gref, + recv_embedding_buffer_desc, + dev_bucket_ids_map_ptr, + indice_desc, + output, + output_desc, + stream, + gather_sms)); WM_CUDA_CHECK(cudaGetLastError()); - // WM_CUDA_CHECK(cudaStreamSynchronize(stream)); } catch (wholememory::cuda_error& wce) { WHOLEMEMORY_ERROR("CUDA logic Error %s\n", wce.what()); return WHOLEMEMORY_CUDA_ERROR; From 5aed03c7ccd13f035d8dc991b890666afad3e176 Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Wed, 21 Aug 2024 17:30:43 +0800 Subject: [PATCH 4/6] quick fix in creating sub-communicator for hierarchy memory type --- cpp/src/wholememory/memory_handle.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/wholememory/memory_handle.cpp b/cpp/src/wholememory/memory_handle.cpp index d9ededa87..48b687683 100644 --- a/cpp/src/wholememory/memory_handle.cpp +++ b/cpp/src/wholememory/memory_handle.cpp @@ -1778,7 +1778,7 @@ class hierarchy_wholememory_impl : public wholememory_impl { #endif } else { int world_rank = -1, local_size = -1; - wholememory_communicator_get_size(&world_rank, global_comm); + wholememory_communicator_get_rank(&world_rank, global_comm); wholememory_communicator_get_local_size(&local_size, global_comm); wholememory_split_communicator( &cross_comm_, global_comm, world_rank % local_size, world_rank / local_size); @@ -2038,7 +2038,7 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha #endif } else { int world_rank = -1, local_size = -1; - wholememory_communicator_get_size(&world_rank, comm); + wholememory_communicator_get_rank(&world_rank, comm); wholememory_communicator_get_local_size(&local_size, comm); wholememory_split_communicator( &local_comm, comm, world_rank / local_size, world_rank % local_size); From 2d4fceb4ce783236cbc8a3f429e2f7468838531b Mon Sep 17 00:00:00 2001 From: Zhuofan Li Date: Thu, 22 Aug 2024 10:05:41 +0000 Subject: [PATCH 5/6] set a knob for sort-unique --- .../bucket_ids_for_hierarchy_func.cu | 109 +++++++++++++----- .../functions/bucket_ids_for_hierarchy_func.h | 1 + .../gather_op_impl_hierarchy.cu | 98 ++++++++++------ 3 files changed, 140 insertions(+), 68 deletions(-) diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu index 6ad83a49a..7c50d2b2f 100644 --- a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu @@ -33,16 +33,16 @@ namespace wholememory_ops { -template +template __global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices, size_t indice_count, int64_t* dev_rank_id_count_ptr, size_t embedding_entry_count_per_rank, int local_size, - int bucket_size) + int nbucket) { extern __shared__ int rank_count_shared[]; - for (int idx = threadIdx.x; idx < bucket_size; idx += blockDim.x) { + for (int idx = threadIdx.x; idx < nbucket; idx += blockDim.x) { rank_count_shared[idx] = 0; } __syncthreads(); @@ -52,9 +52,9 @@ __global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices, if (node_idx < 0) continue; int rank = node_idx / embedding_entry_count_per_rank; int bucket = 0; - if (CROSS_OR_LOCAL == 0) // bucket cross ranks + if (BUCKET_CROSS_OR_LOCAL == 0) bucket = rank % local_size; - else // bucket local ranks + else bucket = rank / local_size; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 atomicAdd_block(&rank_count_shared[bucket], 1); @@ -63,7 +63,7 @@ __global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices, #endif } __syncthreads(); - for (int idx = threadIdx.x; idx < bucket_size; idx += blockDim.x) { + for (int idx = threadIdx.x; idx < nbucket; idx += blockDim.x) { atomicAdd(reinterpret_cast(dev_rank_id_count_ptr) + idx, static_cast(rank_count_shared[idx])); } @@ -113,7 +113,7 @@ void bucket_ids_for_hierarchy_temp_func(const void* indices, REGISTER_DISPATCH_ONE_TYPE(BucketIdsForHierarchy, bucket_ids_for_hierarchy_temp_func, SINT3264) -template +template __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, size_t indice_count, IndexT* dev_bucket_indices, @@ -121,9 +121,9 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, const int64_t* dev_rank_id_offset_ptr, size_t embedding_entry_count_per_rank, int local_size, + int nbucket, int64_t* dev_bucket_atomic_add_ptr) { - int nbucket = local_size; constexpr size_t shared_mem_size = 24576; __shared__ char shared_mem[shared_mem_size]; int* block_bucket_count_shared = reinterpret_cast(shared_mem); @@ -155,7 +155,13 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, IndexT indice = indices[load_idx]; buffer_load[i] = indice; - int bucket_idx = (indice / embedding_entry_count_per_rank) % local_size; + int bucket_idx = 0; + int rank = indice / embedding_entry_count_per_rank; + if (BUCKET_CROSS_OR_LOCAL == 0) { + bucket_idx = rank % local_size; + } else { + bucket_idx = rank / local_size; + } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 atomicAdd_block(&block_bucket_count_shared[bucket_idx], 1); #else @@ -181,7 +187,13 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, IndexT indice = buffer_load[i]; IndexT load_idx = i + load_offset; if (load_idx >= indice_count) break; - int bucket_idx = (indice / embedding_entry_count_per_rank) % local_size; + int bucket_idx = 0; + int rank = indice / embedding_entry_count_per_rank; + if (BUCKET_CROSS_OR_LOCAL == 0) { + bucket_idx = rank % local_size; + } else { + bucket_idx = rank / local_size; + } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 int block_bucket_inc = atomicAdd_block(&block_bucket_atomic_add_shared[bucket_idx], 1); #else @@ -213,6 +225,8 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices, const int64_t* dev_rank_id_count_ptr, size_t embedding_entry_count_per_rank, int local_size, + int cross_size, + int bucket_cross_or_local, wm_thrust_allocator* p_thrust_allocator, wholememory_env_func_t* p_env_fns, int sm_count, @@ -221,44 +235,63 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices, WHOLEMEMORY_CHECK(indice_desc.storage_offset == 0); WHOLEMEMORY_CHECK(indice_desc.dtype == WHOLEMEMORY_DT_INT || indice_desc.dtype == WHOLEMEMORY_DT_INT64); - + int nbucket = 0; + if (bucket_cross_or_local == 0) { + nbucket = local_size; + } else { + nbucket = cross_size; + } temp_memory_handle dev_rank_id_offset_handle(p_env_fns); - int64_t* dev_rank_id_offset_ptr = static_cast( - dev_rank_id_offset_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64)); + int64_t* dev_rank_id_offset_ptr = + static_cast(dev_rank_id_offset_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64)); void* cub_temp_storage = NULL; size_t temp_storage_bytes = 0; cub::DeviceScan::ExclusiveSum(cub_temp_storage, temp_storage_bytes, dev_rank_id_count_ptr, dev_rank_id_offset_ptr, - local_size, + nbucket, stream); cub_temp_storage = p_thrust_allocator->allocate(temp_storage_bytes); cub::DeviceScan::ExclusiveSum(cub_temp_storage, temp_storage_bytes, dev_rank_id_count_ptr, dev_rank_id_offset_ptr, - local_size, + nbucket, stream); p_thrust_allocator->deallocate(reinterpret_cast(cub_temp_storage), temp_storage_bytes); temp_memory_handle dev_bucket_atomic_add_handle(p_env_fns); int64_t* dev_bucket_atomic_add_ptr = static_cast( - dev_bucket_atomic_add_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64)); - cudaMemsetAsync((void*)dev_bucket_atomic_add_ptr, 0, sizeof(int64_t) * local_size, stream); + dev_bucket_atomic_add_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64)); + cudaMemsetAsync((void*)dev_bucket_atomic_add_ptr, 0, sizeof(int64_t) * nbucket, stream); static constexpr int BLOCK_SIZE = 128; int block_count = wholememory::div_rounding_up_unsafe(indice_desc.size, BLOCK_SIZE); block_count = std::min(block_count, sm_count * 4); - reorder_ids_for_hierarchy_kernel<<>>( - static_cast(indices), - indice_desc.size, - static_cast(dev_bucket_indices), - static_cast(dev_indice_map), - dev_rank_id_offset_ptr, - embedding_entry_count_per_rank, - local_size, - dev_bucket_atomic_add_ptr); + if (bucket_cross_or_local == 0) + reorder_ids_for_hierarchy_kernel + <<>>(static_cast(indices), + indice_desc.size, + static_cast(dev_bucket_indices), + static_cast(dev_indice_map), + dev_rank_id_offset_ptr, + embedding_entry_count_per_rank, + local_size, + nbucket, + dev_bucket_atomic_add_ptr); + else + reorder_ids_for_hierarchy_kernel + <<>>(static_cast(indices), + indice_desc.size, + static_cast(dev_bucket_indices), + static_cast(dev_indice_map), + dev_rank_id_offset_ptr, + embedding_entry_count_per_rank, + local_size, + nbucket, + dev_bucket_atomic_add_ptr); + ; } REGISTER_DISPATCH_ONE_TYPE(ReorderIdsForHierarchy, reorder_ids_for_hierarchy_temp_func, SINT3264) @@ -272,23 +305,33 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( size_t embedding_entry_count_per_rank, wholememory_comm_t wm_global_comm, wholememory_comm_t wm_local_comm, + int bucket_cross_or_local, wm_thrust_allocator* p_thrust_allocator, wholememory_env_func_t* p_env_fns, cudaStream_t stream) { if (indice_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } - int world_size, local_size; + int world_size, local_size, cross_size; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_global_comm)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); WHOLEMEMORY_CHECK_NOTHROW(world_size % local_size == 0); + cross_size = world_size / local_size; + WHOLEMEMORY_EXPECTS_NOTHROW(bucket_cross_or_local == 0 || bucket_cross_or_local == 1, + "param bucket_cross_or_local must be 0 or 1, 0: cross, 1: local"); + int nbucket = 0; + if (bucket_cross_or_local == 0) { // bucket by cross id + nbucket = local_size; + } else { // bucket by local id + nbucket = cross_size; + } constexpr int K_DEFAULT_SM_COUNT = 108; auto prop = get_device_prop(-1); int sm_count = (prop != nullptr) ? prop->multiProcessorCount : K_DEFAULT_SM_COUNT; temp_memory_handle dev_rank_id_count_handle(p_env_fns); int64_t* dev_rank_id_count_ptr = - static_cast(dev_rank_id_count_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64)); - cudaMemsetAsync((void*)dev_rank_id_count_ptr, 0, sizeof(int64_t) * local_size, stream); + static_cast(dev_rank_id_count_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64)); + cudaMemsetAsync((void*)dev_rank_id_count_ptr, 0, sizeof(int64_t) * nbucket, stream); try { DISPATCH_ONE_TYPE(indice_desc.dtype, BucketIdsForHierarchy, @@ -297,8 +340,8 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( dev_rank_id_count_ptr, embedding_entry_count_per_rank, local_size, - 0, // ignore - 0, + cross_size, + bucket_cross_or_local, sm_count, stream); } catch (wholememory::cuda_error& wce) { @@ -307,7 +350,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( } WM_CUDA_CHECK_NO_THROW(cudaMemcpyAsync(host_bucket_id_count, dev_rank_id_count_ptr, - local_size * sizeof(int64_t), + nbucket * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); try { @@ -320,6 +363,8 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( dev_rank_id_count_ptr, embedding_entry_count_per_rank, local_size, + cross_size, + bucket_cross_or_local, p_thrust_allocator, p_env_fns, sm_count, diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h index a86a9945e..d6d061c5e 100644 --- a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h @@ -32,6 +32,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( size_t embedding_entry_count_per_rank, wholememory_comm_t wm_global_comm, wholememory_comm_t wm_local_comm, + int bucket_cross_or_local, // 0: cross, 1: local wm_thrust_allocator* p_thrust_allocator, wholememory_env_func_t* p_env_fns, cudaStream_t stream); diff --git a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu index 808ebe768..3e29d29fc 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu @@ -38,6 +38,7 @@ static wholememory_error_code_t wholememory_cross_gather( wholememory_array_description_t indice_desc, void* output, wholememory_matrix_description_t output_desc, + int64_t* host_bucket_id_count_ptr, size_t embedding_entry_count_per_rank, wholememory_comm_t wm_local_comm, wholememory_comm_t wm_cross_comm, @@ -48,27 +49,15 @@ static wholememory_error_code_t wholememory_cross_gather( { int cross_size; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); - // bucket ids - std::vector host_bucket_id_count(cross_size, 0); std::vector host_bucket_id_offset(cross_size); std::vector host_recv_id_count(cross_size, 0); std::vector host_recv_id_offset(cross_size); - bucket_local_ids_func(indices, - indice_desc, - host_bucket_id_count.data(), - embedding_entry_count_per_rank, - wm_local_comm, - wm_cross_comm, - p_thrust_allocator, - p_env_fns, - stream); - WM_CUDA_CHECK(cudaStreamSynchronize(stream)); // exchange node count wm_cross_comm->host_alltoall( - host_bucket_id_count.data(), host_recv_id_count.data(), 1, WHOLEMEMORY_DT_INT64); + host_bucket_id_count_ptr, host_recv_id_count.data(), 1, WHOLEMEMORY_DT_INT64); host_bucket_id_offset[0] = 0; for (int i = 1; i < cross_size; i++) - host_bucket_id_offset[i] = host_bucket_id_offset[i - 1] + host_bucket_id_count[i - 1]; + host_bucket_id_offset[i] = host_bucket_id_offset[i - 1] + host_bucket_id_count_ptr[i - 1]; wm_cross_comm->sync_stream(); // exchange indices int64_t total_recv_count = 0; @@ -81,7 +70,7 @@ static wholememory_error_code_t wholememory_cross_gather( dev_recv_bucket_indices_handle.device_malloc(total_recv_count, indice_desc.dtype); wm_cross_comm->alltoallv(indices, dev_recv_bucket_indices_ptr, - reinterpret_cast(host_bucket_id_count.data()), + reinterpret_cast(host_bucket_id_count_ptr), reinterpret_cast(host_bucket_id_offset.data()), reinterpret_cast(host_recv_id_count.data()), reinterpret_cast(host_recv_id_offset.data()), @@ -117,7 +106,7 @@ static wholememory_error_code_t wholememory_cross_gather( wholememory_desc.sizes[1] * wholememory_dtype_get_element_size(output_desc.dtype); WHOLEMEMORY_RETURN_ON_FAIL(exchange_embeddings_nccl_func(dev_local_gather_buffer_ptr, host_recv_id_count.data(), - host_bucket_id_count.data(), + host_bucket_id_count_ptr, output, output_embedding_size, wm_cross_comm, @@ -142,6 +131,8 @@ wholememory_error_code_t wholememory_gather_hierarchy( return WHOLEMEMORY_INVALID_INPUT; } + bool sort_unique_indices = true; + wm_thrust_allocator thrust_allocator(p_env_fns); size_t embedding_size_per_rank; @@ -168,7 +159,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( WHOLEMEMORY_RETURN_ON_FAIL( wholememory_get_local_communicator(&wm_local_comm, wholememory_handle)); // WHOLEMEMORY_RETURN_ON_FAIL(wholememory_split_communicator( - // &wm_local_comm, wm_global_comm, world_rank / local_size, world_rank % local_size)); + // &wm_local_comm, wm_global_comm, world_rank / local_size, world_rank % local_size)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&local_rank, wm_local_comm)); @@ -177,7 +168,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( WHOLEMEMORY_RETURN_ON_FAIL( wholememory_get_cross_communicator(&wm_cross_comm, wholememory_handle)); // WHOLEMEMORY_RETURN_ON_FAIL(wholememory_split_communicator( - // &wm_cross_comm, wm_global_comm, world_rank % local_size, world_rank / local_size)); + // &wm_cross_comm, wm_global_comm, world_rank % local_size, world_rank / local_size)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); WHOLEMEMORY_CHECK_NOTHROW(world_size == local_size * cross_size); @@ -203,6 +194,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( embedding_entry_count_per_rank, wm_global_comm, wm_local_comm, + 0, &thrust_allocator, p_env_fns, stream)); @@ -235,31 +227,65 @@ wholememory_error_code_t wholememory_gather_hierarchy( stream); wm_local_comm->sync_stream(stream); WM_CUDA_CHECK(cudaGetLastError()); - // sort unique recv indices - temp_memory_handle sort_unique_indices_handle(p_env_fns); - wholememory_array_description_t sort_unique_indice_desc; - temp_memory_handle dev_sort_unique_ids_map_handle(p_env_fns); - sort_unique_ids_for_hierarchy_func(dev_recv_bucket_indices_ptr, - recv_bucket_indices_desc, - &sort_unique_indices_handle, - &sort_unique_indice_desc, - &dev_sort_unique_ids_map_handle, - &thrust_allocator, - p_env_fns, - stream); + // sort unique / bucket recv indices + temp_memory_handle cross_gather_indices_handle(p_env_fns); + wholememory_array_description_t cross_gather_indices_desc; + temp_memory_handle dev_cross_gather_id_map_handle(p_env_fns); + std::vector host_cross_bucket_id_count(cross_size, 0); + if (sort_unique_indices) { + sort_unique_ids_for_hierarchy_func(dev_recv_bucket_indices_ptr, + recv_bucket_indices_desc, + &cross_gather_indices_handle, + &cross_gather_indices_desc, + &dev_cross_gather_id_map_handle, + &thrust_allocator, + p_env_fns, + stream); + bucket_local_ids_func(cross_gather_indices_handle.pointer(), + cross_gather_indices_desc, + host_cross_bucket_id_count.data(), + embedding_entry_count_per_rank, + wm_local_comm, + wm_cross_comm, + &thrust_allocator, + p_env_fns, + stream); + } else { + void* cross_gather_indices_ptr = cross_gather_indices_handle.device_malloc( + recv_bucket_indices_desc.size, recv_bucket_indices_desc.dtype); + void* dev_cross_gather_id_map_ptr = dev_cross_gather_id_map_handle.device_malloc( + recv_bucket_indices_desc.size, recv_bucket_indices_desc.dtype); + cross_gather_indices_desc = recv_bucket_indices_desc; + WHOLEMEMORY_RETURN_ON_FAIL( + bucket_and_reorder_ids_for_hierarchy_func(dev_recv_bucket_indices_ptr, + recv_bucket_indices_desc, + cross_gather_indices_ptr, + dev_cross_gather_id_map_ptr, + host_cross_bucket_id_count.data(), + embedding_entry_count_per_rank, + wm_global_comm, + wm_local_comm, + 1, + &thrust_allocator, + p_env_fns, + stream)); + } + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); // cross gather temp_memory_handle dev_cross_gather_buffer_handle(p_env_fns); void* dev_cross_gather_buffer_ptr = dev_cross_gather_buffer_handle.device_malloc( - wholememory_desc.sizes[1] * sort_unique_indice_desc.size, output_desc.dtype); - int64_t cross_gather_buffer_size[2] = {sort_unique_indice_desc.size, wholememory_desc.sizes[1]}; + wholememory_desc.sizes[1] * cross_gather_indices_desc.size, output_desc.dtype); + int64_t cross_gather_buffer_size[2] = {cross_gather_indices_desc.size, + wholememory_desc.sizes[1]}; wholememory_matrix_description_t cross_gather_buffer_desc = wholememory_create_matrix_desc( cross_gather_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); wholememory_cross_gather(wholememory_handle, wholememory_desc, - sort_unique_indices_handle.pointer(), - sort_unique_indice_desc, + cross_gather_indices_handle.pointer(), + cross_gather_indices_desc, dev_cross_gather_buffer_ptr, cross_gather_buffer_desc, + host_cross_bucket_id_count.data(), embedding_entry_count_per_rank, wm_local_comm, wm_cross_comm, @@ -267,7 +293,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( p_env_fns, stream, gather_sms); - // sort-unique reorder + // cross gather reorder temp_memory_handle dev_embedding_map_buffer_handle(p_env_fns); void* dev_embedding_map_buffer_ptr = dev_embedding_map_buffer_handle.device_malloc( wholememory_desc.sizes[1] * total_recv_count, output_desc.dtype); @@ -278,7 +304,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( wholememory_create_continuous_global_reference(dev_cross_gather_buffer_ptr); WHOLEMEMORY_RETURN_ON_FAIL(gather_func(cross_gather_fake_gref, cross_gather_buffer_desc, - dev_sort_unique_ids_map_handle.pointer(), + dev_cross_gather_id_map_handle.pointer(), recv_bucket_indices_desc, dev_embedding_map_buffer_ptr, embedding_map_buffer_desc, From cba73a2722f6bbb07a01df43d27ea8ccd61a5656 Mon Sep 17 00:00:00 2001 From: zhuofanl Date: Sun, 8 Sep 2024 10:08:15 +0000 Subject: [PATCH 6/6] support different rank sizes for heirarchy memory type --- cpp/include/wholememory/wholememory.h | 16 +- cpp/src/wholememory/memory_handle.cpp | 236 +++--------------- cpp/src/wholememory/wholememory.cpp | 9 - .../bucket_ids_for_hierarchy_func.cu | 113 ++++++--- .../functions/bucket_ids_for_hierarchy_func.h | 4 +- .../gather_op_impl_hierarchy.cu | 54 ++-- .../wholememory_gather_tests.cu | 35 +++ .../pylibwholegraph/torch/embedding.py | 2 +- 8 files changed, 186 insertions(+), 283 deletions(-) diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index 7cf8f6908..58aa611ce 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -295,7 +295,7 @@ wholememory_error_code_t wholememory_get_local_communicator( /** * Get underlying Wholememory Cross Communicator for "Hierarchy" memory type from WholeMemory Handle * One comminicator includes all rank with a same local id from different nodes - * @param comm : returned Local WholeMemory Communicator + * @param comm : returned Cross WholeMemory Communicator * @param wholememory_handle : WholeMemory Handle * @return : wholememory_error_code_t */ @@ -348,20 +348,6 @@ wholememory_error_code_t wholememory_get_local_memory(void** local_ptr, size_t* local_offset, wholememory_handle_t wholememory_handle); -/** - * Get local node memory from WholeMemory Handle, all gpus of the rank has direct access to the - * memory. Note that this is only available for WHOLEMEMORY_MT_HIERARCHY memory type. - * @param local_ptr : returned local node memory pointer - * @param local_size : returned local node memory size - * @param local_offset : returned local node memory offset from WholeMemory - * @param wholememory_handle : WholeMemory Handle - * @return : wholememory_error_code_t - */ -wholememory_error_code_t wholememory_get_local_node_memory(void** local_ptr, - size_t* local_size, - size_t* local_offset, - wholememory_handle_t wholememory_handle); - /** * Get local memory size from WholeMemory Handle of current rank * @param local_size : returned local memory size diff --git a/cpp/src/wholememory/memory_handle.cpp b/cpp/src/wholememory/memory_handle.cpp index 48b687683..f5cbd622e 100644 --- a/cpp/src/wholememory/memory_handle.cpp +++ b/cpp/src/wholememory/memory_handle.cpp @@ -326,7 +326,7 @@ class distributed_wholememory_impl : public wholememory_impl { data_granularity, rank_entry_partition) { - WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED); + WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED || type_ == WHOLEMEMORY_MT_HIERARCHY); } void create_memory() override { @@ -647,11 +647,12 @@ class continuous_device_wholememory_impl : public wholememory_impl { data_granularity, rank_entry_partition) { - printf( - "while in continuous device wholememory creation, the memory_type (%d) and memory_location " - "(%d).\n", - (int)memory_type, - (int)memory_location); + // printf( + // "while in continuous device wholememory creation, the memory_type (%d) and memory_location + // " + // "(%d).\n", + // (int)memory_type, + // (int)memory_location); WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_CONTINUOUS); } void create_memory() override @@ -1752,7 +1753,7 @@ struct wholememory_create_param { size_t min_granularity; }; -class hierarchy_wholememory_impl : public wholememory_impl { +class hierarchy_wholememory_impl : public distributed_wholememory_impl { public: hierarchy_wholememory_impl(wholememory_handle_t wholememory_handle, size_t total_size, @@ -1760,147 +1761,33 @@ class hierarchy_wholememory_impl : public wholememory_impl { wholememory_comm_t local_comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : wholememory_impl( - wholememory_handle, total_size, global_comm, memory_type, memory_location, data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) + : distributed_wholememory_impl(wholememory_handle, + total_size, + global_comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) { WHOLEMEMORY_CHECK(memory_type == WHOLEMEMORY_MT_HIERARCHY); - local_comm_ = local_comm; - if (SupportEGM() && is_intra_mnnvl_communicator(global_comm)) { -#if CUDA_VERSION >= 12030 - clique_info_t* clique_info = nullptr; - wholememory_communicator_get_clique_info(clique_info, global_comm); - WHOLEMEMORY_CHECK_NOTHROW(clique_info->is_in_clique); - wholememory_split_communicator( - &cross_comm_, global_comm, clique_info->clique_rank, clique_info->clique_id); -#else - WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINUOUS is only supported on CUDA Version >= 12.3"); -#endif - } else { - int world_rank = -1, local_size = -1; - wholememory_communicator_get_rank(&world_rank, global_comm); - wholememory_communicator_get_local_size(&local_size, global_comm); - wholememory_split_communicator( - &cross_comm_, global_comm, world_rank % local_size, world_rank / local_size); - } + local_comm_ = local_comm; + int world_rank = -1, world_size = -1, local_size = -1; + wholememory_communicator_get_rank(&world_rank, global_comm); + wholememory_communicator_get_size(&world_size, global_comm); + wholememory_communicator_get_size(&local_size, local_comm); + WHOLEMEMORY_CHECK(world_size % local_size == 0); + wholememory_split_communicator( + &cross_comm_, global_comm, world_rank % local_size, world_rank / local_size); } - void create_memory() override - { - std::unique_lock mlock(local_comm_->mu); - local_memory_handle_ = new wholememory_handle_(); - local_memory_handle_->handle_id = negotiate_handle_id_with_comm_locked(local_comm_); - determine_node_size(); - WM_COMM_CHECK_ALL_SAME(local_comm_, WM_MEM_OP_CREATE); - wholememory_create_param wcp(node_partition_strategy_.local_mem_size, - WHOLEMEMORY_MT_CONTINUOUS, - location_, - data_granularity_); - WM_COMM_CHECK_ALL_SAME(local_comm_, wcp); - - // TODO chunkded memory type and nvshmem type are both not supported yet. - if (is_intranode_communicator(local_comm_) || !SupportEGM()) - if (location_ == WHOLEMEMORY_ML_HOST) { - local_memory_handle_->impl = - new global_mapped_host_wholememory_impl(local_memory_handle_, - node_partition_strategy_.local_mem_size, - local_comm_, - WHOLEMEMORY_MT_CONTINUOUS, - location_, - data_granularity_); - } else if (location_ == WHOLEMEMORY_ML_DEVICE) { - local_memory_handle_->impl = - new continuous_device_wholememory_impl(local_memory_handle_, - node_partition_strategy_.local_mem_size, - local_comm_, - WHOLEMEMORY_MT_CONTINUOUS, - location_, - data_granularity_); - } else { - WHOLEMEMORY_ERROR("unsupported memory location"); - } - else { -#if CUDA_VERSION >= 12030 - local_memory_handle_->impl = - new continuous_mnnvl_wholememory_impl(local_memory_handle_, - node_partition_strategy_.local_mem_size, - local_comm_, - WHOLEMEMORY_MT_CONTINUOUS, - location_, - data_granularity_); -#else - WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINOUS is only supported on CUDA version >= 12.3"); -#endif - } - local_memory_handle_->impl->create_memory(); - local_comm_->wholememory_map.insert( - std::pair(local_memory_handle_->handle_id, local_memory_handle_)); - local_node_memory_pointer_ = local_memory_handle_->impl->get_continuous_mapping_pointer(); - } - [[nodiscard]] wholememory_gref_t get_global_reference() const noexcept override - { - wholememory_gref_t gref{}; - gref.pointer = local_node_memory_pointer_; - gref.stride = 0; - return gref; - } - void get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset) const override - { - get_local_memory_from_handle(local_ptr, local_size, local_offset, local_memory_handle_); - *local_offset += node_partition_strategy_.local_mem_offset; - return; - } - void get_local_node_memory(void** local_node_ptr, - size_t* local_node_size, - size_t* local_node_offset) - { - *local_node_ptr = local_node_memory_pointer_; - *local_node_size = node_partition_strategy_.local_mem_size; - *local_node_offset = node_partition_strategy_.local_mem_offset; - } - [[nodiscard]] size_t get_partition_stride() const override - { - return local_memory_handle_->impl->get_partition_stride(); - } [[nodiscard]] wholememory_comm_t get_local_comm() const { return local_comm_; } [[nodiscard]] wholememory_comm_t get_cross_comm() const { return cross_comm_; } - void destroy_memory() noexcept override { destroy_wholememory(local_memory_handle_); } - bool contains_pointer(const void* ptr) const override - { - uint64_t int_ptr = reinterpret_cast(ptr); - uint64_t int_start_ptr = reinterpret_cast(local_node_memory_pointer_); - return int_ptr >= int_start_ptr && - int_ptr < int_start_ptr + node_partition_strategy_.local_mem_size; - } protected: - void determine_node_size() - { - size_t node_num = comm_->world_size / local_comm_->world_size; - size_t node_id = comm_->world_rank / local_comm_->world_size; - size_t data_slot_count = total_size_ / data_granularity_; - size_t data_slot_per_rank = determine_entry_partition_plan(data_slot_count, comm_->world_size); - size_t data_slot_per_node = data_slot_per_rank * local_comm_->world_size; - size_t node_data_slot_start = std::min(node_id * data_slot_per_node, data_slot_count); - size_t node_data_slot_end = std::min((node_id + 1) * data_slot_per_node, data_slot_count); - size_t node_data_slot_count = node_data_slot_end - node_data_slot_start; - - node_partition_strategy_.local_mem_size = node_data_slot_count * data_granularity_; - node_partition_strategy_.local_mem_offset = node_data_slot_start * data_granularity_; - node_partition_strategy_.partition_mem_stride = data_slot_per_node * data_granularity_; - } - - wholememory_handle_t local_memory_handle_; wholememory_comm_t local_comm_; wholememory_comm_t cross_comm_; - void* local_node_memory_pointer_; - struct partition_strategy { - // size of memory this rank is responsible for - size_t local_mem_size = 0; - // start location of the memory this rank is responsible for - size_t local_mem_offset = 0; - size_t partition_mem_stride = 0; - } node_partition_strategy_; }; wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_handle_ptr, @@ -1958,15 +1845,7 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha data_granularity, rank_entry_partition); } - } else if (memory_type == WHOLEMEMORY_MT_CONTINUOUS || - (memory_type == WHOLEMEMORY_MT_HIERARCHY && is_intranode_communicator(comm)) || - (memory_type == WHOLEMEMORY_MT_HIERARCHY && is_intra_mnnvl_communicator(comm))) { - if (memory_type == WHOLEMEMORY_MT_HIERARCHY) { - WHOLEMEMORY_WARN( - "intra-node or intra-mnnvl HIERARCHY memory type is implemented as CONTINUOUS memory " - "type"); - memory_type = WHOLEMEMORY_MT_CONTINUOUS; - } + } else if (memory_type == WHOLEMEMORY_MT_CONTINUOUS) { if (is_intranode_communicator(comm) || !SupportEGM()) { if (memory_location == WHOLEMEMORY_ML_HOST) { whole_memory_handle->impl = new global_mapped_host_wholememory_impl(whole_memory_handle, @@ -2019,37 +1898,19 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha } } else if (memory_type == WHOLEMEMORY_MT_HIERARCHY) { wholememory_comm_t local_comm; - if (SupportEGM() && is_intra_mnnvl_communicator(comm)) { -#if CUDA_VERSION >= 12030 - clique_info_t* clique_info = nullptr; - wholememory_communicator_get_clique_info(clique_info, comm); - WHOLEMEMORY_CHECK_NOTHROW(clique_info->is_in_clique); - wholememory_split_communicator( - &local_comm, comm, clique_info->clique_id, clique_info->clique_rank); - whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle, - total_size, - comm, - local_comm, - memory_type, - memory_location, - data_granularity); -#else - WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINUOUS is only supported on CUDA Version >= 12.3"); -#endif - } else { - int world_rank = -1, local_size = -1; - wholememory_communicator_get_rank(&world_rank, comm); - wholememory_communicator_get_local_size(&local_size, comm); - wholememory_split_communicator( - &local_comm, comm, world_rank / local_size, world_rank % local_size); - whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle, - total_size, - comm, - local_comm, - memory_type, - memory_location, - data_granularity); - } + int world_rank = -1, local_size = -1; + wholememory_communicator_get_rank(&world_rank, comm); + wholememory_communicator_get_local_size(&local_size, comm); + wholememory_split_communicator( + &local_comm, comm, world_rank / local_size, world_rank % local_size); + whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle, + total_size, + comm, + local_comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } else { WHOLEMEMORY_FATAL("Unsupported memory_type (%d) and memory_location (%d).", (int)memory_type, @@ -2194,25 +2055,6 @@ wholememory_error_code_t get_local_memory_from_handle( return WHOLEMEMORY_SUCCESS; } -wholememory_error_code_t get_local_node_memory_from_handle( - void** local_ptr, - size_t* local_size, - size_t* local_offset, - wholememory_handle_t wholememory_handle) noexcept -{ - if (get_memory_type(wholememory_handle) != WHOLEMEMORY_MT_HIERARCHY) { - WHOLEMEMORY_ERROR("Only Hierarchy memory type support get_local_node_memory function."); - return WHOLEMEMORY_INVALID_INPUT; - } - if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { - return WHOLEMEMORY_INVALID_INPUT; - } - hierarchy_wholememory_impl* hierarchy_impl = - dynamic_cast(wholememory_handle->impl); - hierarchy_impl->get_local_node_memory(local_ptr, local_size, local_offset); - return WHOLEMEMORY_SUCCESS; -} - wholememory_error_code_t get_rank_memory_from_handle( void** rank_memory_ptr, size_t* rank_memory_size, diff --git a/cpp/src/wholememory/wholememory.cpp b/cpp/src/wholememory/wholememory.cpp index 47da7c735..6f85dec24 100644 --- a/cpp/src/wholememory/wholememory.cpp +++ b/cpp/src/wholememory/wholememory.cpp @@ -185,15 +185,6 @@ wholememory_error_code_t wholememory_get_local_memory(void** local_ptr, local_ptr, local_size, local_offset, wholememory_handle); } -wholememory_error_code_t wholememory_get_local_node_memory(void** local_ptr, - size_t* local_size, - size_t* local_offset, - wholememory_handle_t wholememory_handle) -{ - return wholememory::get_local_node_memory_from_handle( - local_ptr, local_size, local_offset, wholememory_handle); -} - wholememory_error_code_t wholememory_get_rank_memory(void** rank_memory_ptr, size_t* rank_memory_size, size_t* rank_memory_offset, diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu index 7c50d2b2f..ff901d8b5 100644 --- a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu @@ -33,24 +33,50 @@ namespace wholememory_ops { +template +__device__ __forceinline__ int dest_rank(IndexT entry_idx, + const size_t* embedding_entry_offsets, + int world_size) +{ + size_t total_entry_count = embedding_entry_offsets[world_size]; + size_t estimated_entry_per_rank = total_entry_count / world_size; + int estimated_rank = max(world_size - 1, int(entry_idx / estimated_entry_per_rank)); + if (embedding_entry_offsets[estimated_rank] > entry_idx) { + for (int i = estimated_rank - 1; i >= 0; i--) { + if (embedding_entry_offsets[i] <= entry_idx) { return i; } + } + } else { + for (int i = estimated_rank + 1; i <= world_size; i++) { + if (embedding_entry_offsets[i] > entry_idx) { return i - 1; } + } + } + return 0; +} + template __global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices, size_t indice_count, int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + const size_t* embedding_entry_offsets, int local_size, + int world_size, int nbucket) { - extern __shared__ int rank_count_shared[]; + extern __shared__ char shared_mem[]; + size_t* embedding_entry_offsets_shared = reinterpret_cast(shared_mem); + int* rank_count_shared = reinterpret_cast(shared_mem + sizeof(size_t) * (world_size + 1)); for (int idx = threadIdx.x; idx < nbucket; idx += blockDim.x) { rank_count_shared[idx] = 0; } + for (int idx = threadIdx.x; idx < world_size + 1; idx += blockDim.x) { + embedding_entry_offsets_shared[idx] = embedding_entry_offsets[idx]; + } __syncthreads(); for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count; idx += blockDim.x * gridDim.x) { IndexT node_idx = indices[idx]; if (node_idx < 0) continue; - int rank = node_idx / embedding_entry_count_per_rank; + int rank = dest_rank(node_idx, embedding_entry_offsets_shared, world_size); int bucket = 0; if (BUCKET_CROSS_OR_LOCAL == 0) bucket = rank % local_size; @@ -73,7 +99,7 @@ template void bucket_ids_for_hierarchy_temp_func(const void* indices, wholememory_array_description_t indice_desc, int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + const size_t* dev_embedding_entry_offsets, int local_size, int cross_size, int bucket_cross_or_local, @@ -85,29 +111,35 @@ void bucket_ids_for_hierarchy_temp_func(const void* indices, block_count = std::min(block_count, sm_count * 4); const IndexT* indices_ptr = static_cast(indices); indices_ptr += indice_desc.storage_offset; - + int world_size = local_size * cross_size; if (bucket_cross_or_local == 0) { int bucket_size = local_size; cudaMemsetAsync(dev_rank_id_count_ptr, 0, sizeof(int64_t) * bucket_size, stream); bucket_ids_for_hierarchy_kernel - <<>>( - indices_ptr, - indice_desc.size, - dev_rank_id_count_ptr, - embedding_entry_count_per_rank, - local_size, - bucket_size); + <<>>(indices_ptr, + indice_desc.size, + dev_rank_id_count_ptr, + dev_embedding_entry_offsets, + local_size, + world_size, + bucket_size); } else { int bucket_size = cross_size; cudaMemsetAsync(dev_rank_id_count_ptr, 0, sizeof(int64_t) * bucket_size, stream); bucket_ids_for_hierarchy_kernel - <<>>( - indices_ptr, - indice_desc.size, - dev_rank_id_count_ptr, - embedding_entry_count_per_rank, - local_size, - bucket_size); + <<>>(indices_ptr, + indice_desc.size, + dev_rank_id_count_ptr, + dev_embedding_entry_offsets, + local_size, + world_size, + bucket_size); } } @@ -119,23 +151,31 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, IndexT* dev_bucket_indices, IndexT* dev_indice_map, const int64_t* dev_rank_id_offset_ptr, - size_t embedding_entry_count_per_rank, + const size_t* embedding_entry_offsets, int local_size, + int world_size, int nbucket, int64_t* dev_bucket_atomic_add_ptr) { constexpr size_t shared_mem_size = 24576; __shared__ char shared_mem[shared_mem_size]; - int* block_bucket_count_shared = reinterpret_cast(shared_mem); - int* block_bucket_atomic_add_shared = reinterpret_cast(shared_mem) + nbucket; + size_t* embedding_entry_offsets_shared = reinterpret_cast(shared_mem); + char* shared_mem_for_bucket = shared_mem + sizeof(size_t) * (world_size + 1); + int* block_bucket_count_shared = reinterpret_cast(shared_mem_for_bucket); + int* block_bucket_atomic_add_shared = reinterpret_cast(shared_mem_for_bucket) + nbucket; IndexT* block_bucket_offset_shared = - reinterpret_cast(shared_mem + 2 * sizeof(int) * nbucket); + reinterpret_cast(shared_mem_for_bucket + 2 * sizeof(int) * nbucket); IndexT* global_bucket_offset_shared = block_bucket_offset_shared + nbucket; - size_t buffer_size = - (shared_mem_size - nbucket * 2 * (sizeof(IndexT) + sizeof(int))) / sizeof(IndexT) / 2; + size_t buffer_size = (shared_mem_size - sizeof(size_t) * (world_size + 1) - + nbucket * 2 * (sizeof(IndexT) + sizeof(int))) / + sizeof(IndexT) / 2; buffer_size = (buffer_size / blockDim.x) * blockDim.x; assert(buffer_size > 0); + for (int idx = threadIdx.x; idx < world_size + 1; idx += blockDim.x) { + embedding_entry_offsets_shared[idx] = embedding_entry_offsets[idx]; + } + __syncthreads(); IndexT* buffer_load = global_bucket_offset_shared + nbucket; IndexT* buffer_store = buffer_load + buffer_size; @@ -156,7 +196,7 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, buffer_load[i] = indice; int bucket_idx = 0; - int rank = indice / embedding_entry_count_per_rank; + int rank = dest_rank(indice, embedding_entry_offsets_shared, world_size); if (BUCKET_CROSS_OR_LOCAL == 0) { bucket_idx = rank % local_size; } else { @@ -188,7 +228,7 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, IndexT load_idx = i + load_offset; if (load_idx >= indice_count) break; int bucket_idx = 0; - int rank = indice / embedding_entry_count_per_rank; + int rank = dest_rank(indice, embedding_entry_offsets_shared, world_size); if (BUCKET_CROSS_OR_LOCAL == 0) { bucket_idx = rank % local_size; } else { @@ -223,7 +263,7 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices, void* dev_bucket_indices, void* dev_indice_map, const int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + const size_t* dev_embedding_entry_offsets, int local_size, int cross_size, int bucket_cross_or_local, @@ -241,6 +281,7 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices, } else { nbucket = cross_size; } + int world_size = local_size * cross_size; temp_memory_handle dev_rank_id_offset_handle(p_env_fns); int64_t* dev_rank_id_offset_ptr = static_cast(dev_rank_id_offset_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64)); @@ -276,8 +317,9 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices, static_cast(dev_bucket_indices), static_cast(dev_indice_map), dev_rank_id_offset_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets, local_size, + world_size, nbucket, dev_bucket_atomic_add_ptr); else @@ -287,8 +329,9 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices, static_cast(dev_bucket_indices), static_cast(dev_indice_map), dev_rank_id_offset_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets, local_size, + world_size, nbucket, dev_bucket_atomic_add_ptr); ; @@ -302,7 +345,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( void* dev_bucket_indices, void* dev_indice_map, int64_t* host_bucket_id_count, - size_t embedding_entry_count_per_rank, + size_t* dev_embedding_entry_offsets, wholememory_comm_t wm_global_comm, wholememory_comm_t wm_local_comm, int bucket_cross_or_local, @@ -338,7 +381,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( indices, indice_desc, dev_rank_id_count_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets, local_size, cross_size, bucket_cross_or_local, @@ -361,7 +404,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( dev_bucket_indices, dev_indice_map, dev_rank_id_count_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets, local_size, cross_size, bucket_cross_or_local, @@ -384,7 +427,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( wholememory_error_code_t bucket_local_ids_func(void* indices, wholememory_array_description_t indice_desc, int64_t* host_bucket_id_count, - size_t embedding_entry_count_per_rank, + size_t* dev_embedding_entry_offsets, wholememory_comm_t wm_local_comm, wholememory_comm_t wm_cross_comm, wm_thrust_allocator* p_thrust_allocator, @@ -409,7 +452,7 @@ wholememory_error_code_t bucket_local_ids_func(void* indices, indices, indice_desc, dev_rank_id_count_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets, local_size, cross_size, 1, diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h index d6d061c5e..60665c9db 100644 --- a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h @@ -29,7 +29,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( void* dev_bucket_indices, void* dev_indice_map, int64_t* host_bucket_id_count, - size_t embedding_entry_count_per_rank, + size_t* dev_embedding_entry_offsets, wholememory_comm_t wm_global_comm, wholememory_comm_t wm_local_comm, int bucket_cross_or_local, // 0: cross, 1: local @@ -40,7 +40,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( wholememory_error_code_t bucket_local_ids_func(void* indices, wholememory_array_description_t indice_desc, int64_t* host_bucket_id_count, - size_t embedding_entry_count_per_rank, + size_t* dev_embedding_entry_offsets, wholememory_comm_t wm_local_comm, wholememory_comm_t wm_cross_comm, wm_thrust_allocator* p_thrust_allocator, diff --git a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu index 3e29d29fc..543bdfd76 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,7 +39,6 @@ static wholememory_error_code_t wholememory_cross_gather( void* output, wholememory_matrix_description_t output_desc, int64_t* host_bucket_id_count_ptr, - size_t embedding_entry_count_per_rank, wholememory_comm_t wm_local_comm, wholememory_comm_t wm_cross_comm, wm_thrust_allocator* p_thrust_allocator, @@ -130,24 +129,10 @@ wholememory_error_code_t wholememory_gather_hierarchy( wholememory_desc.storage_offset + wholememory_desc.sizes[1] > wholememory_desc.stride) { return WHOLEMEMORY_INVALID_INPUT; } - bool sort_unique_indices = true; wm_thrust_allocator thrust_allocator(p_env_fns); - size_t embedding_size_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL( - wholememory_get_partition_plan(&embedding_size_per_rank, wholememory_handle)); - size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); - size_t embedding_entry_size = element_size * wholememory_desc.stride; - WHOLEMEMORY_EXPECTS_NOTHROW( - embedding_size_per_rank % embedding_entry_size == 0, - "embedding_size_per_rank=%ld is not multiple of embedding_entry_size=%ldx%ld", - embedding_size_per_rank, - element_size, - wholememory_desc.stride); - size_t embedding_entry_count_per_rank = embedding_size_per_rank / embedding_entry_size; - wholememory_comm_t wm_global_comm; int world_size, world_rank; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_global_comm, wholememory_handle)); @@ -158,8 +143,6 @@ wholememory_error_code_t wholememory_gather_hierarchy( int local_size, local_rank; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_get_local_communicator(&wm_local_comm, wholememory_handle)); - // WHOLEMEMORY_RETURN_ON_FAIL(wholememory_split_communicator( - // &wm_local_comm, wm_global_comm, world_rank / local_size, world_rank % local_size)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&local_rank, wm_local_comm)); @@ -167,11 +150,35 @@ wholememory_error_code_t wholememory_gather_hierarchy( int cross_size; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_get_cross_communicator(&wm_cross_comm, wholememory_handle)); - // WHOLEMEMORY_RETURN_ON_FAIL(wholememory_split_communicator( - // &wm_cross_comm, wm_global_comm, world_rank % local_size, world_rank / local_size)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); WHOLEMEMORY_CHECK_NOTHROW(world_size == local_size * cross_size); + size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); + size_t embedding_entry_size = element_size * wholememory_desc.stride; + temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + std::vector host_embedding_entry_offsets(world_size + 1); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_rank_partition_offsets( + host_embedding_entry_offsets.data(), wholememory_handle)); + for (int i = 0; i < world_size + 1; i++) { + size_t offset = host_embedding_entry_offsets[i]; + WHOLEMEMORY_EXPECTS_NOTHROW( + offset % embedding_entry_size == 0, + "embedding memory offset of rank%d=%ld is not multiple of embedding_entry_size=%ldx%ld", + i, + offset, + element_size, + wholememory_desc.stride); + host_embedding_entry_offsets[i] /= embedding_entry_size; + } + + WM_CUDA_CHECK(cudaMemcpyAsync(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets.data(), + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice, + stream)); + temp_memory_handle dev_bucket_indices_handle(p_env_fns); void* dev_bucket_indices_ptr = dev_bucket_indices_handle.device_malloc(indice_desc.size, indice_desc.dtype); @@ -191,7 +198,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( dev_bucket_indices_ptr, dev_bucket_ids_map_ptr, host_bucket_id_count.data(), - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, wm_global_comm, wm_local_comm, 0, @@ -244,7 +251,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( bucket_local_ids_func(cross_gather_indices_handle.pointer(), cross_gather_indices_desc, host_cross_bucket_id_count.data(), - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, wm_local_comm, wm_cross_comm, &thrust_allocator, @@ -262,7 +269,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( cross_gather_indices_ptr, dev_cross_gather_id_map_ptr, host_cross_bucket_id_count.data(), - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, wm_global_comm, wm_local_comm, 1, @@ -286,7 +293,6 @@ wholememory_error_code_t wholememory_gather_hierarchy( dev_cross_gather_buffer_ptr, cross_gather_buffer_desc, host_cross_bucket_id_count.data(), - embedding_entry_count_per_rank, wm_local_comm, wm_cross_comm, &thrust_allocator, diff --git a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu index f86c4b93f..506e21ca0 100644 --- a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu @@ -300,9 +300,11 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_indices_count(0), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_indices_count(0), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_indices_count(0), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_indices_count(0), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_memory_location(WHOLEMEMORY_ML_HOST), @@ -312,12 +314,20 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_memory_location(WHOLEMEMORY_ML_HOST), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_memory_location(WHOLEMEMORY_ML_HOST), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).use_random_partition(), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).use_random_partition(), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).use_random_partition(), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_memory_location(WHOLEMEMORY_ML_HOST) .use_random_partition(), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .use_random_partition(), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_memory_location(WHOLEMEMORY_ML_HOST) @@ -353,18 +363,27 @@ INSTANTIATE_TEST_SUITE_P( .set_embedding_dim(11) .set_embedding_stride(12) .set_indices_count(100005), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_embedding_dim(11) + .set_embedding_stride(12) + .set_indices_count(100005), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(128), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(128), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_embedding_dim(128), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_dim(128), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(127), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(127), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_embedding_dim(127), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_dim(127), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(129), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(129), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_embedding_dim(129), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_dim(129), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(513), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(513), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_embedding_dim(513), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_dim(513), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_type(WHOLEMEMORY_DT_HALF), @@ -383,6 +402,9 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_output_type(WHOLEMEMORY_DT_HALF), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_output_type(WHOLEMEMORY_DT_HALF), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_type(WHOLEMEMORY_DT_HALF) @@ -395,6 +417,10 @@ INSTANTIATE_TEST_SUITE_P( .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_type(WHOLEMEMORY_DT_HALF) .set_output_type(WHOLEMEMORY_DT_HALF), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_embedding_type(WHOLEMEMORY_DT_HALF) + .set_output_type(WHOLEMEMORY_DT_HALF), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_indices_type(WHOLEMEMORY_DT_INT64), @@ -404,6 +430,9 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_indices_type(WHOLEMEMORY_DT_INT64), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_indices_type(WHOLEMEMORY_DT_INT64), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_stride(33), @@ -411,9 +440,11 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_stride(33), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_stride(33), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_output_stride(33), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_output_stride(33), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_output_stride(33), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_output_stride(33), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_type(WHOLEMEMORY_DT_HALF) @@ -426,6 +457,10 @@ INSTANTIATE_TEST_SUITE_P( .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_type(WHOLEMEMORY_DT_HALF) .set_embedding_stride(33), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_embedding_type(WHOLEMEMORY_DT_HALF) + .set_embedding_stride(33), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) #ifdef WITH_NVSHMEM_SUPPORT , diff --git a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py index d242ca2fd..cb30bc932 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py @@ -426,7 +426,7 @@ def create_embedding( if embedding_entry_partition is not None and round_robin_size != 0: print("round_robin_size is ignored because embedding_entry_partition is specified") round_robin_size = 0 - if memory_type == 'hierarchy': # todo: modified + if memory_type == 'hierarchy': # todo: modified comm_backend = comm.distributed_backend if comm_backend == 'nvshmem': raise AssertionError