Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
set a knob for sort-unique
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuofan1123 committed Oct 8, 2024
1 parent 5aed03c commit 2d4fceb
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 68 deletions.
109 changes: 77 additions & 32 deletions cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@

namespace wholememory_ops {

template <typename IndexT, int CROSS_OR_LOCAL = 0>
template <typename IndexT, int BUCKET_CROSS_OR_LOCAL = 0>
__global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices,
size_t indice_count,
int64_t* dev_rank_id_count_ptr,
size_t embedding_entry_count_per_rank,
int local_size,
int bucket_size)
int nbucket)
{
extern __shared__ int rank_count_shared[];
for (int idx = threadIdx.x; idx < bucket_size; idx += blockDim.x) {
for (int idx = threadIdx.x; idx < nbucket; idx += blockDim.x) {
rank_count_shared[idx] = 0;
}
__syncthreads();
Expand All @@ -52,9 +52,9 @@ __global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices,
if (node_idx < 0) continue;
int rank = node_idx / embedding_entry_count_per_rank;
int bucket = 0;
if (CROSS_OR_LOCAL == 0) // bucket cross ranks
if (BUCKET_CROSS_OR_LOCAL == 0)
bucket = rank % local_size;
else // bucket local ranks
else
bucket = rank / local_size;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
atomicAdd_block(&rank_count_shared[bucket], 1);
Expand All @@ -63,7 +63,7 @@ __global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices,
#endif
}
__syncthreads();
for (int idx = threadIdx.x; idx < bucket_size; idx += blockDim.x) {
for (int idx = threadIdx.x; idx < nbucket; idx += blockDim.x) {
atomicAdd(reinterpret_cast<unsigned long long*>(dev_rank_id_count_ptr) + idx,
static_cast<unsigned long long>(rank_count_shared[idx]));
}
Expand Down Expand Up @@ -113,17 +113,17 @@ void bucket_ids_for_hierarchy_temp_func(const void* indices,

REGISTER_DISPATCH_ONE_TYPE(BucketIdsForHierarchy, bucket_ids_for_hierarchy_temp_func, SINT3264)

template <typename IndexT>
template <typename IndexT, int BUCKET_CROSS_OR_LOCAL = 0>
__global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices,
size_t indice_count,
IndexT* dev_bucket_indices,
IndexT* dev_indice_map,
const int64_t* dev_rank_id_offset_ptr,
size_t embedding_entry_count_per_rank,
int local_size,
int nbucket,
int64_t* dev_bucket_atomic_add_ptr)
{
int nbucket = local_size;
constexpr size_t shared_mem_size = 24576;
__shared__ char shared_mem[shared_mem_size];
int* block_bucket_count_shared = reinterpret_cast<int*>(shared_mem);
Expand Down Expand Up @@ -155,7 +155,13 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices,
IndexT indice = indices[load_idx];

buffer_load[i] = indice;
int bucket_idx = (indice / embedding_entry_count_per_rank) % local_size;
int bucket_idx = 0;
int rank = indice / embedding_entry_count_per_rank;
if (BUCKET_CROSS_OR_LOCAL == 0) {
bucket_idx = rank % local_size;
} else {
bucket_idx = rank / local_size;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
atomicAdd_block(&block_bucket_count_shared[bucket_idx], 1);
#else
Expand All @@ -181,7 +187,13 @@ __global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices,
IndexT indice = buffer_load[i];
IndexT load_idx = i + load_offset;
if (load_idx >= indice_count) break;
int bucket_idx = (indice / embedding_entry_count_per_rank) % local_size;
int bucket_idx = 0;
int rank = indice / embedding_entry_count_per_rank;
if (BUCKET_CROSS_OR_LOCAL == 0) {
bucket_idx = rank % local_size;
} else {
bucket_idx = rank / local_size;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
int block_bucket_inc = atomicAdd_block(&block_bucket_atomic_add_shared[bucket_idx], 1);
#else
Expand Down Expand Up @@ -213,6 +225,8 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices,
const int64_t* dev_rank_id_count_ptr,
size_t embedding_entry_count_per_rank,
int local_size,
int cross_size,
int bucket_cross_or_local,
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
int sm_count,
Expand All @@ -221,44 +235,63 @@ void reorder_ids_for_hierarchy_temp_func(const void* indices,
WHOLEMEMORY_CHECK(indice_desc.storage_offset == 0);
WHOLEMEMORY_CHECK(indice_desc.dtype == WHOLEMEMORY_DT_INT ||
indice_desc.dtype == WHOLEMEMORY_DT_INT64);

int nbucket = 0;
if (bucket_cross_or_local == 0) {
nbucket = local_size;
} else {
nbucket = cross_size;
}
temp_memory_handle dev_rank_id_offset_handle(p_env_fns);
int64_t* dev_rank_id_offset_ptr = static_cast<int64_t*>(
dev_rank_id_offset_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64));
int64_t* dev_rank_id_offset_ptr =
static_cast<int64_t*>(dev_rank_id_offset_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64));
void* cub_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(cub_temp_storage,
temp_storage_bytes,
dev_rank_id_count_ptr,
dev_rank_id_offset_ptr,
local_size,
nbucket,
stream);
cub_temp_storage = p_thrust_allocator->allocate(temp_storage_bytes);
cub::DeviceScan::ExclusiveSum(cub_temp_storage,
temp_storage_bytes,
dev_rank_id_count_ptr,
dev_rank_id_offset_ptr,
local_size,
nbucket,
stream);
p_thrust_allocator->deallocate(reinterpret_cast<char*>(cub_temp_storage), temp_storage_bytes);

temp_memory_handle dev_bucket_atomic_add_handle(p_env_fns);
int64_t* dev_bucket_atomic_add_ptr = static_cast<int64_t*>(
dev_bucket_atomic_add_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64));
cudaMemsetAsync((void*)dev_bucket_atomic_add_ptr, 0, sizeof(int64_t) * local_size, stream);
dev_bucket_atomic_add_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64));
cudaMemsetAsync((void*)dev_bucket_atomic_add_ptr, 0, sizeof(int64_t) * nbucket, stream);
static constexpr int BLOCK_SIZE = 128;
int block_count = wholememory::div_rounding_up_unsafe(indice_desc.size, BLOCK_SIZE);
block_count = std::min(block_count, sm_count * 4);

reorder_ids_for_hierarchy_kernel<<<block_count, BLOCK_SIZE, 0, stream>>>(
static_cast<const IndexT*>(indices),
indice_desc.size,
static_cast<IndexT*>(dev_bucket_indices),
static_cast<IndexT*>(dev_indice_map),
dev_rank_id_offset_ptr,
embedding_entry_count_per_rank,
local_size,
dev_bucket_atomic_add_ptr);
if (bucket_cross_or_local == 0)
reorder_ids_for_hierarchy_kernel<IndexT, 0>
<<<block_count, BLOCK_SIZE, 0, stream>>>(static_cast<const IndexT*>(indices),
indice_desc.size,
static_cast<IndexT*>(dev_bucket_indices),
static_cast<IndexT*>(dev_indice_map),
dev_rank_id_offset_ptr,
embedding_entry_count_per_rank,
local_size,
nbucket,
dev_bucket_atomic_add_ptr);
else
reorder_ids_for_hierarchy_kernel<IndexT, 1>
<<<block_count, BLOCK_SIZE, 0, stream>>>(static_cast<const IndexT*>(indices),
indice_desc.size,
static_cast<IndexT*>(dev_bucket_indices),
static_cast<IndexT*>(dev_indice_map),
dev_rank_id_offset_ptr,
embedding_entry_count_per_rank,
local_size,
nbucket,
dev_bucket_atomic_add_ptr);
;
}

REGISTER_DISPATCH_ONE_TYPE(ReorderIdsForHierarchy, reorder_ids_for_hierarchy_temp_func, SINT3264)
Expand All @@ -272,23 +305,33 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func(
size_t embedding_entry_count_per_rank,
wholememory_comm_t wm_global_comm,
wholememory_comm_t wm_local_comm,
int bucket_cross_or_local,
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream)
{
if (indice_desc.size == 0) { return WHOLEMEMORY_SUCCESS; }
int world_size, local_size;
int world_size, local_size, cross_size;
WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_global_comm));
WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm));
WHOLEMEMORY_CHECK_NOTHROW(world_size % local_size == 0);
cross_size = world_size / local_size;

WHOLEMEMORY_EXPECTS_NOTHROW(bucket_cross_or_local == 0 || bucket_cross_or_local == 1,
"param bucket_cross_or_local must be 0 or 1, 0: cross, 1: local");
int nbucket = 0;
if (bucket_cross_or_local == 0) { // bucket by cross id
nbucket = local_size;
} else { // bucket by local id
nbucket = cross_size;
}
constexpr int K_DEFAULT_SM_COUNT = 108;
auto prop = get_device_prop(-1);
int sm_count = (prop != nullptr) ? prop->multiProcessorCount : K_DEFAULT_SM_COUNT;
temp_memory_handle dev_rank_id_count_handle(p_env_fns);
int64_t* dev_rank_id_count_ptr =
static_cast<int64_t*>(dev_rank_id_count_handle.device_malloc(local_size, WHOLEMEMORY_DT_INT64));
cudaMemsetAsync((void*)dev_rank_id_count_ptr, 0, sizeof(int64_t) * local_size, stream);
static_cast<int64_t*>(dev_rank_id_count_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64));
cudaMemsetAsync((void*)dev_rank_id_count_ptr, 0, sizeof(int64_t) * nbucket, stream);
try {
DISPATCH_ONE_TYPE(indice_desc.dtype,
BucketIdsForHierarchy,
Expand All @@ -297,8 +340,8 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func(
dev_rank_id_count_ptr,
embedding_entry_count_per_rank,
local_size,
0, // ignore
0,
cross_size,
bucket_cross_or_local,
sm_count,
stream);
} catch (wholememory::cuda_error& wce) {
Expand All @@ -307,7 +350,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func(
}
WM_CUDA_CHECK_NO_THROW(cudaMemcpyAsync(host_bucket_id_count,
dev_rank_id_count_ptr,
local_size * sizeof(int64_t),
nbucket * sizeof(int64_t),
cudaMemcpyDeviceToHost,
stream));
try {
Expand All @@ -320,6 +363,8 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func(
dev_rank_id_count_ptr,
embedding_entry_count_per_rank,
local_size,
cross_size,
bucket_cross_or_local,
p_thrust_allocator,
p_env_fns,
sm_count,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func(
size_t embedding_entry_count_per_rank,
wholememory_comm_t wm_global_comm,
wholememory_comm_t wm_local_comm,
int bucket_cross_or_local, // 0: cross, 1: local
wm_thrust_allocator* p_thrust_allocator,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream);
Expand Down
Loading

0 comments on commit 2d4fceb

Please sign in to comment.