diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu index 6ad83a49a..7c50d2b2f 100644 --- a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu @@ -33,16 +33,16 @@ namespace wholememory_ops { -template +template __global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices, size_t indice_count, int64_t* dev_rank_id_count_ptr, size_t embedding_entry_count_per_rank, int local_size, - int bucket_size) + int nbucket) { extern __shared__ int rank_count_shared[]; - for (int idx = threadIdx.x; idx < bucket_size; idx += blockDim.x) { + for (int idx = threadIdx.x; idx < nbucket; idx += blockDim.x) { rank_count_shared[idx] = 0; } __syncthreads(); @@ -52,9 +52,9 @@ __global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices, if (node_idx < 0) continue; int rank = node_idx / embedding_entry_count_per_rank; int bucket = 0; - if (CROSS_OR_LOCAL == 0) // bucket cross ranks + if (BUCKET_CROSS_OR_LOCAL == 0) bucket = rank % local_size; - else // bucket local ranks + else bucket = rank / local_size; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 atomicAdd_block(&rank_count_shared[bucket], 1); @@ -63,7 +63,7 @@ __global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices, #endif } __syncthreads(); - for (int idx = threadIdx.x; idx < bucket_size; idx += blockDim.x) { + for (int idx = threadIdx.x; idx < nbucket; idx += blockDim.x) { atomicAdd(reinterpret_cast(dev_rank_id_count_ptr) + idx, static_cast(rank_count_shared[idx])); } @@ -113,7 +113,7 @@ void bucket_ids_for_hierarchy_temp_func(const void* indices, REGISTER_DISPATCH_ONE_TYPE(BucketIdsForHierarchy, bucket_ids_for_hierarchy_temp_func, SINT3264) -template +template __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, size_t indice_count, IndexT* dev_bucket_indices, @@ -121,9 +121,9 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, const int64_t* dev_rank_id_offset_ptr, size_t embedding_entry_count_per_rank, int local_size, + int nbucket, int64_t* dev_bucket_atomic_add_ptr) { - int nbucket = local_size; constexpr size_t shared_mem_size = 24576; __shared__ char shared_mem[shared_mem_size]; int* block_bucket_count_shared = reinterpret_cast(shared_mem); @@ -155,7 +155,13 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, IndexT indice = indices[load_idx]; buffer_load[i] = indice; - int bucket_idx = (indice / embedding_entry_count_per_rank) % local_size; + int bucket_idx = 0; + int rank = indice / embedding_entry_count_per_rank; + if (BUCKET_CROSS_OR_LOCAL == 0) { + bucket_idx = rank % local_size; + } else { + bucket_idx = rank / local_size; + } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 atomicAdd_block(&block_bucket_count_shared[bucket_idx], 1); #else @@ -181,7 +187,13 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, IndexT indice = buffer_load[i]; IndexT load_idx = i + load_offset; if (load_idx >= indice_count) break; - int bucket_idx = (indice / embedding_entry_count_per_rank) % local_size; + int bucket_idx = 0; + int rank = indice / embedding_entry_count_per_rank; + if (BUCKET_CROSS_OR_LOCAL == 0) { + bucket_idx = rank % local_size; + } else { + bucket_idx = rank / local_size; + } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 int block_bucket_inc = atomicAdd_block(&block_bucket_atomic_add_shared[bucket_idx], 1); #else @@ -213,6 +225,8 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices, const int64_t* dev_rank_id_count_ptr, size_t embedding_entry_count_per_rank, int local_size, + int cross_size, + int bucket_cross_or_local, wm_thrust_allocator* p_thrust_allocator, wholememory_env_func_t* p_env_fns, int sm_count, @@ -221,44 +235,63 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices, WHOLEMEMORY_CHECK(indice_desc.storage_offset == 0); WHOLEMEMORY_CHECK(indice_desc.dtype == WHOLEMEMORY_DT_INT || indice_desc.dtype == WHOLEMEMORY_DT_INT64); - + int nbucket = 0; + if (bucket_cross_or_local == 0) { + nbucket = local_size; + } else { + nbucket = cross_size; + } temp_memory_handle dev_rank_id_offset_handle(p_env_fns); - int64_t* dev_rank_id_offset_ptr = static_cast( - dev_rank_id_offset_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64)); + int64_t* dev_rank_id_offset_ptr = + static_cast(dev_rank_id_offset_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64)); void* cub_temp_storage = NULL; size_t temp_storage_bytes = 0; cub::DeviceScan::ExclusiveSum(cub_temp_storage, temp_storage_bytes, dev_rank_id_count_ptr, dev_rank_id_offset_ptr, - local_size, + nbucket, stream); cub_temp_storage = p_thrust_allocator->allocate(temp_storage_bytes); cub::DeviceScan::ExclusiveSum(cub_temp_storage, temp_storage_bytes, dev_rank_id_count_ptr, dev_rank_id_offset_ptr, - local_size, + nbucket, stream); p_thrust_allocator->deallocate(reinterpret_cast(cub_temp_storage), temp_storage_bytes); temp_memory_handle dev_bucket_atomic_add_handle(p_env_fns); int64_t* dev_bucket_atomic_add_ptr = static_cast( - dev_bucket_atomic_add_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64)); - cudaMemsetAsync((void*)dev_bucket_atomic_add_ptr, 0, sizeof(int64_t) * local_size, stream); + dev_bucket_atomic_add_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64)); + cudaMemsetAsync((void*)dev_bucket_atomic_add_ptr, 0, sizeof(int64_t) * nbucket, stream); static constexpr int BLOCK_SIZE = 128; int block_count = wholememory::div_rounding_up_unsafe(indice_desc.size, BLOCK_SIZE); block_count = std::min(block_count, sm_count * 4); - reorder_ids_for_hierarchy_kernel<<>>( - static_cast(indices), - indice_desc.size, - static_cast(dev_bucket_indices), - static_cast(dev_indice_map), - dev_rank_id_offset_ptr, - embedding_entry_count_per_rank, - local_size, - dev_bucket_atomic_add_ptr); + if (bucket_cross_or_local == 0) + reorder_ids_for_hierarchy_kernel + <<>>(static_cast(indices), + indice_desc.size, + static_cast(dev_bucket_indices), + static_cast(dev_indice_map), + dev_rank_id_offset_ptr, + embedding_entry_count_per_rank, + local_size, + nbucket, + dev_bucket_atomic_add_ptr); + else + reorder_ids_for_hierarchy_kernel + <<>>(static_cast(indices), + indice_desc.size, + static_cast(dev_bucket_indices), + static_cast(dev_indice_map), + dev_rank_id_offset_ptr, + embedding_entry_count_per_rank, + local_size, + nbucket, + dev_bucket_atomic_add_ptr); + ; } REGISTER_DISPATCH_ONE_TYPE(ReorderIdsForHierarchy, reorder_ids_for_hierarchy_temp_func, SINT3264) @@ -272,23 +305,33 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( size_t embedding_entry_count_per_rank, wholememory_comm_t wm_global_comm, wholememory_comm_t wm_local_comm, + int bucket_cross_or_local, wm_thrust_allocator* p_thrust_allocator, wholememory_env_func_t* p_env_fns, cudaStream_t stream) { if (indice_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } - int world_size, local_size; + int world_size, local_size, cross_size; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_global_comm)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); WHOLEMEMORY_CHECK_NOTHROW(world_size % local_size == 0); + cross_size = world_size / local_size; + WHOLEMEMORY_EXPECTS_NOTHROW(bucket_cross_or_local == 0 || bucket_cross_or_local == 1, + "param bucket_cross_or_local must be 0 or 1, 0: cross, 1: local"); + int nbucket = 0; + if (bucket_cross_or_local == 0) { // bucket by cross id + nbucket = local_size; + } else { // bucket by local id + nbucket = cross_size; + } constexpr int K_DEFAULT_SM_COUNT = 108; auto prop = get_device_prop(-1); int sm_count = (prop != nullptr) ? prop->multiProcessorCount : K_DEFAULT_SM_COUNT; temp_memory_handle dev_rank_id_count_handle(p_env_fns); int64_t* dev_rank_id_count_ptr = - static_cast(dev_rank_id_count_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64)); - cudaMemsetAsync((void*)dev_rank_id_count_ptr, 0, sizeof(int64_t) * local_size, stream); + static_cast(dev_rank_id_count_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64)); + cudaMemsetAsync((void*)dev_rank_id_count_ptr, 0, sizeof(int64_t) * nbucket, stream); try { DISPATCH_ONE_TYPE(indice_desc.dtype, BucketIdsForHierarchy, @@ -297,8 +340,8 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( dev_rank_id_count_ptr, embedding_entry_count_per_rank, local_size, - 0, // ignore - 0, + cross_size, + bucket_cross_or_local, sm_count, stream); } catch (wholememory::cuda_error& wce) { @@ -307,7 +350,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( } WM_CUDA_CHECK_NO_THROW(cudaMemcpyAsync(host_bucket_id_count, dev_rank_id_count_ptr, - local_size * sizeof(int64_t), + nbucket * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); try { @@ -320,6 +363,8 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( dev_rank_id_count_ptr, embedding_entry_count_per_rank, local_size, + cross_size, + bucket_cross_or_local, p_thrust_allocator, p_env_fns, sm_count, diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h index a86a9945e..d6d061c5e 100644 --- a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h @@ -32,6 +32,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( size_t embedding_entry_count_per_rank, wholememory_comm_t wm_global_comm, wholememory_comm_t wm_local_comm, + int bucket_cross_or_local, // 0: cross, 1: local wm_thrust_allocator* p_thrust_allocator, wholememory_env_func_t* p_env_fns, cudaStream_t stream); diff --git a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu index 808ebe768..3e29d29fc 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu @@ -38,6 +38,7 @@ static wholememory_error_code_t wholememory_cross_gather( wholememory_array_description_t indice_desc, void* output, wholememory_matrix_description_t output_desc, + int64_t* host_bucket_id_count_ptr, size_t embedding_entry_count_per_rank, wholememory_comm_t wm_local_comm, wholememory_comm_t wm_cross_comm, @@ -48,27 +49,15 @@ static wholememory_error_code_t wholememory_cross_gather( { int cross_size; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); - // bucket ids - std::vector host_bucket_id_count(cross_size, 0); std::vector host_bucket_id_offset(cross_size); std::vector host_recv_id_count(cross_size, 0); std::vector host_recv_id_offset(cross_size); - bucket_local_ids_func(indices, - indice_desc, - host_bucket_id_count.data(), - embedding_entry_count_per_rank, - wm_local_comm, - wm_cross_comm, - p_thrust_allocator, - p_env_fns, - stream); - WM_CUDA_CHECK(cudaStreamSynchronize(stream)); // exchange node count wm_cross_comm->host_alltoall( - host_bucket_id_count.data(), host_recv_id_count.data(), 1, WHOLEMEMORY_DT_INT64); + host_bucket_id_count_ptr, host_recv_id_count.data(), 1, WHOLEMEMORY_DT_INT64); host_bucket_id_offset[0] = 0; for (int i = 1; i < cross_size; i++) - host_bucket_id_offset[i] = host_bucket_id_offset[i - 1] + host_bucket_id_count[i - 1]; + host_bucket_id_offset[i] = host_bucket_id_offset[i - 1] + host_bucket_id_count_ptr[i - 1]; wm_cross_comm->sync_stream(); // exchange indices int64_t total_recv_count = 0; @@ -81,7 +70,7 @@ static wholememory_error_code_t wholememory_cross_gather( dev_recv_bucket_indices_handle.device_malloc(total_recv_count, indice_desc.dtype); wm_cross_comm->alltoallv(indices, dev_recv_bucket_indices_ptr, - reinterpret_cast(host_bucket_id_count.data()), + reinterpret_cast(host_bucket_id_count_ptr), reinterpret_cast(host_bucket_id_offset.data()), reinterpret_cast(host_recv_id_count.data()), reinterpret_cast(host_recv_id_offset.data()), @@ -117,7 +106,7 @@ static wholememory_error_code_t wholememory_cross_gather( wholememory_desc.sizes[1] * wholememory_dtype_get_element_size(output_desc.dtype); WHOLEMEMORY_RETURN_ON_FAIL(exchange_embeddings_nccl_func(dev_local_gather_buffer_ptr, host_recv_id_count.data(), - host_bucket_id_count.data(), + host_bucket_id_count_ptr, output, output_embedding_size, wm_cross_comm, @@ -142,6 +131,8 @@ wholememory_error_code_t wholememory_gather_hierarchy( return WHOLEMEMORY_INVALID_INPUT; } + bool sort_unique_indices = true; + wm_thrust_allocator thrust_allocator(p_env_fns); size_t embedding_size_per_rank; @@ -168,7 +159,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( WHOLEMEMORY_RETURN_ON_FAIL( wholememory_get_local_communicator(&wm_local_comm, wholememory_handle)); // WHOLEMEMORY_RETURN_ON_FAIL(wholememory_split_communicator( - // &wm_local_comm, wm_global_comm, world_rank / local_size, world_rank % local_size)); + // &wm_local_comm, wm_global_comm, world_rank / local_size, world_rank % local_size)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&local_rank, wm_local_comm)); @@ -177,7 +168,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( WHOLEMEMORY_RETURN_ON_FAIL( wholememory_get_cross_communicator(&wm_cross_comm, wholememory_handle)); // WHOLEMEMORY_RETURN_ON_FAIL(wholememory_split_communicator( - // &wm_cross_comm, wm_global_comm, world_rank % local_size, world_rank / local_size)); + // &wm_cross_comm, wm_global_comm, world_rank % local_size, world_rank / local_size)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); WHOLEMEMORY_CHECK_NOTHROW(world_size == local_size * cross_size); @@ -203,6 +194,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( embedding_entry_count_per_rank, wm_global_comm, wm_local_comm, + 0, &thrust_allocator, p_env_fns, stream)); @@ -235,31 +227,65 @@ wholememory_error_code_t wholememory_gather_hierarchy( stream); wm_local_comm->sync_stream(stream); WM_CUDA_CHECK(cudaGetLastError()); - // sort unique recv indices - temp_memory_handle sort_unique_indices_handle(p_env_fns); - wholememory_array_description_t sort_unique_indice_desc; - temp_memory_handle dev_sort_unique_ids_map_handle(p_env_fns); - sort_unique_ids_for_hierarchy_func(dev_recv_bucket_indices_ptr, - recv_bucket_indices_desc, - &sort_unique_indices_handle, - &sort_unique_indice_desc, - &dev_sort_unique_ids_map_handle, - &thrust_allocator, - p_env_fns, - stream); + // sort unique / bucket recv indices + temp_memory_handle cross_gather_indices_handle(p_env_fns); + wholememory_array_description_t cross_gather_indices_desc; + temp_memory_handle dev_cross_gather_id_map_handle(p_env_fns); + std::vector host_cross_bucket_id_count(cross_size, 0); + if (sort_unique_indices) { + sort_unique_ids_for_hierarchy_func(dev_recv_bucket_indices_ptr, + recv_bucket_indices_desc, + &cross_gather_indices_handle, + &cross_gather_indices_desc, + &dev_cross_gather_id_map_handle, + &thrust_allocator, + p_env_fns, + stream); + bucket_local_ids_func(cross_gather_indices_handle.pointer(), + cross_gather_indices_desc, + host_cross_bucket_id_count.data(), + embedding_entry_count_per_rank, + wm_local_comm, + wm_cross_comm, + &thrust_allocator, + p_env_fns, + stream); + } else { + void* cross_gather_indices_ptr = cross_gather_indices_handle.device_malloc( + recv_bucket_indices_desc.size, recv_bucket_indices_desc.dtype); + void* dev_cross_gather_id_map_ptr = dev_cross_gather_id_map_handle.device_malloc( + recv_bucket_indices_desc.size, recv_bucket_indices_desc.dtype); + cross_gather_indices_desc = recv_bucket_indices_desc; + WHOLEMEMORY_RETURN_ON_FAIL( + bucket_and_reorder_ids_for_hierarchy_func(dev_recv_bucket_indices_ptr, + recv_bucket_indices_desc, + cross_gather_indices_ptr, + dev_cross_gather_id_map_ptr, + host_cross_bucket_id_count.data(), + embedding_entry_count_per_rank, + wm_global_comm, + wm_local_comm, + 1, + &thrust_allocator, + p_env_fns, + stream)); + } + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); // cross gather temp_memory_handle dev_cross_gather_buffer_handle(p_env_fns); void* dev_cross_gather_buffer_ptr = dev_cross_gather_buffer_handle.device_malloc( - wholememory_desc.sizes[1] * sort_unique_indice_desc.size, output_desc.dtype); - int64_t cross_gather_buffer_size[2] = {sort_unique_indice_desc.size, wholememory_desc.sizes[1]}; + wholememory_desc.sizes[1] * cross_gather_indices_desc.size, output_desc.dtype); + int64_t cross_gather_buffer_size[2] = {cross_gather_indices_desc.size, + wholememory_desc.sizes[1]}; wholememory_matrix_description_t cross_gather_buffer_desc = wholememory_create_matrix_desc( cross_gather_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); wholememory_cross_gather(wholememory_handle, wholememory_desc, - sort_unique_indices_handle.pointer(), - sort_unique_indice_desc, + cross_gather_indices_handle.pointer(), + cross_gather_indices_desc, dev_cross_gather_buffer_ptr, cross_gather_buffer_desc, + host_cross_bucket_id_count.data(), embedding_entry_count_per_rank, wm_local_comm, wm_cross_comm, @@ -267,7 +293,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( p_env_fns, stream, gather_sms); - // sort-unique reorder + // cross gather reorder temp_memory_handle dev_embedding_map_buffer_handle(p_env_fns); void* dev_embedding_map_buffer_ptr = dev_embedding_map_buffer_handle.device_malloc( wholememory_desc.sizes[1] * total_recv_count, output_desc.dtype); @@ -278,7 +304,7 @@ wholememory_error_code_t wholememory_gather_hierarchy( wholememory_create_continuous_global_reference(dev_cross_gather_buffer_ptr); WHOLEMEMORY_RETURN_ON_FAIL(gather_func(cross_gather_fake_gref, cross_gather_buffer_desc, - dev_sort_unique_ids_map_handle.pointer(), + dev_cross_gather_id_map_handle.pointer(), recv_bucket_indices_desc, dev_embedding_map_buffer_ptr, embedding_map_buffer_desc,