diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index 7cf8f6908..58aa611ce 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -295,7 +295,7 @@ wholememory_error_code_t wholememory_get_local_communicator( /** * Get underlying Wholememory Cross Communicator for "Hierarchy" memory type from WholeMemory Handle * One comminicator includes all rank with a same local id from different nodes - * @param comm : returned Local WholeMemory Communicator + * @param comm : returned Cross WholeMemory Communicator * @param wholememory_handle : WholeMemory Handle * @return : wholememory_error_code_t */ @@ -348,20 +348,6 @@ wholememory_error_code_t wholememory_get_local_memory(void** local_ptr, size_t* local_offset, wholememory_handle_t wholememory_handle); -/** - * Get local node memory from WholeMemory Handle, all gpus of the rank has direct access to the - * memory. Note that this is only available for WHOLEMEMORY_MT_HIERARCHY memory type. - * @param local_ptr : returned local node memory pointer - * @param local_size : returned local node memory size - * @param local_offset : returned local node memory offset from WholeMemory - * @param wholememory_handle : WholeMemory Handle - * @return : wholememory_error_code_t - */ -wholememory_error_code_t wholememory_get_local_node_memory(void** local_ptr, - size_t* local_size, - size_t* local_offset, - wholememory_handle_t wholememory_handle); - /** * Get local memory size from WholeMemory Handle of current rank * @param local_size : returned local memory size diff --git a/cpp/src/wholememory/memory_handle.cpp b/cpp/src/wholememory/memory_handle.cpp index 48b687683..f5cbd622e 100644 --- a/cpp/src/wholememory/memory_handle.cpp +++ b/cpp/src/wholememory/memory_handle.cpp @@ -326,7 +326,7 @@ class distributed_wholememory_impl : public wholememory_impl { data_granularity, rank_entry_partition) { - WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED); + WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED || type_ == WHOLEMEMORY_MT_HIERARCHY); } void create_memory() override { @@ -647,11 +647,12 @@ class continuous_device_wholememory_impl : public wholememory_impl { data_granularity, rank_entry_partition) { - printf( - "while in continuous device wholememory creation, the memory_type (%d) and memory_location " - "(%d).\n", - (int)memory_type, - (int)memory_location); + // printf( + // "while in continuous device wholememory creation, the memory_type (%d) and memory_location + // " + // "(%d).\n", + // (int)memory_type, + // (int)memory_location); WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_CONTINUOUS); } void create_memory() override @@ -1752,7 +1753,7 @@ struct wholememory_create_param { size_t min_granularity; }; -class hierarchy_wholememory_impl : public wholememory_impl { +class hierarchy_wholememory_impl : public distributed_wholememory_impl { public: hierarchy_wholememory_impl(wholememory_handle_t wholememory_handle, size_t total_size, @@ -1760,147 +1761,33 @@ class hierarchy_wholememory_impl : public wholememory_impl { wholememory_comm_t local_comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : wholememory_impl( - wholememory_handle, total_size, global_comm, memory_type, memory_location, data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) + : distributed_wholememory_impl(wholememory_handle, + total_size, + global_comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) { WHOLEMEMORY_CHECK(memory_type == WHOLEMEMORY_MT_HIERARCHY); - local_comm_ = local_comm; - if (SupportEGM() && is_intra_mnnvl_communicator(global_comm)) { -#if CUDA_VERSION >= 12030 - clique_info_t* clique_info = nullptr; - wholememory_communicator_get_clique_info(clique_info, global_comm); - WHOLEMEMORY_CHECK_NOTHROW(clique_info->is_in_clique); - wholememory_split_communicator( - &cross_comm_, global_comm, clique_info->clique_rank, clique_info->clique_id); -#else - WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINUOUS is only supported on CUDA Version >= 12.3"); -#endif - } else { - int world_rank = -1, local_size = -1; - wholememory_communicator_get_rank(&world_rank, global_comm); - wholememory_communicator_get_local_size(&local_size, global_comm); - wholememory_split_communicator( - &cross_comm_, global_comm, world_rank % local_size, world_rank / local_size); - } + local_comm_ = local_comm; + int world_rank = -1, world_size = -1, local_size = -1; + wholememory_communicator_get_rank(&world_rank, global_comm); + wholememory_communicator_get_size(&world_size, global_comm); + wholememory_communicator_get_size(&local_size, local_comm); + WHOLEMEMORY_CHECK(world_size % local_size == 0); + wholememory_split_communicator( + &cross_comm_, global_comm, world_rank % local_size, world_rank / local_size); } - void create_memory() override - { - std::unique_lock mlock(local_comm_->mu); - local_memory_handle_ = new wholememory_handle_(); - local_memory_handle_->handle_id = negotiate_handle_id_with_comm_locked(local_comm_); - determine_node_size(); - WM_COMM_CHECK_ALL_SAME(local_comm_, WM_MEM_OP_CREATE); - wholememory_create_param wcp(node_partition_strategy_.local_mem_size, - WHOLEMEMORY_MT_CONTINUOUS, - location_, - data_granularity_); - WM_COMM_CHECK_ALL_SAME(local_comm_, wcp); - - // TODO chunkded memory type and nvshmem type are both not supported yet. - if (is_intranode_communicator(local_comm_) || !SupportEGM()) - if (location_ == WHOLEMEMORY_ML_HOST) { - local_memory_handle_->impl = - new global_mapped_host_wholememory_impl(local_memory_handle_, - node_partition_strategy_.local_mem_size, - local_comm_, - WHOLEMEMORY_MT_CONTINUOUS, - location_, - data_granularity_); - } else if (location_ == WHOLEMEMORY_ML_DEVICE) { - local_memory_handle_->impl = - new continuous_device_wholememory_impl(local_memory_handle_, - node_partition_strategy_.local_mem_size, - local_comm_, - WHOLEMEMORY_MT_CONTINUOUS, - location_, - data_granularity_); - } else { - WHOLEMEMORY_ERROR("unsupported memory location"); - } - else { -#if CUDA_VERSION >= 12030 - local_memory_handle_->impl = - new continuous_mnnvl_wholememory_impl(local_memory_handle_, - node_partition_strategy_.local_mem_size, - local_comm_, - WHOLEMEMORY_MT_CONTINUOUS, - location_, - data_granularity_); -#else - WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINOUS is only supported on CUDA version >= 12.3"); -#endif - } - local_memory_handle_->impl->create_memory(); - local_comm_->wholememory_map.insert( - std::pair(local_memory_handle_->handle_id, local_memory_handle_)); - local_node_memory_pointer_ = local_memory_handle_->impl->get_continuous_mapping_pointer(); - } - [[nodiscard]] wholememory_gref_t get_global_reference() const noexcept override - { - wholememory_gref_t gref{}; - gref.pointer = local_node_memory_pointer_; - gref.stride = 0; - return gref; - } - void get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset) const override - { - get_local_memory_from_handle(local_ptr, local_size, local_offset, local_memory_handle_); - *local_offset += node_partition_strategy_.local_mem_offset; - return; - } - void get_local_node_memory(void** local_node_ptr, - size_t* local_node_size, - size_t* local_node_offset) - { - *local_node_ptr = local_node_memory_pointer_; - *local_node_size = node_partition_strategy_.local_mem_size; - *local_node_offset = node_partition_strategy_.local_mem_offset; - } - [[nodiscard]] size_t get_partition_stride() const override - { - return local_memory_handle_->impl->get_partition_stride(); - } [[nodiscard]] wholememory_comm_t get_local_comm() const { return local_comm_; } [[nodiscard]] wholememory_comm_t get_cross_comm() const { return cross_comm_; } - void destroy_memory() noexcept override { destroy_wholememory(local_memory_handle_); } - bool contains_pointer(const void* ptr) const override - { - uint64_t int_ptr = reinterpret_cast(ptr); - uint64_t int_start_ptr = reinterpret_cast(local_node_memory_pointer_); - return int_ptr >= int_start_ptr && - int_ptr < int_start_ptr + node_partition_strategy_.local_mem_size; - } protected: - void determine_node_size() - { - size_t node_num = comm_->world_size / local_comm_->world_size; - size_t node_id = comm_->world_rank / local_comm_->world_size; - size_t data_slot_count = total_size_ / data_granularity_; - size_t data_slot_per_rank = determine_entry_partition_plan(data_slot_count, comm_->world_size); - size_t data_slot_per_node = data_slot_per_rank * local_comm_->world_size; - size_t node_data_slot_start = std::min(node_id * data_slot_per_node, data_slot_count); - size_t node_data_slot_end = std::min((node_id + 1) * data_slot_per_node, data_slot_count); - size_t node_data_slot_count = node_data_slot_end - node_data_slot_start; - - node_partition_strategy_.local_mem_size = node_data_slot_count * data_granularity_; - node_partition_strategy_.local_mem_offset = node_data_slot_start * data_granularity_; - node_partition_strategy_.partition_mem_stride = data_slot_per_node * data_granularity_; - } - - wholememory_handle_t local_memory_handle_; wholememory_comm_t local_comm_; wholememory_comm_t cross_comm_; - void* local_node_memory_pointer_; - struct partition_strategy { - // size of memory this rank is responsible for - size_t local_mem_size = 0; - // start location of the memory this rank is responsible for - size_t local_mem_offset = 0; - size_t partition_mem_stride = 0; - } node_partition_strategy_; }; wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_handle_ptr, @@ -1958,15 +1845,7 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha data_granularity, rank_entry_partition); } - } else if (memory_type == WHOLEMEMORY_MT_CONTINUOUS || - (memory_type == WHOLEMEMORY_MT_HIERARCHY && is_intranode_communicator(comm)) || - (memory_type == WHOLEMEMORY_MT_HIERARCHY && is_intra_mnnvl_communicator(comm))) { - if (memory_type == WHOLEMEMORY_MT_HIERARCHY) { - WHOLEMEMORY_WARN( - "intra-node or intra-mnnvl HIERARCHY memory type is implemented as CONTINUOUS memory " - "type"); - memory_type = WHOLEMEMORY_MT_CONTINUOUS; - } + } else if (memory_type == WHOLEMEMORY_MT_CONTINUOUS) { if (is_intranode_communicator(comm) || !SupportEGM()) { if (memory_location == WHOLEMEMORY_ML_HOST) { whole_memory_handle->impl = new global_mapped_host_wholememory_impl(whole_memory_handle, @@ -2019,37 +1898,19 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha } } else if (memory_type == WHOLEMEMORY_MT_HIERARCHY) { wholememory_comm_t local_comm; - if (SupportEGM() && is_intra_mnnvl_communicator(comm)) { -#if CUDA_VERSION >= 12030 - clique_info_t* clique_info = nullptr; - wholememory_communicator_get_clique_info(clique_info, comm); - WHOLEMEMORY_CHECK_NOTHROW(clique_info->is_in_clique); - wholememory_split_communicator( - &local_comm, comm, clique_info->clique_id, clique_info->clique_rank); - whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle, - total_size, - comm, - local_comm, - memory_type, - memory_location, - data_granularity); -#else - WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINUOUS is only supported on CUDA Version >= 12.3"); -#endif - } else { - int world_rank = -1, local_size = -1; - wholememory_communicator_get_rank(&world_rank, comm); - wholememory_communicator_get_local_size(&local_size, comm); - wholememory_split_communicator( - &local_comm, comm, world_rank / local_size, world_rank % local_size); - whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle, - total_size, - comm, - local_comm, - memory_type, - memory_location, - data_granularity); - } + int world_rank = -1, local_size = -1; + wholememory_communicator_get_rank(&world_rank, comm); + wholememory_communicator_get_local_size(&local_size, comm); + wholememory_split_communicator( + &local_comm, comm, world_rank / local_size, world_rank % local_size); + whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle, + total_size, + comm, + local_comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } else { WHOLEMEMORY_FATAL("Unsupported memory_type (%d) and memory_location (%d).", (int)memory_type, @@ -2194,25 +2055,6 @@ wholememory_error_code_t get_local_memory_from_handle( return WHOLEMEMORY_SUCCESS; } -wholememory_error_code_t get_local_node_memory_from_handle( - void** local_ptr, - size_t* local_size, - size_t* local_offset, - wholememory_handle_t wholememory_handle) noexcept -{ - if (get_memory_type(wholememory_handle) != WHOLEMEMORY_MT_HIERARCHY) { - WHOLEMEMORY_ERROR("Only Hierarchy memory type support get_local_node_memory function."); - return WHOLEMEMORY_INVALID_INPUT; - } - if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { - return WHOLEMEMORY_INVALID_INPUT; - } - hierarchy_wholememory_impl* hierarchy_impl = - dynamic_cast(wholememory_handle->impl); - hierarchy_impl->get_local_node_memory(local_ptr, local_size, local_offset); - return WHOLEMEMORY_SUCCESS; -} - wholememory_error_code_t get_rank_memory_from_handle( void** rank_memory_ptr, size_t* rank_memory_size, diff --git a/cpp/src/wholememory/wholememory.cpp b/cpp/src/wholememory/wholememory.cpp index 47da7c735..6f85dec24 100644 --- a/cpp/src/wholememory/wholememory.cpp +++ b/cpp/src/wholememory/wholememory.cpp @@ -185,15 +185,6 @@ wholememory_error_code_t wholememory_get_local_memory(void** local_ptr, local_ptr, local_size, local_offset, wholememory_handle); } -wholememory_error_code_t wholememory_get_local_node_memory(void** local_ptr, - size_t* local_size, - size_t* local_offset, - wholememory_handle_t wholememory_handle) -{ - return wholememory::get_local_node_memory_from_handle( - local_ptr, local_size, local_offset, wholememory_handle); -} - wholememory_error_code_t wholememory_get_rank_memory(void** rank_memory_ptr, size_t* rank_memory_size, size_t* rank_memory_offset, diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu index 7c50d2b2f..ff901d8b5 100644 --- a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu @@ -33,24 +33,50 @@ namespace wholememory_ops { +template +__device__ __forceinline__ int dest_rank(IndexT entry_idx, + const size_t* embedding_entry_offsets, + int world_size) +{ + size_t total_entry_count = embedding_entry_offsets[world_size]; + size_t estimated_entry_per_rank = total_entry_count / world_size; + int estimated_rank = max(world_size - 1, int(entry_idx / estimated_entry_per_rank)); + if (embedding_entry_offsets[estimated_rank] > entry_idx) { + for (int i = estimated_rank - 1; i >= 0; i--) { + if (embedding_entry_offsets[i] <= entry_idx) { return i; } + } + } else { + for (int i = estimated_rank + 1; i <= world_size; i++) { + if (embedding_entry_offsets[i] > entry_idx) { return i - 1; } + } + } + return 0; +} + template __global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices, size_t indice_count, int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + const size_t* embedding_entry_offsets, int local_size, + int world_size, int nbucket) { - extern __shared__ int rank_count_shared[]; + extern __shared__ char shared_mem[]; + size_t* embedding_entry_offsets_shared = reinterpret_cast(shared_mem); + int* rank_count_shared = reinterpret_cast(shared_mem + sizeof(size_t) * (world_size + 1)); for (int idx = threadIdx.x; idx < nbucket; idx += blockDim.x) { rank_count_shared[idx] = 0; } + for (int idx = threadIdx.x; idx < world_size + 1; idx += blockDim.x) { + embedding_entry_offsets_shared[idx] = embedding_entry_offsets[idx]; + } __syncthreads(); for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count; idx += blockDim.x * gridDim.x) { IndexT node_idx = indices[idx]; if (node_idx < 0) continue; - int rank = node_idx / embedding_entry_count_per_rank; + int rank = dest_rank(node_idx, embedding_entry_offsets_shared, world_size); int bucket = 0; if (BUCKET_CROSS_OR_LOCAL == 0) bucket = rank % local_size; @@ -73,7 +99,7 @@ template void bucket_ids_for_hierarchy_temp_func(const void* indices, wholememory_array_description_t indice_desc, int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + const size_t* dev_embedding_entry_offsets, int local_size, int cross_size, int bucket_cross_or_local, @@ -85,29 +111,35 @@ void bucket_ids_for_hierarchy_temp_func(const void* indices, block_count = std::min(block_count, sm_count * 4); const IndexT* indices_ptr = static_cast(indices); indices_ptr += indice_desc.storage_offset; - + int world_size = local_size * cross_size; if (bucket_cross_or_local == 0) { int bucket_size = local_size; cudaMemsetAsync(dev_rank_id_count_ptr, 0, sizeof(int64_t) * bucket_size, stream); bucket_ids_for_hierarchy_kernel - <<>>( - indices_ptr, - indice_desc.size, - dev_rank_id_count_ptr, - embedding_entry_count_per_rank, - local_size, - bucket_size); + <<>>(indices_ptr, + indice_desc.size, + dev_rank_id_count_ptr, + dev_embedding_entry_offsets, + local_size, + world_size, + bucket_size); } else { int bucket_size = cross_size; cudaMemsetAsync(dev_rank_id_count_ptr, 0, sizeof(int64_t) * bucket_size, stream); bucket_ids_for_hierarchy_kernel - <<>>( - indices_ptr, - indice_desc.size, - dev_rank_id_count_ptr, - embedding_entry_count_per_rank, - local_size, - bucket_size); + <<>>(indices_ptr, + indice_desc.size, + dev_rank_id_count_ptr, + dev_embedding_entry_offsets, + local_size, + world_size, + bucket_size); } } @@ -119,23 +151,31 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, IndexT* dev_bucket_indices, IndexT* dev_indice_map, const int64_t* dev_rank_id_offset_ptr, - size_t embedding_entry_count_per_rank, + const size_t* embedding_entry_offsets, int local_size, + int world_size, int nbucket, int64_t* dev_bucket_atomic_add_ptr) { constexpr size_t shared_mem_size = 24576; __shared__ char shared_mem[shared_mem_size]; - int* block_bucket_count_shared = reinterpret_cast(shared_mem); - int* block_bucket_atomic_add_shared = reinterpret_cast(shared_mem) + nbucket; + size_t* embedding_entry_offsets_shared = reinterpret_cast(shared_mem); + char* shared_mem_for_bucket = shared_mem + sizeof(size_t) * (world_size + 1); + int* block_bucket_count_shared = reinterpret_cast(shared_mem_for_bucket); + int* block_bucket_atomic_add_shared = reinterpret_cast(shared_mem_for_bucket) + nbucket; IndexT* block_bucket_offset_shared = - reinterpret_cast(shared_mem + 2 * sizeof(int) * nbucket); + reinterpret_cast(shared_mem_for_bucket + 2 * sizeof(int) * nbucket); IndexT* global_bucket_offset_shared = block_bucket_offset_shared + nbucket; - size_t buffer_size = - (shared_mem_size - nbucket * 2 * (sizeof(IndexT) + sizeof(int))) / sizeof(IndexT) / 2; + size_t buffer_size = (shared_mem_size - sizeof(size_t) * (world_size + 1) - + nbucket * 2 * (sizeof(IndexT) + sizeof(int))) / + sizeof(IndexT) / 2; buffer_size = (buffer_size / blockDim.x) * blockDim.x; assert(buffer_size > 0); + for (int idx = threadIdx.x; idx < world_size + 1; idx += blockDim.x) { + embedding_entry_offsets_shared[idx] = embedding_entry_offsets[idx]; + } + __syncthreads(); IndexT* buffer_load = global_bucket_offset_shared + nbucket; IndexT* buffer_store = buffer_load + buffer_size; @@ -156,7 +196,7 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, buffer_load[i] = indice; int bucket_idx = 0; - int rank = indice / embedding_entry_count_per_rank; + int rank = dest_rank(indice, embedding_entry_offsets_shared, world_size); if (BUCKET_CROSS_OR_LOCAL == 0) { bucket_idx = rank % local_size; } else { @@ -188,7 +228,7 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, IndexT load_idx = i + load_offset; if (load_idx >= indice_count) break; int bucket_idx = 0; - int rank = indice / embedding_entry_count_per_rank; + int rank = dest_rank(indice, embedding_entry_offsets_shared, world_size); if (BUCKET_CROSS_OR_LOCAL == 0) { bucket_idx = rank % local_size; } else { @@ -223,7 +263,7 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices, void* dev_bucket_indices, void* dev_indice_map, const int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + const size_t* dev_embedding_entry_offsets, int local_size, int cross_size, int bucket_cross_or_local, @@ -241,6 +281,7 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices, } else { nbucket = cross_size; } + int world_size = local_size * cross_size; temp_memory_handle dev_rank_id_offset_handle(p_env_fns); int64_t* dev_rank_id_offset_ptr = static_cast(dev_rank_id_offset_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64)); @@ -276,8 +317,9 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices, static_cast(dev_bucket_indices), static_cast(dev_indice_map), dev_rank_id_offset_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets, local_size, + world_size, nbucket, dev_bucket_atomic_add_ptr); else @@ -287,8 +329,9 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices, static_cast(dev_bucket_indices), static_cast(dev_indice_map), dev_rank_id_offset_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets, local_size, + world_size, nbucket, dev_bucket_atomic_add_ptr); ; @@ -302,7 +345,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( void* dev_bucket_indices, void* dev_indice_map, int64_t* host_bucket_id_count, - size_t embedding_entry_count_per_rank, + size_t* dev_embedding_entry_offsets, wholememory_comm_t wm_global_comm, wholememory_comm_t wm_local_comm, int bucket_cross_or_local, @@ -338,7 +381,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( indices, indice_desc, dev_rank_id_count_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets, local_size, cross_size, bucket_cross_or_local, @@ -361,7 +404,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( dev_bucket_indices, dev_indice_map, dev_rank_id_count_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets, local_size, cross_size, bucket_cross_or_local, @@ -384,7 +427,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( wholememory_error_code_t bucket_local_ids_func(void* indices, wholememory_array_description_t indice_desc, int64_t* host_bucket_id_count, - size_t embedding_entry_count_per_rank, + size_t* dev_embedding_entry_offsets, wholememory_comm_t wm_local_comm, wholememory_comm_t wm_cross_comm, wm_thrust_allocator* p_thrust_allocator, @@ -409,7 +452,7 @@ wholememory_error_code_t bucket_local_ids_func(void* indices, indices, indice_desc, dev_rank_id_count_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets, local_size, cross_size, 1, diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h index d6d061c5e..60665c9db 100644 --- a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h @@ -29,7 +29,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( void* dev_bucket_indices, void* dev_indice_map, int64_t* host_bucket_id_count, - size_t embedding_entry_count_per_rank, + size_t* dev_embedding_entry_offsets, wholememory_comm_t wm_global_comm, wholememory_comm_t wm_local_comm, int bucket_cross_or_local, // 0: cross, 1: local @@ -40,7 +40,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( wholememory_error_code_t bucket_local_ids_func(void* indices, wholememory_array_description_t indice_desc, int64_t* host_bucket_id_count, - size_t embedding_entry_count_per_rank, + size_t* dev_embedding_entry_offsets, wholememory_comm_t wm_local_comm, wholememory_comm_t wm_cross_comm, wm_thrust_allocator* p_thrust_allocator, diff --git a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu index 3e29d29fc..543bdfd76 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,7 +39,6 @@ static wholememory_error_code_t wholememory_cross_gather( void* output, wholememory_matrix_description_t output_desc, int64_t* host_bucket_id_count_ptr, - size_t embedding_entry_count_per_rank, wholememory_comm_t wm_local_comm, wholememory_comm_t wm_cross_comm, wm_thrust_allocator* p_thrust_allocator, @@ -130,24 +129,10 @@ wholememory_error_code_t wholememory_gather_hierarchy( wholememory_desc.storage_offset + wholememory_desc.sizes[1] > wholememory_desc.stride) { return WHOLEMEMORY_INVALID_INPUT; } - bool sort_unique_indices = true; wm_thrust_allocator thrust_allocator(p_env_fns); - size_t embedding_size_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL( - wholememory_get_partition_plan(&embedding_size_per_rank, wholememory_handle)); - size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); - size_t embedding_entry_size = element_size * wholememory_desc.stride; - WHOLEMEMORY_EXPECTS_NOTHROW( - embedding_size_per_rank % embedding_entry_size == 0, - "embedding_size_per_rank=%ld is not multiple of embedding_entry_size=%ldx%ld", - embedding_size_per_rank, - element_size, - wholememory_desc.stride); - size_t embedding_entry_count_per_rank = embedding_size_per_rank / embedding_entry_size; - wholememory_comm_t wm_global_comm; int world_size, world_rank; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_global_comm, wholememory_handle)); @@ -158,8 +143,6 @@ wholememory_error_code_t wholememory_gather_hierarchy( int local_size, local_rank; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_get_local_communicator(&wm_local_comm, wholememory_handle)); - // WHOLEMEMORY_RETURN_ON_FAIL(wholememory_split_communicator( - // &wm_local_comm, wm_global_comm, world_rank / local_size, world_rank % local_size)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&local_rank, wm_local_comm)); @@ -167,11 +150,35 @@ wholememory_error_code_t wholememory_gather_hierarchy( int cross_size; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_get_cross_communicator(&wm_cross_comm, wholememory_handle)); - // WHOLEMEMORY_RETURN_ON_FAIL(wholememory_split_communicator( - // &wm_cross_comm, wm_global_comm, world_rank % local_size, world_rank / local_size)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); WHOLEMEMORY_CHECK_NOTHROW(world_size == local_size * cross_size); + size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); + size_t embedding_entry_size = element_size * wholememory_desc.stride; + temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + std::vector host_embedding_entry_offsets(world_size + 1); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_rank_partition_offsets( + host_embedding_entry_offsets.data(), wholememory_handle)); + for (int i = 0; i < world_size + 1; i++) { + size_t offset = host_embedding_entry_offsets[i]; + WHOLEMEMORY_EXPECTS_NOTHROW( + offset % embedding_entry_size == 0, + "embedding memory offset of rank%d=%ld is not multiple of embedding_entry_size=%ldx%ld", + i, + offset, + element_size, + wholememory_desc.stride); + host_embedding_entry_offsets[i] /= embedding_entry_size; + } + + WM_CUDA_CHECK(cudaMemcpyAsync(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets.data(), + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice, + stream)); + temp_memory_handle dev_bucket_indices_handle(p_env_fns); void* dev_bucket_indices_ptr = dev_bucket_indices_handle.device_malloc(indice_desc.size, indice_desc.dtype); @@ -191,7 +198,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( dev_bucket_indices_ptr, dev_bucket_ids_map_ptr, host_bucket_id_count.data(), - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, wm_global_comm, wm_local_comm, 0, @@ -244,7 +251,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( bucket_local_ids_func(cross_gather_indices_handle.pointer(), cross_gather_indices_desc, host_cross_bucket_id_count.data(), - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, wm_local_comm, wm_cross_comm, &thrust_allocator, @@ -262,7 +269,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( cross_gather_indices_ptr, dev_cross_gather_id_map_ptr, host_cross_bucket_id_count.data(), - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, wm_global_comm, wm_local_comm, 1, @@ -286,7 +293,6 @@ wholememory_error_code_t wholememory_gather_hierarchy( dev_cross_gather_buffer_ptr, cross_gather_buffer_desc, host_cross_bucket_id_count.data(), - embedding_entry_count_per_rank, wm_local_comm, wm_cross_comm, &thrust_allocator, diff --git a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu index f86c4b93f..506e21ca0 100644 --- a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu @@ -300,9 +300,11 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_indices_count(0), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_indices_count(0), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_indices_count(0), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_indices_count(0), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_memory_location(WHOLEMEMORY_ML_HOST), @@ -312,12 +314,20 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_memory_location(WHOLEMEMORY_ML_HOST), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_memory_location(WHOLEMEMORY_ML_HOST), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).use_random_partition(), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).use_random_partition(), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).use_random_partition(), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_memory_location(WHOLEMEMORY_ML_HOST) .use_random_partition(), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .use_random_partition(), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_memory_location(WHOLEMEMORY_ML_HOST) @@ -353,18 +363,27 @@ INSTANTIATE_TEST_SUITE_P( .set_embedding_dim(11) .set_embedding_stride(12) .set_indices_count(100005), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_embedding_dim(11) + .set_embedding_stride(12) + .set_indices_count(100005), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(128), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(128), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_embedding_dim(128), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_dim(128), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(127), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(127), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_embedding_dim(127), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_dim(127), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(129), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(129), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_embedding_dim(129), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_dim(129), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(513), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(513), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_embedding_dim(513), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_dim(513), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_type(WHOLEMEMORY_DT_HALF), @@ -383,6 +402,9 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_output_type(WHOLEMEMORY_DT_HALF), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_output_type(WHOLEMEMORY_DT_HALF), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_type(WHOLEMEMORY_DT_HALF) @@ -395,6 +417,10 @@ INSTANTIATE_TEST_SUITE_P( .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_type(WHOLEMEMORY_DT_HALF) .set_output_type(WHOLEMEMORY_DT_HALF), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_embedding_type(WHOLEMEMORY_DT_HALF) + .set_output_type(WHOLEMEMORY_DT_HALF), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_indices_type(WHOLEMEMORY_DT_INT64), @@ -404,6 +430,9 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_indices_type(WHOLEMEMORY_DT_INT64), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_indices_type(WHOLEMEMORY_DT_INT64), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_stride(33), @@ -411,9 +440,11 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_stride(33), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_stride(33), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_output_stride(33), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_output_stride(33), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_output_stride(33), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_output_stride(33), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_type(WHOLEMEMORY_DT_HALF) @@ -426,6 +457,10 @@ INSTANTIATE_TEST_SUITE_P( .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_type(WHOLEMEMORY_DT_HALF) .set_embedding_stride(33), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_embedding_type(WHOLEMEMORY_DT_HALF) + .set_embedding_stride(33), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) #ifdef WITH_NVSHMEM_SUPPORT ,