From 7aefaeb70dda5d028c2ed5c98e8437457dddad86 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Wed, 6 Nov 2024 21:45:55 +0800 Subject: [PATCH] update wholegraph --- cpp/bench/CMakeLists.txt | 4 +- cpp/bench/common/wholegraph_benchmark.cpp | 17 + cpp/bench/common/wholegraph_benchmark.hpp | 2 + .../wholememory_ops/gather_scatter_bench.cu | 73 ++- cpp/include/wholememory/device_reference.cuh | 33 +- cpp/include/wholememory/embedding.h | 7 +- cpp/include/wholememory/global_reference.h | 8 +- cpp/include/wholememory/wholememory.h | 104 +++- cpp/include/wholememory/wholememory_tensor.h | 39 +- cpp/src/wholememory/communicator.cpp | 7 + cpp/src/wholememory/communicator.hpp | 3 + cpp/src/wholememory/embedding.cpp | 108 +++- cpp/src/wholememory/embedding.hpp | 3 +- cpp/src/wholememory/memory_handle.cpp | 541 +++++++++++++----- cpp/src/wholememory/memory_handle.hpp | 35 +- cpp/src/wholememory/wholememory.cpp | 68 ++- cpp/src/wholememory/wholememory_tensor.cpp | 124 +++- .../bucket_ids_for_hierarchy_func.cu | 474 +++++++++++++++ .../functions/bucket_ids_for_hierarchy_func.h | 50 ++ .../functions/bucket_ids_func.cu | 50 +- .../functions/bucket_ids_func.h | 2 +- .../functions/embedding_cache_func.cu | 27 +- .../functions/embedding_cache_func.h | 4 +- .../functions/exchange_ids_nccl_func.cu | 4 +- .../functions/exchange_ids_nccl_func.h | 4 +- .../functions/gather_scatter_func.cuh | 35 +- .../functions/map_indices_func.cu | 14 +- .../functions/nvshmem_device_reference.cuh | 86 ++- ...r_func_impl_floating_data_int32_indices.cu | 31 +- ...r_func_impl_floating_data_int64_indices.cu | 31 +- ...er_func_impl_integer_data_int32_indices.cu | 31 +- ...er_func_impl_integer_data_int64_indices.cu | 33 +- .../functions/nvshmem_gather_scatter_func.cuh | 82 +-- ...r_func_impl_floating_data_int32_indices.cu | 31 +- ...r_func_impl_floating_data_int64_indices.cu | 31 +- ...er_func_impl_integer_data_int32_indices.cu | 31 +- ...er_func_impl_integer_data_int64_indices.cu | 31 +- .../sort_unique_ids_for_hierarchy_func.cu | 145 +++++ .../sort_unique_ids_for_hierarchy_func.h | 35 ++ .../functions/sort_unique_indices_func.cu | 118 ++++ .../functions/sort_unique_indices_func.h | 37 ++ cpp/src/wholememory_ops/gather_op.cpp | 13 + cpp/src/wholememory_ops/gather_op_impl.h | 11 + .../gather_op_impl_hierarchy.cu | 360 ++++++++++++ .../wholememory_ops/gather_op_impl_nccl.cu | 43 +- .../wholememory_ops/gather_op_impl_nvshmem.cu | 57 +- .../scatter_op_impl.nvshmem.cu | 60 +- .../wholememory_ops/scatter_op_impl_nccl.cu | 41 +- .../wholememory_ops/embedding_test_utils.cu | 17 + .../wholememory_ops/embedding_test_utils.hpp | 2 + ...lememory_embedding_gradient_apply_tests.cu | 57 +- .../wholememory_embedding_tests.cu | 33 +- .../wholememory_gather_tests.cu | 62 +- .../wholememory_scatter_tests.cu | 27 +- .../binding/wholememory_binding.pyx | 127 ++-- .../pylibwholegraph/test_utils/test_comm.py | 29 +- .../pylibwholegraph/test_wholememory_io.py | 48 +- .../test_wholememory_tensor.py | 20 +- .../ops/test_wholegraph_gather_scatter.py | 24 +- .../pylibwholegraph/torch/common_options.py | 5 +- .../pylibwholegraph/torch/embedding.py | 39 +- .../pylibwholegraph/torch/tensor.py | 22 +- .../pylibwholegraph/torch/utils.py | 4 +- .../pylibwholegraph/torch/wholegraph_env.py | 6 - 64 files changed, 3036 insertions(+), 664 deletions(-) create mode 100644 cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu create mode 100644 cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h create mode 100644 cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.cu create mode 100644 cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h create mode 100644 cpp/src/wholememory_ops/functions/sort_unique_indices_func.cu create mode 100644 cpp/src/wholememory_ops/functions/sort_unique_indices_func.h create mode 100644 cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 89092a9..7736c04 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -35,7 +35,9 @@ function(ConfigureBench) rmm::rmm pthread ) - + if(BUILD_WITH_NVSHMEM) + target_compile_definitions(${BENCH_NAME} PRIVATE WITH_NVSHMEM_SUPPORT) + endif() set_target_properties( ${BENCH_NAME} PROPERTIES # set target compile options diff --git a/cpp/bench/common/wholegraph_benchmark.cpp b/cpp/bench/common/wholegraph_benchmark.cpp index e65d2b2..0be6855 100644 --- a/cpp/bench/common/wholegraph_benchmark.cpp +++ b/cpp/bench/common/wholegraph_benchmark.cpp @@ -52,6 +52,23 @@ void host_random_init_integer_indices(void* indices, } } +void host_random_partition(size_t* partition_sizes, size_t total_size, int partition_count) +{ + std::default_random_engine random_engine(0); + std::uniform_int_distribution uniform(90, 100); + size_t acc_size = 0; + size_t random_sum = 0; + for (int i = 0; i < partition_count; i++) { + partition_sizes[i] = (size_t)uniform(random_engine); + random_sum += partition_sizes[i]; + } + for (int i = 0; i < partition_count; i++) { + partition_sizes[i] = (size_t)((partition_sizes[i] / (double)random_sum) * total_size); + acc_size += partition_sizes[i]; + } + partition_sizes[0] += total_size - acc_size; +} + void MultiProcessMeasurePerformance(std::function run_fn, wholememory_comm_t& wm_comm, const PerformanceMeter& meter, diff --git a/cpp/bench/common/wholegraph_benchmark.hpp b/cpp/bench/common/wholegraph_benchmark.hpp index a870b8a..a3af9b1 100644 --- a/cpp/bench/common/wholegraph_benchmark.hpp +++ b/cpp/bench/common/wholegraph_benchmark.hpp @@ -35,6 +35,8 @@ void host_random_init_integer_indices(void* indices, wholememory_array_description_t indices_desc, int64_t max_indices); +void host_random_partition(size_t* partition_sizes, size_t total_size, int partition_count); + struct Metric { Metric(const std::string& metrics_name, const std::string& metrics_unit, diff --git a/cpp/bench/wholememory_ops/gather_scatter_bench.cu b/cpp/bench/wholememory_ops/gather_scatter_bench.cu index a84035c..b081048 100644 --- a/cpp/bench/wholememory_ops/gather_scatter_bench.cu +++ b/cpp/bench/wholememory_ops/gather_scatter_bench.cu @@ -77,6 +77,8 @@ typedef struct GatherScatterBenchParam { int64_t get_embedding_dim() const { return embedding_dim; } wholememory_dtype_t get_embedding_type() const { return embedding_type; } + int get_partition_method() const { return partition_method; } + std::string get_distributed_backend() const { return distributed_backend; } GatherScatterBenchParam& set_memory_type(wholememory_memory_type_t new_memory_type) { @@ -153,6 +155,18 @@ typedef struct GatherScatterBenchParam { return *this; } + GatherScatterBenchParam& set_partition_method(int new_partition_method) + { + partition_method = new_partition_method; + return *this; + } + + GatherScatterBenchParam& set_distributed_backend(std::string new_distributed_backend) + { + distributed_backend = new_distributed_backend; + return *this; + } + private: int64_t get_embedding_entry_count() const { @@ -196,6 +210,8 @@ typedef struct GatherScatterBenchParam { int64_t embedding_dim = 32; int loop_count = 20; std::string test_type = "gather"; // gather or scatter + int partition_method = 0; + std::string distributed_backend = "nccl"; // nccl or nvshmem std::string server_addr = "localhost"; int server_port = 24987; @@ -256,7 +272,15 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params) wholememory_comm_t wm_comm = create_communicator_by_socket(side_band_communicator, world_rank, world_size); - + std::string distributed_backend = params.get_distributed_backend(); +#ifdef WITH_NVSHMEM_SUPPORT + if (distributed_backend.compare("nvshmem") == 0) + WHOLEMEMORY_CHECK_NOTHROW(wholememory_communicator_set_distributed_backend( + wm_comm, WHOLEMEMORY_DB_NVSHMEM) == WHOLEMEMORY_SUCCESS); +#else + distributed_backend = "nccl"; + params.set_distributed_backend("nccl"); +#endif ShutDownSidebandCommunicator(side_band_communicator); auto embedding_desc = params.get_embedding_desc(); @@ -268,12 +292,17 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params) wholememory_tensor_t embedding_tensor; wholememory_tensor_description_t embedding_tensor_desc; wholememory_copy_matrix_desc_to_tensor(&embedding_tensor_desc, &embedding_desc); + std::vector rank_partition(world_size); + wholegraph::bench::host_random_partition( + rank_partition.data(), embedding_tensor_desc.sizes[0], world_size); WHOLEMEMORY_CHECK_NOTHROW(wholememory_create_tensor(&embedding_tensor, &embedding_tensor_desc, wm_comm, params.get_memory_type(), - params.get_memory_location()) == - WHOLEMEMORY_SUCCESS); + params.get_memory_location(), + params.get_partition_method() == 1 + ? rank_partition.data() + : nullptr) == WHOLEMEMORY_SUCCESS); cudaStream_t stream; WM_CUDA_CHECK_NO_THROW(cudaStreamCreate(&stream)); @@ -318,8 +347,8 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params) double gather_size_mb = (double)params.get_gather_size() / 1024.0 / 1024.0; if (local_rank == 0) { printf( - "%s, world_size=%d, memoryType=%s, memoryLocation=%s, elt_size=%ld, embeddingDim=%ld, " - "embeddingTableSize=%.2lf MB, gatherSize=%.2lf MB\n", + "%s, worldSize=%d, memoryType=%s, memoryLocation=%s, eltSize=%ld, embeddingDim=%ld, " + "embeddingTableSize=%.2lf MB, gatherSize=%.2lf MB, distributedBackend=%s\n", test_type.c_str(), world_size, get_memory_type_string(params.get_memory_type()).c_str(), @@ -327,7 +356,8 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params) wholememory_dtype_get_element_size(params.get_embedding_type()), params.get_embedding_dim(), emb_size_mb, - gather_size_mb); + gather_size_mb, + distributed_backend.c_str()); } PerformanceMeter meter; @@ -388,7 +418,7 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params) int main(int argc, char** argv) { wholegraph::bench::gather_scatter::GatherScatterBenchParam params; - const char* optstr = "ht:l:e:g:d:c:f:a:p:r:s:n:"; + const char* optstr = "ht:l:e:g:d:c:f:a:p:r:s:n:m:b:"; struct option opts[] = { {"help", no_argument, NULL, 'h'}, {"memory_type", @@ -405,8 +435,9 @@ int main(int argc, char** argv) {"node_size", required_argument, NULL, 's'}, // node_size {"num_gpu", required_argument, NULL, 'n'}, // num gpu per node {"server_addr", required_argument, NULL, 'a'}, // server_addr - {"server_port", required_argument, NULL, 'p'} // server_port - }; + {"server_port", required_argument, NULL, 'p'}, // server_port + {"partition_method", required_argument, NULL, 'm'}, + {"distributed_backend", required_argument, NULL, 'b'}}; const char* usage = "Usage: %s [options]\n" @@ -424,7 +455,9 @@ int main(int argc, char** argv) " -s, --node_size node_size or process count\n" " -n, --num_gpu num_gpu per process\n" " -a, --server_addr specify sideband server address\n" - " -p, --server_port specify sideband server port\n"; + " -p, --server_port specify sideband server port\n" + " -m, --partition_method specify rank partition method, 0: Default, 1: Random\n" + " -b, --distributed_backend specify distributed backend: nccl or nvshmem\n"; int c; bool has_option = false; @@ -536,6 +569,26 @@ int main(int argc, char** argv) } params.set_num_gpu(val); break; + case 'm': + val = std::atoi(optarg); + if (val != 0 && val != 1) { + printf("Invalid argument for option -m\n"); + printf(usage, argv[0]); + exit(EXIT_FAILURE); + } + params.set_partition_method(val); + break; + case 'b': + if (strcmp(optarg, "nccl") == 0) { + params.set_distributed_backend("nccl"); + } else if (strcmp(optarg, "nvshmem") == 0) { + params.set_distributed_backend("nvshmem"); + } else { + printf("Invalid argument for option -b\n"); + printf(usage, argv[0]); + exit(EXIT_FAILURE); + } + break; default: printf("Invalid or unrecognized option\n"); printf(usage, argv[0]); diff --git a/cpp/include/wholememory/device_reference.cuh b/cpp/include/wholememory/device_reference.cuh index 538ae27..8f2146a 100644 --- a/cpp/include/wholememory/device_reference.cuh +++ b/cpp/include/wholememory/device_reference.cuh @@ -26,23 +26,48 @@ class device_reference { public: __device__ __forceinline__ explicit device_reference(const wholememory_gref_t& gref) : pointer_(static_cast(gref.pointer)), - typed_stride_(gref.stride / sizeof(DataTypeT)) + typed_stride_(gref.stride / sizeof(DataTypeT)), + world_size_(gref.world_size), + same_chunk_(gref.same_chunk) { assert(gref.stride % sizeof(DataTypeT) == 0); + if (typed_stride_ != 0 && !same_chunk_) { + assert(world_size_ <= 8); // intra-node WHOLEMEMORY_MT_CHUNKED + for (int i = 0; i < world_size_ + 1; i++) { + assert(gref.rank_memory_offsets[i] % sizeof(DataTypeT) == 0); + typed_rank_mem_offsets_[i] = gref.rank_memory_offsets[i] / sizeof(DataTypeT); + } + } } __device__ device_reference() = delete; __device__ __forceinline__ DataTypeT& operator[](size_t index) { if (typed_stride_ == 0) { return pointer_[index]; } - size_t rank = index / typed_stride_; - return static_cast( - static_cast(pointer_))[rank][index - rank * typed_stride_]; + if (same_chunk_) { + size_t rank = index / typed_stride_; + return static_cast( + static_cast(pointer_))[rank][index - rank * typed_stride_]; + } else { + size_t rank = 0; + for (int i = 1; i < world_size_ + 1; i++) { + if (index < typed_rank_mem_offsets_[i]) { + rank = i - 1; + break; + } + } + return static_cast( + static_cast(pointer_))[rank][index - typed_rank_mem_offsets_[rank]]; + } } private: DataTypeT* pointer_; + int world_size_; size_t typed_stride_; + + bool same_chunk_; + size_t typed_rank_mem_offsets_[8 + 1]; }; } // namespace wholememory diff --git a/cpp/include/wholememory/embedding.h b/cpp/include/wholememory/embedding.h index 08cd73e..1853742 100644 --- a/cpp/include/wholememory/embedding.h +++ b/cpp/include/wholememory/embedding.h @@ -129,6 +129,8 @@ wholememory_error_code_t wholememory_destroy_embedding_cache_policy( * @param memory_type : Memory Type of the underlying WholeMemory * @param memory_location : Memory Location of the underlying WholeMemory * @param cache_policy : Cache policy for this embedding, if don't use cache, use nullptr + * @param embedding_entry_partition: Embedding entry count of each rank, the length must be + * world_size * @param user_defined_sms : User-defined sms number for raw embedding gather/scatter * @param round_robin_size : continuous embedding size in each rank under round-robin shard mode * @return : wholememory_error_code_t @@ -140,8 +142,9 @@ wholememory_error_code_t wholememory_create_embedding( wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, wholememory_embedding_cache_policy_t cache_policy, - int user_defined_sms = -1, - int round_robin_size = 0); + size_t* embedding_entry_partition = nullptr, + int user_defined_sms = -1, + int round_robin_size = 0); /** * Destroy WholeMemory Embedding diff --git a/cpp/include/wholememory/global_reference.h b/cpp/include/wholememory/global_reference.h index 767d174..7f17475 100644 --- a/cpp/include/wholememory/global_reference.h +++ b/cpp/include/wholememory/global_reference.h @@ -24,14 +24,18 @@ extern "C" { /** * @brief Global reference of a WholeMemory object * - * A global reference is for Continuous of Chunked WholeMemory Type, in these types, each rank can + * A global reference is for Continuous or Chunked WholeMemory Type, in these types, each rank can * directly access all memory from all ranks. The global reference is used to do this direct access. */ struct wholememory_gref_t { void* pointer; /*!< pointer to data for CONTINUOUS WholeMemory or pointer to data pointer array for CHUNKED WholeMemory */ + size_t* + rank_memory_offsets; /*!< memory offset of each rank, and the length must be world_size+1 */ + int world_size; size_t stride; /*!< must be 0 for CONTINUOUS WholeMemory or memory size in byte for each pointer */ + bool same_chunk; /*!< if true, rank can be got by offset/stride */ }; /** @@ -43,9 +47,11 @@ wholememory_gref_t wholememory_create_continuous_global_reference(void* ptr); struct wholememory_nvshmem_ref_t { void* pointer; + size_t* rank_memory_offsets; size_t stride; int world_rank; int world_size; + bool same_chunk; }; #ifdef __cplusplus diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index f6baccc..58aa611 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -63,6 +63,7 @@ enum wholememory_memory_type_t { WHOLEMEMORY_MT_CONTINUOUS, /*!< Memory from all ranks are mapped in continuous address space */ WHOLEMEMORY_MT_CHUNKED, /*!< Memory from all ranks are mapped in chunked address space */ WHOLEMEMORY_MT_DISTRIBUTED, /*!< Memory from other ranks are not mapped. */ + WHOLEMEMORY_MT_HIERARCHY, /*!< Memory from other ranks are mapped in hierarchy address space */ }; /** @@ -206,6 +207,23 @@ wholememory_error_code_t wholememory_communicator_get_rank(int* rank, wholememor */ wholememory_error_code_t wholememory_communicator_get_size(int* size, wholememory_comm_t comm); +/** + * Get the local rank size of current process in the WholeMemory Communicator + * @param local_size : returned local rank size + * @param comm : WholeMemory Communicator + * @return : wholememory_error_code_t + */ + +wholememory_error_code_t wholememory_communicator_get_local_size(int* local_size, + wholememory_comm_t comm); + +/** + * Get the clique info of WholeMemory Communicator + * @param clique_info : returned clique info + * @param comm : WholeMemory Communicator + * @return : wholememory_error_code_t + */ + wholememory_error_code_t wholememory_communicator_get_clique_info(clique_info_t* clique_info, wholememory_comm_t comm); @@ -238,6 +256,7 @@ typedef struct wholememory_handle_* wholememory_handle_t; * @param memory_type : WholeMemory type * @param memory_location : memory location, host or device * @param data_granularity : granularity size of data, which is guaranteed not to be partitioned. + * @param rank_entry_partition : entry count of each rank (size of entry equal to data_granularity) * @return : wholememory_error_code_t */ wholememory_error_code_t wholememory_malloc(wholememory_handle_t* wholememory_handle_ptr, @@ -245,7 +264,8 @@ wholememory_error_code_t wholememory_malloc(wholememory_handle_t* wholememory_ha wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity); + size_t data_granularity, + size_t* rank_entry_partition = nullptr); /** * Free allocated WholeMemory Handle @@ -263,6 +283,25 @@ wholememory_error_code_t wholememory_free(wholememory_handle_t wholememory_handl wholememory_error_code_t wholememory_get_communicator(wholememory_comm_t* comm, wholememory_handle_t wholememory_handle); +/** + * Get underlying Wholememory Local Communicator for "Hierarchy" memory type from WholeMemory Handle + * @param comm : returned Local WholeMemory Communicator + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_get_local_communicator( + wholememory_comm_t* comm, wholememory_handle_t wholememory_handle); + +/** + * Get underlying Wholememory Cross Communicator for "Hierarchy" memory type from WholeMemory Handle + * One comminicator includes all rank with a same local id from different nodes + * @param comm : returned Cross WholeMemory Communicator + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_get_cross_communicator( + wholememory_comm_t* comm, wholememory_handle_t wholememory_handle); + /** * Get WholeMemory Type * @param wholememory_handle : WholeMemory Handle @@ -309,6 +348,24 @@ wholememory_error_code_t wholememory_get_local_memory(void** local_ptr, size_t* local_offset, wholememory_handle_t wholememory_handle); +/** + * Get local memory size from WholeMemory Handle of current rank + * @param local_size : returned local memory size + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_get_local_size(size_t* local_size, + wholememory_handle_t wholememory_handle); + +/** + * Get local memory offset from WholeMemory Handle of current rank + * @param local_offset : returned local memory offset + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_get_local_offset(size_t* local_offset, + wholememory_handle_t wholememory_handle); + /** * Get local memory of specified rank from WholeMemory Handle * @param rank_memory_ptr : returned local memory pointer of specified rank @@ -324,6 +381,17 @@ wholememory_error_code_t wholememory_get_rank_memory(void** rank_memory_ptr, int rank, wholememory_handle_t wholememory_handle); +/** + * Get the equal partition plan WholeMemory uses by default + * @param entry_per_rank : returned entry count per rank + * @param total_entry_count : total entry count + * @param world_size : communicator world size + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_equal_entry_partition_plan(size_t* entry_per_rank, + size_t total_entry_count, + int world_size); + /** * Get global memory pointer from WholeMemory Handle. * Only Continuous memory type or Chunked Host memory has global pointer. @@ -345,38 +413,22 @@ wholememory_error_code_t wholememory_get_global_reference(wholememory_gref_t* wh wholememory_handle_t wholememory_handle); /** - * Get the partition plan WholeMemory will use - * @param size_per_rank : returned size per rank - * @param total_size : total size - * @param data_granularity : data granularity - * @param world_size : communicator world size - * @return : wholememory_error_code_t - */ -wholememory_error_code_t wholememory_determine_partition_plan(size_t* size_per_rank, - size_t total_size, - size_t data_granularity, - int world_size); - -/** - * Get the partition plan WholeMemory will use based on entry count. - * Entry is number of data granularity - * @param entry_per_rank : returned entry count per rank - * @param total_entry_count : total entry count - * @param world_size : communicator world size + * Get memory size of each rank from WholeMemory Handle + * @param rank_mem_sizes : returned memory size of each rank + * @param wholememory_handle : WholeMemory Handle * @return : wholememory_error_code_t */ -wholememory_error_code_t wholememory_determine_entry_partition_plan(size_t* entry_per_rank, - size_t total_entry_count, - int world_size); +wholememory_error_code_t wholememory_get_rank_partition_sizes( + size_t* rank_mem_sizes, wholememory_handle_t wholememory_handle); /** - * Get the partition plan used in WholeMemory Handle - * @param size_per_rank : returned size per rank + * Get memory offset of each rank from WholeMemory Handle + * @param rank_mem_offsets : returned memory offset of each rank * @param wholememory_handle : WholeMemory Handle * @return : wholememory_error_code_t */ -wholememory_error_code_t wholememory_get_partition_plan(size_t* size_per_rank, - wholememory_handle_t wholememory_handle); +wholememory_error_code_t wholememory_get_rank_partition_offsets( + size_t* rank_mem_offsets, wholememory_handle_t wholememory_handle); /** * Fork a new process and get device count. Should be called before other CUDA call diff --git a/cpp/include/wholememory/wholememory_tensor.h b/cpp/include/wholememory/wholememory_tensor.h index aa757f6..9acd3e5 100644 --- a/cpp/include/wholememory/wholememory_tensor.h +++ b/cpp/include/wholememory/wholememory_tensor.h @@ -37,6 +37,7 @@ typedef struct wholememory_tensor_* wholememory_tensor_t; * @param comm : WholeMemory Communicator * @param memory_type : Memory Type of the underlying WholeMemory * @param memory_location : Memory Location of the underlying WholeMemory + * @param tensor_entry_partition : Tensor entry count of each rank, the length must be world_size. * @return : wholememory_error_code_t */ wholememory_error_code_t wholememory_create_tensor( @@ -44,7 +45,8 @@ wholememory_error_code_t wholememory_create_tensor( wholememory_tensor_description_t* tensor_description, wholememory_comm_t comm, wholememory_memory_type_t memory_type, - wholememory_memory_location_t memory_location); + wholememory_memory_location_t memory_location, + size_t* tensor_entry_partition = nullptr); /** * Destroy WholeMemory Tensor @@ -131,11 +133,40 @@ wholememory_error_code_t wholememory_tensor_map_local_tensor( void* wholememory_tensor_get_data_pointer(wholememory_tensor_t wholememory_tensor); /** - * Get entry count per rank of a WholeMemory Tensor + * Get entry offset of each rank from WholeMemory Tensor + * @param entry_offsets : returned entry offset of each rank * @param wholememory_tensor : WholeMemory Tensor - * @return : entry count per rank + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_tensor_get_entry_offsets( + size_t* entry_offsets, wholememory_tensor_t wholememory_tensor); + +/** + * Get entry count of each rank from WholeMemory Tensor + * @param entry_partition : returned entry count of each rank + * @param wholememory_tensor : WholeMemory Tensor + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_tensor_get_entry_partition_sizes( + size_t* entry_partition, wholememory_tensor_t wholememory_tensor); + +/** + * Get entry count of current rank from WholeMemory Tensor + * @param local_entry_count : returned entry count of current rank + * @param wholememory_tensor : WholeMemory Tensor + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_tensor_get_local_entry_count( + size_t* local_entry_count, wholememory_tensor_t wholememory_tensor); + +/** + * Get entry start of current rank from WholeMemory Tensor + * @param local_entry_start : returned entry start id of current rank + * @param wholememory_tensor : WholeMemory Tensor + * @return : wholememory_error_code_t */ -size_t wholememory_tensor_get_entry_per_partition(wholememory_tensor_t wholememory_tensor); +wholememory_error_code_t wholememory_tensor_get_local_entry_start( + size_t* local_entry_start, wholememory_tensor_t wholememory_tensor); /** * Get sub tensor of a WholeMemory Tensor diff --git a/cpp/src/wholememory/communicator.cpp b/cpp/src/wholememory/communicator.cpp index dabb9ba..f76a4c7 100644 --- a/cpp/src/wholememory/communicator.cpp +++ b/cpp/src/wholememory/communicator.cpp @@ -897,6 +897,13 @@ wholememory_error_code_t communicator_get_size(int* size, wholememory_comm_t com return WHOLEMEMORY_SUCCESS; } +wholememory_error_code_t communicator_get_local_size(int* local_size, + wholememory_comm_t comm) noexcept +{ + *local_size = comm->intra_node_rank_num; + return WHOLEMEMORY_SUCCESS; +} + // wholememory_error_code_t communicator_get_clique_rank(int* clique_rank, // wholememory_comm_t comm) noexcept // { diff --git a/cpp/src/wholememory/communicator.hpp b/cpp/src/wholememory/communicator.hpp index b48d66b..709965c 100644 --- a/cpp/src/wholememory/communicator.hpp +++ b/cpp/src/wholememory/communicator.hpp @@ -291,6 +291,9 @@ wholememory_error_code_t communicator_get_rank(int* rank, wholememory_comm_t com wholememory_error_code_t communicator_get_size(int* size, wholememory_comm_t comm) noexcept; +wholememory_error_code_t communicator_get_local_size(int* local_size, + wholememory_comm_t comm) noexcept; + wholememory_error_code_t communicator_get_clique_info(clique_info_t* clique_info, wholememory_comm_t comm) noexcept; diff --git a/cpp/src/wholememory/embedding.cpp b/cpp/src/wholememory/embedding.cpp index f1a868a..7d6aae8 100644 --- a/cpp/src/wholememory/embedding.cpp +++ b/cpp/src/wholememory/embedding.cpp @@ -89,7 +89,8 @@ wholememory_error_code_t embedding_base::allocate( wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - wholememory_embedding_cache_policy_t policy) noexcept + wholememory_embedding_cache_policy_t policy, + size_t* embedding_entry_partition) noexcept { cache_policy = policy; raw_embedding_comm_ = comm; @@ -109,6 +110,7 @@ wholememory_error_code_t embedding_base::allocate( comm, memory_type, memory_location)); + embedding_entry_partition = nullptr; } else { wholememory_copy_matrix_desc_to_tensor(&padded_embedding_tensor_description, embedding_description); @@ -123,7 +125,8 @@ wholememory_error_code_t embedding_base::allocate( &padded_embedding_tensor_description, comm, memory_type, - memory_location)); + memory_location, + embedding_entry_partition)); int64_t starts[2] = {0, 0}; int64_t ends[2] = {embedding_description->sizes[0], embedding_description->sizes[1]}; WHOLEMEMORY_RETURN_ON_FAIL( @@ -155,8 +158,6 @@ wholememory_error_code_t embedding_base::gather_gradient_apply(wholememory_tenso host_rank_id_count_handle(p_env_fns); wholememory_ops::temp_memory_handle dev_recv_indices_buffer_handle(p_env_fns); wholememory_ops::temp_memory_handle dev_raw_indice_handle(p_env_fns); - size_t const embedding_entry_count_per_rank = - wholememory_tensor_get_entry_per_partition(allocated_embedding); wholememory_ops::wm_thrust_allocator thrust_allocator(p_env_fns); int world_size = -1, world_rank = -1; int64_t* host_recv_rank_id_count_ptr = nullptr; @@ -174,6 +175,21 @@ wholememory_error_code_t embedding_base::gather_gradient_apply(wholememory_tenso wholememory_array_description_t indice_array_desc; WHOLEMEMORY_CHECK_NOTHROW( wholememory_convert_tensor_desc_to_array(&indice_array_desc, indice_desc)); + + wholememory_ops::temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + wholememory_ops::temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_tensor_get_entry_offsets(host_embedding_entry_offsets_ptr, allocated_embedding)); + WM_CUDA_CHECK_NO_THROW(cudaMemcpy(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice)); + WHOLEMEMORY_RETURN_ON_FAIL( wholememory_ops::bucket_and_exchange_ids_func(wholememory_tensor_get_data_pointer(indices), indice_array_desc, @@ -181,7 +197,7 @@ wholememory_error_code_t embedding_base::gather_gradient_apply(wholememory_tenso host_rank_id_count_ptr, &dev_recv_indices_buffer_handle, dev_raw_indice_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, raw_embedding_comm_, &thrust_allocator, p_env_fns, @@ -308,9 +324,9 @@ wholememory_error_code_t embedding_base::gather_gradient_apply(wholememory_tenso wholememory_error_code_t embedding_base::create_optimizer_states() noexcept { + wholememory_handle_t wm_handle = wholememory_tensor_get_memory_handle(allocated_embedding); wholememory_comm_t wm_raw_comm; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator( - &wm_raw_comm, wholememory_tensor_get_memory_handle(allocated_embedding))); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_raw_comm, wm_handle)); int world_rank, world_size; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&world_rank, wm_raw_comm)); @@ -321,12 +337,16 @@ wholememory_error_code_t embedding_base::create_optimizer_states() noexcept int64_t start[2] = {0, 0}; int64_t end[2] = {user_tensor_desc->sizes[1], -1}; - size_t entry_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_determine_entry_partition_plan( - &entry_per_rank, allocated_tensor_desc->sizes[0], world_size)); + std::vector allocated_tensor_entry_partition(world_size); + std::vector user_tensor_entry_partition(world_size); + wholememory_tensor_get_entry_partition_sizes(allocated_tensor_entry_partition.data(), + allocated_embedding); + wholememory_tensor_get_entry_partition_sizes(user_tensor_entry_partition.data(), user_embedding); optimizer_state_ = std::make_unique(); - optimizer_state_->local_start_index = entry_per_rank * world_rank; + optimizer_state_->local_start_index = 0; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_tensor_get_local_entry_start( + (size_t*)&optimizer_state_->local_start_index, allocated_embedding)); optimizer_impl_base_->create_optimizer_states(optimizer_state_.get(), user_tensor_desc->sizes[1]); bool const need_cachable_states = !optimizer_state_->cachable_states.empty(); wholememory_tensor_description_t cachable_state_desc; @@ -363,7 +383,8 @@ wholememory_error_code_t embedding_base::create_optimizer_states() noexcept raw_embedding_comm_, memory_type, memory_location, - cache_policy)); + cache_policy, + user_tensor_entry_partition.data())); optimizer_state_->global_cachable_raw_user_tensor = wholememory_embedding_get_embedding_tensor(optimizer_state_->cachable_state_embedding); @@ -391,7 +412,8 @@ wholememory_error_code_t embedding_base::create_optimizer_states() noexcept &uc_desc, wm_raw_comm, WHOLEMEMORY_MT_DISTRIBUTED, - WHOLEMEMORY_ML_DEVICE)); + WHOLEMEMORY_ML_DEVICE, + allocated_tensor_entry_partition.data())); start[0] = 0; start[1] = 0; end[0] = user_tensor_desc->sizes[0]; @@ -564,8 +586,6 @@ wholememory_error_code_t device_cached_host_embedding::gather(wholememory_tensor host_rank_id_count_handle(p_env_fns); wholememory_ops::temp_memory_handle dev_recv_indices_buffer_handle(p_env_fns); wholememory_ops::temp_memory_handle dev_raw_indice_handle(p_env_fns); - size_t const embedding_entry_count_per_rank = - wholememory_tensor_get_entry_per_partition(allocated_embedding); wholememory_ops::wm_thrust_allocator thrust_allocator(p_env_fns); int world_size = -1, world_rank = -1; int64_t* host_recv_rank_id_count_ptr = nullptr; @@ -584,6 +604,20 @@ wholememory_error_code_t device_cached_host_embedding::gather(wholememory_tensor wholememory_array_description_t indice_array_desc; WHOLEMEMORY_CHECK_NOTHROW( wholememory_convert_tensor_desc_to_array(&indice_array_desc, indice_desc)); + + wholememory_ops::temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + wholememory_ops::temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_tensor_get_entry_offsets(host_embedding_entry_offsets_ptr, allocated_embedding)); + WM_CUDA_CHECK_NO_THROW(cudaMemcpy(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice)); WHOLEMEMORY_RETURN_ON_FAIL( wholememory_ops::bucket_and_exchange_ids_func(wholememory_tensor_get_data_pointer(indices), indice_array_desc, @@ -591,7 +625,7 @@ wholememory_error_code_t device_cached_host_embedding::gather(wholememory_tensor host_rank_id_count_ptr, &dev_recv_indices_buffer_handle, dev_raw_indice_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, raw_embedding_comm_, &thrust_allocator, p_env_fns, @@ -642,8 +676,10 @@ wholememory_error_code_t device_cached_host_embedding::gather(wholememory_tensor wholememory_gref_t cache_line_tag_gref; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_tensor_get_global_reference( cache_ptr_->get_cache_local_data()->cache_line_tag_, &cache_line_tag_gref)); - int64_t const rank_start_gid = - wholememory_tensor_get_entry_per_partition(allocated_embedding) * world_rank; + + size_t rank_start_gid = 0; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_tensor_get_local_entry_start(&rank_start_gid, allocated_embedding)); wholememory_tensor_description_t recv_indices_desc; auto recv_indices_array_desc = wholememory_create_array_desc(total_recv_count, 0, indice_desc->dtype); @@ -756,8 +792,6 @@ wholememory_error_code_t local_cached_global_readonly_embedding::gather( host_rank_id_count_handle(p_env_fns); wholememory_ops::temp_memory_handle dev_recv_indices_buffer_handle(p_env_fns); wholememory_ops::temp_memory_handle dev_raw_indice_handle(p_env_fns); - size_t const embedding_entry_count_per_rank = - wholememory_tensor_get_entry_per_partition(cache_ptr_->access_count_wm_tensor_); wholememory_ops::wm_thrust_allocator thrust_allocator(p_env_fns); int cache_world_size = -1, cache_world_rank = -1; int64_t* host_recv_rank_id_count_ptr = nullptr; @@ -779,6 +813,19 @@ wholememory_error_code_t local_cached_global_readonly_embedding::gather( wholememory_array_description_t indice_array_desc; WHOLEMEMORY_CHECK_NOTHROW( wholememory_convert_tensor_desc_to_array(&indice_array_desc, indice_desc)); + + wholememory_ops::temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(cache_world_size + 1, WHOLEMEMORY_DT_INT64)); + wholememory_ops::temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(cache_world_size + 1, WHOLEMEMORY_DT_INT64)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_tensor_get_entry_offsets( + host_embedding_entry_offsets_ptr, cache_ptr_->access_count_wm_tensor_)); + WM_CUDA_CHECK_NO_THROW(cudaMemcpy(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (cache_world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice)); WHOLEMEMORY_RETURN_ON_FAIL( wholememory_ops::bucket_and_exchange_ids_func(wholememory_tensor_get_data_pointer(indices), indice_array_desc, @@ -786,7 +833,7 @@ wholememory_error_code_t local_cached_global_readonly_embedding::gather( host_rank_id_count_ptr, &dev_recv_indices_buffer_handle, dev_raw_indice_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, cache_policy->cache_comm, &thrust_allocator, p_env_fns, @@ -804,7 +851,7 @@ wholememory_error_code_t local_cached_global_readonly_embedding::gather( update_indice_desc, allocated_embedding, cache_policy->cache_comm, - embedding_entry_count_per_rank, + host_embedding_entry_offsets_ptr, cache_ptr_->get_cache_local_data(), cache_ptr_->get_cache_set_coverage(), p_env_fns, @@ -903,6 +950,7 @@ wholememory_error_code_t wholememory_create_embedding( wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, wholememory_embedding_cache_policy_t cache_policy, + size_t* embedding_entry_partition, int user_defined_sms, int round_robin_size) { @@ -916,6 +964,9 @@ wholememory_error_code_t wholememory_create_embedding( int embedding_world_size = 1; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&embedding_world_size, comm)); if (cache_policy != nullptr) { + if (memory_type == WHOLEMEMORY_MT_HIERARCHY) { + WHOLEMEMORY_ERROR("Cache is not supported now in hierarchy memory type."); + } if (cache_policy->cache_comm == comm) { if (cache_policy->cache_memory_location != WHOLEMEMORY_ML_DEVICE) { WHOLEMEMORY_ERROR( @@ -961,15 +1012,24 @@ wholememory_error_code_t wholememory_create_embedding( } embedding_impl_ptr = new wholememory::local_cached_global_readonly_embedding(); } + embedding_entry_partition = nullptr; } else { embedding_impl_ptr = new wholememory::noncached_embedding(); } + if (embedding_entry_partition) { + if (round_robin_size != 0) { WHOLEMEMORY_WARN("Parameter 'round_robin_size' is ignored."); } + round_robin_size = 0; + } WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&embedding_world_size, comm)); embedding_impl_ptr->set_shard_method( &embedding_matrix_description, embedding_world_size, round_robin_size); embedding_impl_ptr->set_gather_sms(user_defined_sms); - WHOLEMEMORY_RETURN_ON_FAIL(embedding_impl_ptr->allocate( - &embedding_matrix_description, comm, memory_type, memory_location, cache_policy)); + WHOLEMEMORY_RETURN_ON_FAIL(embedding_impl_ptr->allocate(&embedding_matrix_description, + comm, + memory_type, + memory_location, + cache_policy, + embedding_entry_partition)); *wholememory_embedding = static_cast(embedding_impl_ptr); return WHOLEMEMORY_SUCCESS; } diff --git a/cpp/src/wholememory/embedding.hpp b/cpp/src/wholememory/embedding.hpp index f593c36..616667c 100644 --- a/cpp/src/wholememory/embedding.hpp +++ b/cpp/src/wholememory/embedding.hpp @@ -45,7 +45,8 @@ class embedding_base : public wholememory_embedding_ { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - wholememory_embedding_cache_policy_t policy) noexcept; + wholememory_embedding_cache_policy_t policy, + size_t* embedding_entry_partition) noexcept; void deallocate() noexcept; virtual wholememory_error_code_t gather(wholememory_tensor_t indices, wholememory_tensor_t output, diff --git a/cpp/src/wholememory/memory_handle.cpp b/cpp/src/wholememory/memory_handle.cpp index 16ed437..f5cbd62 100644 --- a/cpp/src/wholememory/memory_handle.cpp +++ b/cpp/src/wholememory/memory_handle.cpp @@ -57,7 +57,8 @@ class wholememory_impl { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) : handle_(wholememory_handle), comm_(comm), type_(memory_type), @@ -65,6 +66,16 @@ class wholememory_impl { total_size_(total_size), data_granularity_(data_granularity) { + if (rank_entry_partition != nullptr) { + rank_partition_strategy_.partition_sizes_.resize(comm_->world_size, 0); + rank_partition_strategy_.partition_offsets_.resize(comm_->world_size + 1, 0); + for (int i = 0; i < comm_->world_size; i++) { + rank_partition_strategy_.partition_sizes_[i] = rank_entry_partition[i] * data_granularity_; + rank_partition_strategy_.partition_offsets_[i + 1] = + rank_partition_strategy_.partition_offsets_[i] + + rank_entry_partition[i] * data_granularity_; + } + } distrubuted_backend_ = WHOLEMEMORY_DB_NCCL; } wholememory_impl() = delete; @@ -89,16 +100,17 @@ class wholememory_impl { [[nodiscard]] virtual wholememory_gref_t get_global_reference() const noexcept { wholememory_gref_t gref{}; - gref.pointer = nullptr; - gref.stride = 0; + gref.pointer = nullptr; + gref.stride = 0; + gref.world_size = comm_->world_size; return gref; } virtual bool contains_pointer(const void* ptr) const = 0; - void get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset) const + virtual void get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset) const { if (local_ptr != nullptr) *local_ptr = local_partition_memory_pointer_; - if (local_size != nullptr) *local_size = rank_partition_strategy_.local_mem_size; - if (local_offset != nullptr) *local_offset = rank_partition_strategy_.local_mem_offset; + if (local_size != nullptr) *local_size = get_local_size(); + if (local_offset != nullptr) *local_offset = get_local_offset(); if (location_ == WHOLEMEMORY_ML_HOST && (type_ == WHOLEMEMORY_MT_CONTINUOUS) && (!(comm_->is_intranode()))) { WHOLEMEMORY_WARN( @@ -116,10 +128,23 @@ class wholememory_impl { *rank_memory_offset = 0; return false; } - [[nodiscard]] size_t get_partition_stride() const + [[nodiscard]] virtual size_t get_partition_stride() const { return rank_partition_strategy_.partition_mem_stride; } + [[nodiscard]] size_t get_local_size() const + { + return rank_partition_strategy_.partition_sizes_[comm_->world_rank]; + } + [[nodiscard]] size_t get_local_offset() const + { + return rank_partition_strategy_.partition_offsets_[comm_->world_rank]; + } + std::vector get_rank_sizes() const { return rank_partition_strategy_.partition_sizes_; } + std::vector get_rank_offsets() const + { + return rank_partition_strategy_.partition_offsets_; + } protected: // In WholeMemory, memory is first allocated by one or all ranks, and then partition the whole @@ -136,18 +161,19 @@ class wholememory_impl { // first rank responsible for all memory allocation, continuous or chunked host shared memory may // use this mode. void first_rank_allocate_all_strategy(); - // each rank allocate exactly the same size, chunked device memory or nccl memory may use this - // mode. - void each_rank_same_chunk_strategy(); + // each rank allocate different size, chunked device memory or nccl memory may use this + // mode. If rank_entry_partition isn't set, each rank allocate exactly the same size. + void each_rank_different_chunk_strategy(); // each rank allocate a multiple of pages, and map the whole memory by page, continuous device // memory use this mode. void each_rank_multiple_page_strategy(); // For now, memory rank partitioning strategy is the same for all WholeMemory types. - // Each rank is response for memory of size local_mem_size_ starting from local_mem_offset_. - // And local_mem_offset_ can also be got by rank_mem_stride_ * rank for ranks with local_mem_size_ - // != 0 That means for a valid memory offset position, offset / rank_mem_stride_ can be used to - // get the rank which is responsible for it. + // Each rank is response for memory of size local_mem_size starting from local_mem_offset. + // Local_mem_size can be got by calling get_local_size(), and local_mem_offset can be got + // by calling get_local_offset(). rank_partition_strategy_.partition_sizes_ and + // rank_partition_strategy_.partition_offsets_ record the memory size and memory offset of + // all ranks. void generate_rank_partition_strategy(); /* @@ -182,12 +208,12 @@ class wholememory_impl { } alloc_strategy_; struct partition_strategy { - // size of memory this rank is responsible for - size_t local_mem_size = 0; - // start location of the memory this rank is responsible for - size_t local_mem_offset = 0; + std::vector partition_sizes_; + std::vector partition_offsets_; size_t partition_mem_stride = 0; + bool same_chunk; } rank_partition_strategy_; + void* local_partition_memory_pointer_ = nullptr; void get_rank_partition_info(size_t* rank_mem_size, @@ -195,12 +221,9 @@ class wholememory_impl { int rank) const noexcept { WHOLEMEMORY_CHECK_NOTHROW(rank >= 0 && rank <= comm_->world_size); - size_t rank_mem_part_start = - std::min(rank_partition_strategy_.partition_mem_stride * rank, total_size_); - size_t rank_mem_part_end = - std::min(rank_partition_strategy_.partition_mem_stride * (rank + 1), total_size_); - if (rank_mem_size != nullptr) *rank_mem_size = rank_mem_part_end - rank_mem_part_start; - if (rank_mem_start != nullptr) *rank_mem_start = rank_mem_part_start; + if (rank_mem_size != nullptr) *rank_mem_size = rank_partition_strategy_.partition_sizes_[rank]; + if (rank_mem_start != nullptr) + *rank_mem_start = rank_partition_strategy_.partition_offsets_[rank]; } static constexpr size_t HUGE_PAGE_THRESHOLD = 16UL * 1024UL * 1024UL * 1024UL; @@ -293,16 +316,22 @@ class distributed_wholememory_impl : public wholememory_impl { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : wholememory_impl( - wholememory_handle, total_size, comm, memory_type, memory_location, data_granularity) - { - WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED); + size_t data_granularity, + size_t* rank_entry_partition) + : wholememory_impl(wholememory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) + { + WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED || type_ == WHOLEMEMORY_MT_HIERARCHY); } void create_memory() override { - each_rank_same_chunk_strategy(); generate_rank_partition_strategy(); + each_rank_different_chunk_strategy(); create_local_cuda_runtime_memory(); register_private_memory(); } @@ -387,17 +416,23 @@ class global_mapped_host_wholememory_impl : public wholememory_impl { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : wholememory_impl( - wholememory_handle, total_size, comm, memory_type, memory_location, data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) + : wholememory_impl(wholememory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) { WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_CONTINUOUS || type_ == WHOLEMEMORY_MT_CHUNKED); WHOLEMEMORY_CHECK(location_ == WHOLEMEMORY_ML_HOST); } void create_memory() override { - first_rank_allocate_all_strategy(); generate_rank_partition_strategy(); + first_rank_allocate_all_strategy(); create_and_map_shared_host_memory(); register_host_memory(); } @@ -413,8 +448,9 @@ class global_mapped_host_wholememory_impl : public wholememory_impl { [[nodiscard]] wholememory_gref_t get_global_reference() const noexcept override { wholememory_gref_t gref{}; - gref.pointer = get_continuous_mapping_pointer(); - gref.stride = 0; + gref.pointer = get_continuous_mapping_pointer(); + gref.stride = 0; + gref.world_size = comm_->world_size; return gref; } bool contains_pointer(const void* ptr) const override @@ -423,6 +459,7 @@ class global_mapped_host_wholememory_impl : public wholememory_impl { uint64_t int_start_ptr = reinterpret_cast(shared_host_handle_.shared_host_memory_ptr); return int_ptr >= int_start_ptr && int_ptr < int_start_ptr + total_size_; } + bool get_rank_memory(void** rank_memory_ptr, size_t* rank_memory_size, size_t* rank_memory_offset, @@ -531,9 +568,7 @@ class global_mapped_host_wholememory_impl : public wholememory_impl { nullptr, alloc_strategy_.total_alloc_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); WHOLEMEMORY_CHECK(mmap_ptr != (void*)-1); } - memset(static_cast(mmap_ptr) + rank_partition_strategy_.local_mem_offset, - 0, - rank_partition_strategy_.local_mem_size); + memset(static_cast(mmap_ptr) + get_local_offset(), 0, get_local_size()); WM_CUDA_CHECK_NO_THROW( cudaHostRegister(mmap_ptr, alloc_strategy_.total_alloc_size, cudaHostRegisterDefault)); if (!use_systemv_shm_) WHOLEMEMORY_CHECK(close(shm_fd) == 0); @@ -541,8 +576,7 @@ class global_mapped_host_wholememory_impl : public wholememory_impl { WM_CUDA_CHECK_NO_THROW(cudaHostGetDevicePointer(&dev_ptr, mmap_ptr, 0)); WHOLEMEMORY_CHECK(dev_ptr == mmap_ptr); shared_host_handle_.shared_host_memory_ptr = dev_ptr; - local_partition_memory_pointer_ = - static_cast(dev_ptr) + rank_partition_strategy_.local_mem_offset; + local_partition_memory_pointer_ = static_cast(dev_ptr) + get_local_offset(); } void unmap_and_destroy_shared_host_memory() noexcept @@ -603,17 +637,29 @@ class continuous_device_wholememory_impl : public wholememory_impl { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : wholememory_impl( - wholememory_handle, total_size, comm, memory_type, memory_location, data_granularity) - { + size_t data_granularity, + size_t* rank_entry_partition) + : wholememory_impl(wholememory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) + { + // printf( + // "while in continuous device wholememory creation, the memory_type (%d) and memory_location + // " + // "(%d).\n", + // (int)memory_type, + // (int)memory_location); WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_CONTINUOUS); } void create_memory() override { WHOLEMEMORY_CHECK(location_ == WHOLEMEMORY_ML_DEVICE); - each_rank_multiple_page_strategy(); generate_rank_partition_strategy(); + each_rank_multiple_page_strategy(); create_and_map_driver_device_memory(); register_continuous_device_memory(); } @@ -629,8 +675,9 @@ class continuous_device_wholememory_impl : public wholememory_impl { [[nodiscard]] wholememory_gref_t get_global_reference() const noexcept override { wholememory_gref_t gref{}; - gref.pointer = get_continuous_mapping_pointer(); - gref.stride = 0; + gref.pointer = get_continuous_mapping_pointer(); + gref.stride = 0; + gref.world_size = comm_->world_size; return gref; } bool contains_pointer(const void* ptr) const override @@ -960,8 +1007,8 @@ class continuous_device_wholememory_impl : public wholememory_impl { close_unix_domain_sockets(); map_driver_device_memory_handles(&recv_ipc_sharable_cu_handles); communicator_barrier(comm_); - local_partition_memory_pointer_ = static_cast(cu_alloc_handle_.mapped_whole_memory) + - rank_partition_strategy_.local_mem_offset; + local_partition_memory_pointer_ = + static_cast(cu_alloc_handle_.mapped_whole_memory) + get_local_offset(); } void unmap_and_destroy_driver_device_memory() noexcept { @@ -1017,17 +1064,23 @@ class chunked_device_wholememory_impl : public wholememory_impl { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : wholememory_impl( - wholememory_handle, total_size, comm, memory_type, memory_location, data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) + : wholememory_impl(wholememory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) { WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_CHUNKED); WHOLEMEMORY_CHECK(location_ == WHOLEMEMORY_ML_DEVICE); } void create_memory() override { - each_rank_same_chunk_strategy(); generate_rank_partition_strategy(); + each_rank_different_chunk_strategy(); create_and_map_runtime_device_memory(); register_chunked_device_memory(); } @@ -1044,7 +1097,7 @@ class chunked_device_wholememory_impl : public wholememory_impl { for (int i = 0; i < comm_->world_size; i++) { size_t mem_size_of_this_rank_and_after = total_size_ - acc_size; size_t mem_size_for_current_rank = - std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_mem_stride); + std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_sizes_[i]); uint64_t int_start_ptr = reinterpret_cast(cuda_ipc_handle_.mapped_ptrs[i]); if (int_ptr >= int_start_ptr && int_ptr < int_start_ptr + mem_size_for_current_rank) { return true; @@ -1074,7 +1127,7 @@ class chunked_device_wholememory_impl : public wholememory_impl { for (int i = 0; i < comm_->world_size; i++) { size_t mem_size_of_this_rank_and_after = total_size_ - acc_size; size_t mem_size_for_current_rank = - std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_mem_stride); + std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_sizes_[i]); if (mem_size_for_current_rank > 0) { register_wholememory_vma_range_locked( cuda_ipc_handle_.mapped_ptrs[i], mem_size_for_current_rank, handle_); @@ -1089,7 +1142,7 @@ class chunked_device_wholememory_impl : public wholememory_impl { for (int i = 0; i < comm_->world_size; i++) { size_t mem_size_of_this_rank_and_after = total_size_ - acc_size; size_t mem_size_for_current_rank = - std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_mem_stride); + std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_sizes_[i]); if (mem_size_for_current_rank > 0) { unregister_wholememory_vma_range_locked( cuda_ipc_handle_.mapped_ptrs[i], mem_size_for_current_rank, handle_); @@ -1124,12 +1177,20 @@ class chunked_device_wholememory_impl : public wholememory_impl { cuda_ipc_handle_.mapped_ptrs.data(), sizeof(void*) * comm_->world_size, cudaMemcpyHostToDevice)); - gref_.stride = rank_partition_strategy_.partition_mem_stride; + WM_CUDA_CHECK(cudaMalloc(&gref_.rank_memory_offsets, sizeof(size_t) * (comm_->world_size + 1))); + WM_CUDA_CHECK(cudaMemcpy(gref_.rank_memory_offsets, + get_rank_offsets().data(), + sizeof(size_t) * (comm_->world_size + 1), + cudaMemcpyHostToDevice)); + gref_.world_size = comm_->world_size; + gref_.stride = rank_partition_strategy_.partition_mem_stride; + gref_.same_chunk = rank_partition_strategy_.same_chunk; } void unmap_and_destroy_runtime_device_memory() noexcept { try { WM_CUDA_CHECK(cudaFree(gref_.pointer)); + WM_CUDA_CHECK(cudaFree(gref_.rank_memory_offsets)); gref_.pointer = nullptr; for (int i = 0; i < comm_->world_size; i++) { if (i != comm_->world_rank) { @@ -1164,9 +1225,15 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : wholememory_impl( - wholememory_handle, total_size, comm, memory_type, memory_location, data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) + : wholememory_impl(wholememory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) { WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED); WHOLEMEMORY_CHECK(location_ == WHOLEMEMORY_ML_DEVICE); @@ -1179,8 +1246,8 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { void create_memory() override { - each_rank_same_chunk_strategy(); generate_rank_partition_strategy(); + each_rank_different_chunk_strategy(); nvshmem_malloc_device_memory(); register_nvshmem_device_memory(); } @@ -1198,7 +1265,7 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { for (int i = 0; i < comm_->world_size; i++) { size_t mem_size_of_this_rank_and_after = total_size_ - acc_size; size_t mem_size_for_current_rank = - std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_mem_stride); + std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_sizes_[i]); acc_size += mem_size_for_current_rank; uint64_t int_start_ptr = reinterpret_cast(nvshmem_ptr(nvshmem_memory_handle_.local_alloc_mem_ptr, i)); @@ -1226,6 +1293,8 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { return true; } + [[nodiscard]] wholememory_gref_t get_global_reference() const noexcept override { return gref_; } + protected: void register_nvshmem_device_memory() { @@ -1234,7 +1303,7 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { for (int i = 0; i < comm_->world_size; i++) { size_t mem_size_of_this_rank_and_after = total_size_ - acc_size; size_t mem_size_for_current_rank = - std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_mem_stride); + std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_sizes_[i]); if (mem_size_for_current_rank > 0) { void* ptr = nvshmem_ptr(nvshmem_memory_handle_.local_alloc_mem_ptr, i); if (ptr != nullptr) { @@ -1251,7 +1320,7 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { for (int i = 0; i < comm_->world_size; i++) { size_t mem_size_of_this_rank_and_after = total_size_ - acc_size; size_t mem_size_for_current_rank = - std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_mem_stride); + std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_sizes_[i]); if (mem_size_for_current_rank > 0) { void* ptr = nvshmem_ptr(nvshmem_memory_handle_.local_alloc_mem_ptr, i); if (ptr != nullptr) { @@ -1270,10 +1339,22 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { nvshmem_memory_handle_.local_alloc_mem_ptr = nvshmem_malloc(alloc_size); local_partition_memory_pointer_ = nvshmem_memory_handle_.local_alloc_mem_ptr; distrubuted_backend_ = WHOLEMEMORY_DB_NVSHMEM; + + WM_CUDA_CHECK(cudaMalloc(&gref_.rank_memory_offsets, sizeof(size_t) * (comm_->world_size + 1))); + WM_CUDA_CHECK(cudaMemcpy(gref_.rank_memory_offsets, + get_rank_offsets().data(), + sizeof(size_t) * (comm_->world_size + 1), + cudaMemcpyHostToDevice)); + gref_.pointer = local_partition_memory_pointer_; + gref_.world_size = comm_->world_size; + gref_.stride = rank_partition_strategy_.partition_mem_stride; + gref_.same_chunk = rank_partition_strategy_.same_chunk; } void nvshmem_free_device_memory() { + WM_CUDA_CHECK(cudaFree(gref_.rank_memory_offsets)); + gref_.pointer = nullptr; if (nvshmem_memory_handle_.local_alloc_mem_ptr) { nvshmem_free(nvshmem_memory_handle_.local_alloc_mem_ptr); @@ -1307,6 +1388,8 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { void* local_alloc_mem_ptr = nullptr; } nvshmem_memory_handle_; inline static bool has_set_nvshmem_heap = false; + + wholememory_gref_t gref_; }; #endif // Implementation for MNNVL wholememory that use cuda driver api. @@ -1320,9 +1403,15 @@ class continuous_mnnvl_wholememory_impl : public continuous_device_wholememory_i wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : continuous_device_wholememory_impl( - wholememory_handle, total_size, comm, memory_type, memory_location, data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) + : continuous_device_wholememory_impl(wholememory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) { WHOLEMEMORY_INFO("Using continuous_mnnvl_wholememory_impl"); WHOLEMEMORY_CHECK_NOTHROW(type_ == WHOLEMEMORY_MT_CONTINUOUS); @@ -1335,8 +1424,8 @@ class continuous_mnnvl_wholememory_impl : public continuous_device_wholememory_i void create_memory() override { check_valid(); - each_rank_multiple_page_strategy(); generate_rank_partition_strategy(); + each_rank_multiple_page_strategy(); create_and_map_driver_memory(); register_continuous_mnnvl_memory(); } @@ -1474,8 +1563,8 @@ class continuous_mnnvl_wholememory_impl : public continuous_device_wholememory_i &cu_alloc_handle_.local_ipc_fabric_handle); map_driver_memory_handles(&recv_ipc_sharable_cu_fabric_handles); - local_partition_memory_pointer_ = static_cast(cu_alloc_handle_.mapped_whole_memory) + - rank_partition_strategy_.local_mem_offset; + local_partition_memory_pointer_ = + static_cast(cu_alloc_handle_.mapped_whole_memory) + get_local_offset(); } void unmap_and_destroy_driver_host_memory() noexcept { @@ -1513,16 +1602,37 @@ class continuous_mnnvl_wholememory_impl : public continuous_device_wholememory_i void wholememory_impl::generate_rank_partition_strategy() { - size_t data_slot_count = total_size_ / data_granularity_; - size_t data_slot_per_rank = determine_entry_partition_plan(data_slot_count, comm_->world_size); - size_t rank_data_slot_start = std::min(comm_->world_rank * data_slot_per_rank, data_slot_count); - size_t rank_data_slot_end = - std::min((comm_->world_rank + 1) * data_slot_per_rank, data_slot_count); - size_t rank_data_slot_count = rank_data_slot_end - rank_data_slot_start; - - rank_partition_strategy_.local_mem_size = rank_data_slot_count * data_granularity_; - rank_partition_strategy_.local_mem_offset = rank_data_slot_start * data_granularity_; + if (!rank_partition_strategy_.partition_sizes_.empty()) { + rank_partition_strategy_.partition_mem_stride = total_size_ / comm_->world_size; + bool check_same = true; + for (int i = 0; i < comm_->world_size - 2; i++) { // ignore the last rank + if (rank_partition_strategy_.partition_sizes_[i] != + rank_partition_strategy_.partition_sizes_[i + 1]) { + check_same = false; + break; + } + } + rank_partition_strategy_.same_chunk = check_same; + return; + } + size_t data_slot_count = total_size_ / data_granularity_; + + size_t data_slot_per_rank = 0; + equal_partition_plan(&data_slot_per_rank, data_slot_count, comm_->world_size); + + rank_partition_strategy_.partition_sizes_.resize(comm_->world_size, 0); + rank_partition_strategy_.partition_offsets_.resize(comm_->world_size + 1, 0); + for (int i = 0; i < comm_->world_size; i++) { + size_t tmp_slot_start = std::min(i * data_slot_per_rank, data_slot_count); + size_t tmp_slot_end = std::min((i + 1) * data_slot_per_rank, data_slot_count); + rank_partition_strategy_.partition_sizes_[i] = + (tmp_slot_end - tmp_slot_start) * data_granularity_; + rank_partition_strategy_.partition_offsets_[i] = tmp_slot_start * data_granularity_; + } + rank_partition_strategy_.partition_offsets_[comm_->world_size] = + data_slot_count * data_granularity_; rank_partition_strategy_.partition_mem_stride = data_slot_per_rank * data_granularity_; + rank_partition_strategy_.same_chunk = true; } void wholememory_impl::first_rank_allocate_all_strategy() @@ -1542,28 +1652,33 @@ void wholememory_impl::first_rank_allocate_all_strategy() alloc_strategy_.alloc_sizes[0] = alloc_strategy_.total_alloc_size; } -void wholememory_impl::each_rank_same_chunk_strategy() +void wholememory_impl::each_rank_different_chunk_strategy() { - size_t data_slot_count = total_size_ / data_granularity_; - size_t data_slot_per_rank = determine_entry_partition_plan(data_slot_count, comm_->world_size); - // each rank allocate same size - alloc_strategy_.local_alloc_size = data_slot_per_rank * data_granularity_; - alloc_strategy_.alignment = comm_->alloc_granularity; - if (total_size_ > HUGE_PAGE_THRESHOLD) { - alloc_strategy_.local_alloc_size = - round_up_unsafe(alloc_strategy_.local_alloc_size, HUGE_PAGE_SIZE); - alloc_strategy_.alignment = HUGE_PAGE_SIZE; - } - alloc_strategy_.total_alloc_size = alloc_strategy_.local_alloc_size * comm_->world_size; - alloc_strategy_.alloc_offsets.clear(); alloc_strategy_.alloc_offsets.resize(comm_->world_size, 0); + alloc_strategy_.alloc_sizes.clear(); + alloc_strategy_.alloc_sizes.resize(comm_->world_size, 0); + + size_t rank_local_alloc_offset = 0; for (int i = 0; i < comm_->world_size; i++) { - alloc_strategy_.alloc_offsets[i] = alloc_strategy_.local_alloc_size * i; + size_t rank_local_alloc_size = rank_partition_strategy_.partition_sizes_[i]; + size_t rank_alignment; + if (total_size_ > HUGE_PAGE_THRESHOLD) { + rank_local_alloc_size = round_up_unsafe(rank_local_alloc_size, HUGE_PAGE_SIZE); + rank_alignment = HUGE_PAGE_SIZE; + } else { + rank_local_alloc_size = round_up_unsafe(rank_local_alloc_size, comm_->alloc_granularity); + rank_alignment = comm_->alloc_granularity; + } + if (i == comm_->world_rank) { + alloc_strategy_.local_alloc_size = rank_local_alloc_size; + alloc_strategy_.alignment = rank_alignment; + } + alloc_strategy_.alloc_offsets[i] = rank_local_alloc_offset; + alloc_strategy_.alloc_sizes[i] = rank_local_alloc_size; + rank_local_alloc_offset += rank_local_alloc_size; } - - alloc_strategy_.alloc_sizes.clear(); - alloc_strategy_.alloc_sizes.resize(comm_->world_size, alloc_strategy_.local_alloc_size); + alloc_strategy_.total_alloc_size = rank_local_alloc_offset; } void wholememory_impl::each_rank_multiple_page_strategy() @@ -1638,16 +1753,68 @@ struct wholememory_create_param { size_t min_granularity; }; +class hierarchy_wholememory_impl : public distributed_wholememory_impl { + public: + hierarchy_wholememory_impl(wholememory_handle_t wholememory_handle, + size_t total_size, + wholememory_comm_t global_comm, + wholememory_comm_t local_comm, + wholememory_memory_type_t memory_type, + wholememory_memory_location_t memory_location, + size_t data_granularity, + size_t* rank_entry_partition) + : distributed_wholememory_impl(wholememory_handle, + total_size, + global_comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) + { + WHOLEMEMORY_CHECK(memory_type == WHOLEMEMORY_MT_HIERARCHY); + local_comm_ = local_comm; + int world_rank = -1, world_size = -1, local_size = -1; + wholememory_communicator_get_rank(&world_rank, global_comm); + wholememory_communicator_get_size(&world_size, global_comm); + wholememory_communicator_get_size(&local_size, local_comm); + WHOLEMEMORY_CHECK(world_size % local_size == 0); + wholememory_split_communicator( + &cross_comm_, global_comm, world_rank % local_size, world_rank / local_size); + } + + [[nodiscard]] wholememory_comm_t get_local_comm() const { return local_comm_; } + [[nodiscard]] wholememory_comm_t get_cross_comm() const { return cross_comm_; } + + protected: + wholememory_comm_t local_comm_; + wholememory_comm_t cross_comm_; +}; + wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_handle_ptr, size_t total_size, wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) noexcept + size_t data_granularity, + size_t* rank_entry_partition) noexcept { try { if (total_size % data_granularity != 0) return WHOLEMEMORY_INVALID_VALUE; - + if (rank_entry_partition != nullptr) { + int64_t total_slot_count = 0; + for (int i = 0; i < comm->world_size; i++) { + WM_COMM_CHECK_ALL_SAME(comm, rank_entry_partition[i]); + if (rank_entry_partition[i] <= 0) { return WHOLEMEMORY_INVALID_VALUE; } + total_slot_count += rank_entry_partition[i]; + } + if (total_slot_count * data_granularity != total_size) { + WHOLEMEMORY_ERROR("total slot count * data granularity (%ld*%ld) != total size (%ld)", + total_slot_count, + data_granularity, + total_size); + return WHOLEMEMORY_INVALID_VALUE; + } + } *wholememory_handle_ptr = nullptr; std::unique_lock mlock(comm->mu); auto* whole_memory_handle = new wholememory_handle_(); @@ -1660,27 +1827,52 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha if (memory_type == WHOLEMEMORY_MT_DISTRIBUTED) { #ifdef WITH_NVSHMEM_SUPPORT if (comm->bind_to_nvshmem) { - whole_memory_handle->impl = new nvshmem_device_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new nvshmem_device_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } else #endif { - whole_memory_handle->impl = new distributed_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new distributed_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } } else if (memory_type == WHOLEMEMORY_MT_CONTINUOUS) { if (is_intranode_communicator(comm) || !SupportEGM()) { if (memory_location == WHOLEMEMORY_ML_HOST) { - whole_memory_handle->impl = new global_mapped_host_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new global_mapped_host_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } else { - whole_memory_handle->impl = new continuous_device_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new continuous_device_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } } else { #if CUDA_VERSION >= 12030 - whole_memory_handle->impl = new continuous_mnnvl_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new continuous_mnnvl_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); #else WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINUOUS is only supported on CUDA Version >= 12.3"); #endif @@ -1688,12 +1880,37 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha } else if (memory_type == WHOLEMEMORY_MT_CHUNKED) { WHOLEMEMORY_CHECK_NOTHROW(is_intranode_communicator(comm)); if (memory_location == WHOLEMEMORY_ML_HOST) { - whole_memory_handle->impl = new global_mapped_host_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new global_mapped_host_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } else { - whole_memory_handle->impl = new chunked_device_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new chunked_device_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } + } else if (memory_type == WHOLEMEMORY_MT_HIERARCHY) { + wholememory_comm_t local_comm; + int world_rank = -1, local_size = -1; + wholememory_communicator_get_rank(&world_rank, comm); + wholememory_communicator_get_local_size(&local_size, comm); + wholememory_split_communicator( + &local_comm, comm, world_rank / local_size, world_rank % local_size); + whole_memory_handle->impl = new hierarchy_wholememory_impl(whole_memory_handle, + total_size, + comm, + local_comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } else { WHOLEMEMORY_FATAL("Unsupported memory_type (%d) and memory_location (%d).", (int)memory_type, @@ -1769,6 +1986,36 @@ wholememory_error_code_t get_communicator_from_handle( return WHOLEMEMORY_SUCCESS; } +wholememory_error_code_t get_local_communicator_from_handle( + wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept +{ + if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { + return WHOLEMEMORY_INVALID_INPUT; + } + if (get_memory_type(wholememory_handle) != WHOLEMEMORY_MT_HIERARCHY) { + return WHOLEMEMORY_NOT_SUPPORTED; + } + hierarchy_wholememory_impl* hierarchy_impl = + dynamic_cast(wholememory_handle->impl); + *comm = hierarchy_impl->get_local_comm(); + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t get_cross_communicator_from_handle( + wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept +{ + if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { + return WHOLEMEMORY_INVALID_INPUT; + } + if (get_memory_type(wholememory_handle) != WHOLEMEMORY_MT_HIERARCHY) { + return WHOLEMEMORY_NOT_SUPPORTED; + } + hierarchy_wholememory_impl* hierarchy_impl = + dynamic_cast(wholememory_handle->impl); + *comm = hierarchy_impl->get_cross_comm(); + return WHOLEMEMORY_SUCCESS; +} + wholememory_memory_type_t get_memory_type(wholememory_handle_t wholememory_handle) noexcept { return wholememory_handle->impl->get_type(); @@ -1858,45 +2105,69 @@ wholememory_error_code_t get_nvshmem_reference_frome_handle( (wholememory_handle->impl->get_distributed_backend() != WHOLEMEMORY_DB_NVSHMEM)) { return WHOLEMEMORY_INVALID_INPUT; } - *wholememory_nvshmem_ref = wholememory_nvshmem_ref_t{}; - size_t local_size, local_offset; - void* pointer; - - wholememory_handle->impl->get_local_memory(&pointer, &local_size, &local_offset); - wholememory_nvshmem_ref->pointer = pointer; - wholememory_nvshmem_ref->stride = wholememory_handle->impl->get_partition_stride(); - wholememory_nvshmem_ref->world_rank = wholememory_handle->impl->get_comm()->world_rank; - wholememory_nvshmem_ref->world_size = wholememory_handle->impl->get_comm()->world_size; + wholememory_gref_t wholememory_gref_tmp = wholememory_handle->impl->get_global_reference(); + *wholememory_nvshmem_ref = wholememory_nvshmem_ref_t{}; + wholememory_nvshmem_ref->pointer = wholememory_gref_tmp.pointer; + wholememory_nvshmem_ref->rank_memory_offsets = wholememory_gref_tmp.rank_memory_offsets; + wholememory_nvshmem_ref->world_size = wholememory_gref_tmp.world_size; + wholememory_nvshmem_ref->world_rank = wholememory_handle->impl->get_comm()->world_rank; + wholememory_nvshmem_ref->stride = wholememory_gref_tmp.stride; + wholememory_nvshmem_ref->same_chunk = wholememory_gref_tmp.same_chunk; return (wholememory_nvshmem_ref->pointer == nullptr) ? WHOLEMEMORY_INVALID_INPUT : WHOLEMEMORY_SUCCESS; } #endif -wholememory_error_code_t determine_partition_plan(size_t* size_per_rank, - size_t total_size, - size_t data_granularity, - int world_size) noexcept +wholememory_error_code_t equal_partition_plan(size_t* entry_per_rank, + size_t total_entry_count, + int world_size) noexcept +{ + *entry_per_rank = div_rounding_up_safe(total_entry_count, world_size); + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t get_rank_partition_sizes_from_handle( + size_t* rank_sizes, wholememory_handle_t wholememory_handle) noexcept +{ + if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { + return WHOLEMEMORY_INVALID_INPUT; + } + std::vector rank_sizes_ = wholememory_handle->impl->get_rank_sizes(); + for (int i = 0; i < rank_sizes_.size(); i++) + rank_sizes[i] = rank_sizes_[i]; + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t get_rank_partition_offsets_from_handle( + size_t* rank_offsets, wholememory_handle_t wholememory_handle) noexcept { - if (total_size % data_granularity != 0) { return WHOLEMEMORY_INVALID_VALUE; } - if (size_per_rank == nullptr) { return WHOLEMEMORY_INVALID_INPUT; } - size_t entry_per_rank = 0; - *size_per_rank = determine_entry_partition_plan(total_size / data_granularity, world_size); + if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { + return WHOLEMEMORY_INVALID_INPUT; + } + std::vector rank_offsets_ = wholememory_handle->impl->get_rank_offsets(); + for (int i = 0; i < rank_offsets_.size(); i++) + rank_offsets[i] = rank_offsets_[i]; return WHOLEMEMORY_SUCCESS; } -size_t determine_entry_partition_plan(size_t total_entry_count, int world_size) noexcept +wholememory_error_code_t get_local_size_from_handle( + size_t* rank_size, wholememory_handle_t wholememory_handle) noexcept { - return div_rounding_up_safe(total_entry_count, world_size); + if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { + return WHOLEMEMORY_INVALID_INPUT; + } + *rank_size = wholememory_handle->impl->get_local_size(); + return WHOLEMEMORY_SUCCESS; } -wholememory_error_code_t get_partition_plan_from_handle( - size_t* size_per_rank, wholememory_handle_t wholememory_handle) noexcept +wholememory_error_code_t get_local_offset_from_handle( + size_t* local_offset, wholememory_handle_t wholememory_handle) noexcept { if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { return WHOLEMEMORY_INVALID_INPUT; } - *size_per_rank = wholememory_handle->impl->get_partition_stride(); + *local_offset = wholememory_handle->impl->get_local_offset(); return WHOLEMEMORY_SUCCESS; } diff --git a/cpp/src/wholememory/memory_handle.hpp b/cpp/src/wholememory/memory_handle.hpp index 159d6b4..a5ef211 100644 --- a/cpp/src/wholememory/memory_handle.hpp +++ b/cpp/src/wholememory/memory_handle.hpp @@ -40,7 +40,8 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) noexcept; + size_t data_granularity, + size_t* rank_entry_partition = nullptr) noexcept; wholememory_error_code_t destroy_wholememory_with_comm_locked( wholememory_handle_t wholememory_handle) noexcept; @@ -50,6 +51,12 @@ wholememory_error_code_t destroy_wholememory(wholememory_handle_t wholememory_ha wholememory_error_code_t get_communicator_from_handle( wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept; +wholememory_error_code_t get_local_communicator_from_handle( + wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept; + +wholememory_error_code_t get_cross_communicator_from_handle( + wholememory_comm_t* comm, wholememory_handle_t wholememory_handle) noexcept; + wholememory_memory_type_t get_memory_type(wholememory_handle_t wholememory_handle) noexcept; wholememory_memory_location_t get_memory_location(wholememory_handle_t wholememory_handle) noexcept; @@ -64,6 +71,12 @@ wholememory_error_code_t get_local_memory_from_handle( size_t* local_offset, wholememory_handle_t wholememory_handle) noexcept; +wholememory_error_code_t get_local_node_memory_from_handle( + void** local_ptr, + size_t* local_size, + size_t* local_offset, + wholememory_handle_t wholememory_handle) noexcept; + wholememory_error_code_t get_rank_memory_from_handle( void** rank_memory_ptr, size_t* rank_memory_size, @@ -71,21 +84,27 @@ wholememory_error_code_t get_rank_memory_from_handle( int rank, wholememory_handle_t wholememory_handle) noexcept; +wholememory_error_code_t get_local_size_from_handle( + size_t* size, wholememory_handle_t wholememory_handle) noexcept; + +wholememory_error_code_t get_local_offset_from_handle( + size_t* offset, wholememory_handle_t wholememory_handle) noexcept; + wholememory_error_code_t get_global_pointer_from_handle( void** global_ptr, wholememory_handle_t wholememory_handle) noexcept; wholememory_error_code_t get_global_reference_from_handle( wholememory_gref_t* wholememory_gref, wholememory_handle_t wholememory_handle) noexcept; -wholememory_error_code_t determine_partition_plan(size_t* size_per_rank, - size_t total_size, - size_t data_granularity, - int world_size) noexcept; +wholememory_error_code_t equal_partition_plan(size_t* entry_per_rank, + size_t total_entry_count, + int world_size) noexcept; -size_t determine_entry_partition_plan(size_t total_entry_count, int world_size) noexcept; +wholememory_error_code_t get_rank_partition_sizes_from_handle( + size_t* rank_sizes, wholememory_handle_t wholememory_handle) noexcept; -wholememory_error_code_t get_partition_plan_from_handle( - size_t* size_per_rank, wholememory_handle_t wholememory_handle) noexcept; +wholememory_error_code_t get_rank_partition_offsets_from_handle( + size_t* rank_offsets, wholememory_handle_t wholememory_handle) noexcept; wholememory_distributed_backend_t get_distributed_backend_t( wholememory_handle_t wholememory_handle) noexcept; diff --git a/cpp/src/wholememory/wholememory.cpp b/cpp/src/wholememory/wholememory.cpp index 814e900..6f85dec 100644 --- a/cpp/src/wholememory/wholememory.cpp +++ b/cpp/src/wholememory/wholememory.cpp @@ -75,6 +75,13 @@ wholememory_error_code_t wholememory_communicator_get_size(int* size, wholememor { return wholememory::communicator_get_size(size, comm); } + +wholememory_error_code_t wholememory_communicator_get_local_size(int* local_size, + wholememory_comm_t comm) +{ + return wholememory::communicator_get_local_size(local_size, comm); +} + bool wholememory_communicator_is_bind_to_nvshmem(wholememory_comm_t comm) { #ifdef WITH_NVSHMEM_SUPPORT @@ -107,10 +114,16 @@ wholememory_error_code_t wholememory_malloc(wholememory_handle_t* wholememory_ha wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) { - return wholememory::create_wholememory( - wholememory_handle_ptr, total_size, comm, memory_type, memory_location, data_granularity); + return wholememory::create_wholememory(wholememory_handle_ptr, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } wholememory_error_code_t wholememory_free(wholememory_handle_t wholememory_handle) @@ -124,6 +137,18 @@ wholememory_error_code_t wholememory_get_communicator(wholememory_comm_t* comm, return wholememory::get_communicator_from_handle(comm, wholememory_handle); } +wholememory_error_code_t wholememory_get_local_communicator(wholememory_comm_t* comm, + wholememory_handle_t wholememory_handle) +{ + return wholememory::get_local_communicator_from_handle(comm, wholememory_handle); +} + +wholememory_error_code_t wholememory_get_cross_communicator(wholememory_comm_t* comm, + wholememory_handle_t wholememory_handle) +{ + return wholememory::get_cross_communicator_from_handle(comm, wholememory_handle); +} + wholememory_memory_type_t wholememory_get_memory_type(wholememory_handle_t wholememory_handle) { return wholememory::get_memory_type(wholememory_handle); @@ -170,6 +195,13 @@ wholememory_error_code_t wholememory_get_rank_memory(void** rank_memory_ptr, rank_memory_ptr, rank_memory_size, rank_memory_offset, rank, wholememory_handle); } +wholememory_error_code_t wholememory_equal_entry_partition_plan(size_t* entry_per_rank, + size_t total_entry_count, + int world_size) +{ + return wholememory::equal_partition_plan(entry_per_rank, total_entry_count, world_size); +} + wholememory_error_code_t wholememory_get_global_pointer(void** global_ptr, wholememory_handle_t wholememory_handle) { @@ -193,28 +225,28 @@ wholememory_error_code_t wholememory_get_nvshmem_reference( #endif -wholememory_error_code_t wholememory_determine_partition_plan(size_t* size_per_rank, - size_t total_size, - size_t data_granularity, - int world_size) +wholememory_error_code_t wholememory_get_rank_partition_sizes( + size_t* rank_sizes, wholememory_handle_t wholememory_handle) { - return wholememory::determine_partition_plan( - size_per_rank, total_size, data_granularity, world_size); + return wholememory::get_rank_partition_sizes_from_handle(rank_sizes, wholememory_handle); } -wholememory_error_code_t wholememory_determine_entry_partition_plan(size_t* entry_per_rank, - size_t total_entry_count, - int world_size) +wholememory_error_code_t wholememory_get_rank_partition_offsets( + size_t* rank_offsets, wholememory_handle_t wholememory_handle) { - if (entry_per_rank == nullptr) { return WHOLEMEMORY_INVALID_INPUT; } - *entry_per_rank = wholememory::determine_entry_partition_plan(total_entry_count, world_size); - return WHOLEMEMORY_SUCCESS; + return wholememory::get_rank_partition_offsets_from_handle(rank_offsets, wholememory_handle); } -wholememory_error_code_t wholememory_get_partition_plan(size_t* size_per_rank, - wholememory_handle_t wholememory_handle) +wholememory_error_code_t wholememory_get_local_size(size_t* local_size, + wholememory_handle_t wholememory_handle) +{ + return wholememory::get_local_size_from_handle(local_size, wholememory_handle); +} + +wholememory_error_code_t wholememory_get_local_offset(size_t* local_size, + wholememory_handle_t wholememory_handle) { - return wholememory::get_partition_plan_from_handle(size_per_rank, wholememory_handle); + return wholememory::get_local_offset_from_handle(local_size, wholememory_handle); } int fork_get_device_count() diff --git a/cpp/src/wholememory/wholememory_tensor.cpp b/cpp/src/wholememory/wholememory_tensor.cpp index a02fdaa..41ba109 100644 --- a/cpp/src/wholememory/wholememory_tensor.cpp +++ b/cpp/src/wholememory/wholememory_tensor.cpp @@ -53,7 +53,8 @@ wholememory_error_code_t wholememory_create_tensor( wholememory_tensor_description_t* tensor_description, wholememory_comm_t comm, wholememory_memory_type_t memory_type, - wholememory_memory_location_t memory_location) + wholememory_memory_location_t memory_location, + size_t* tensor_entry_partition) { if (p_wholememory_tensor == nullptr) { WHOLEMEMORY_ERROR("p_wholememory_tensor is nullptr"); @@ -98,7 +99,8 @@ wholememory_error_code_t wholememory_create_tensor( comm, memory_type, memory_location, - granularity); + granularity, + tensor_entry_partition); inc_tensor_count(); if (ret_code != WHOLEMEMORY_SUCCESS) { free(wholememory_tensor); } return ret_code; @@ -259,16 +261,10 @@ wholememory_error_code_t wholememory_tensor_map_local_tensor( wholememory_get_local_memory(&local_ptr, &local_size, &local_offset, handle)); size_t const element_size = wholememory_dtype_get_element_size(wm_desc->dtype); size_t const gran_size = wm_desc->dim == 1 ? element_size : element_size * wm_desc->strides[0]; - size_t size_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_partition_plan(&size_per_rank, handle)); - WHOLEMEMORY_CHECK_NOTHROW(size_per_rank % gran_size == 0); - size_t entry_per_rank = size_per_rank / gran_size; - int64_t local_start = std::min(entry_per_rank * world_rank, wm_desc->sizes[0]); - int64_t local_end = std::min(entry_per_rank * (world_rank + 1), wm_desc->sizes[0]); + local_size = std::min(local_size, wm_desc->sizes[0] * gran_size - local_offset); if (local_size % gran_size != 0) return WHOLEMEMORY_LOGIC_ERROR; wholememory_tensor_description_t local_desc = *wm_desc; - // local_desc.sizes[0] = local_size / gran_size; - local_desc.sizes[0] = (local_end - local_start); + local_desc.sizes[0] = local_size / gran_size; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_make_tensor_from_pointer(local_tensor, local_ptr, &local_desc)); @@ -297,36 +293,120 @@ void* wholememory_tensor_get_data_pointer(wholememory_tensor_t wholememory_tenso wholememory_tensor->tensor_description.storage_offset; } -size_t wholememory_tensor_get_entry_per_partition(wholememory_tensor_t wholememory_tensor) +wholememory_error_code_t wholememory_tensor_get_entry_offsets( + size_t* entry_offsets, wholememory_tensor_t wholememory_tensor) { wholememory_tensor_t root_tensor = wholememory_tensor_get_root(wholememory_tensor); WHOLEMEMORY_CHECK_NOTHROW( (root_tensor->tensor_description.dim == 1 || root_tensor->tensor_description.dim == 2)); if (wholememory_tensor->is_wholememory) { - size_t size_per_rank; - wholememory_get_partition_plan(&size_per_rank, - wholememory_tensor_get_memory_handle(root_tensor)); size_t embedding_stride = 1; size_t const element_size = wholememory_dtype_get_element_size(wholememory_tensor->tensor_description.dtype); if (root_tensor->tensor_description.dim == 2) { embedding_stride = root_tensor->tensor_description.strides[0]; } - WHOLEMEMORY_CHECK_NOTHROW(size_per_rank % (embedding_stride * element_size) == 0); - size_t det_entry_per_rank; int world_size; wholememory_comm_t comm; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator( &comm, wholememory_tensor_get_memory_handle(wholememory_tensor))); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, comm)); - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_determine_entry_partition_plan( - &det_entry_per_rank, root_tensor->tensor_description.sizes[0], world_size)); - WHOLEMEMORY_CHECK_NOTHROW(det_entry_per_rank == - size_per_rank / (embedding_stride * element_size)); - return det_entry_per_rank; + + wholememory_get_rank_partition_offsets( + entry_offsets, wholememory_tensor_get_memory_handle(wholememory_tensor)); + for (int i = 0; i < world_size + 1; i++) { + WHOLEMEMORY_CHECK_NOTHROW(entry_offsets[i] % (embedding_stride * element_size) == 0); + entry_offsets[i] /= (embedding_stride * element_size); + } + return WHOLEMEMORY_SUCCESS; } - return root_tensor->tensor_description.sizes[0]; + entry_offsets[0] = 0; + entry_offsets[1] = root_tensor->tensor_description.sizes[0]; + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t wholememory_tensor_get_entry_partition_sizes( + size_t* entry_partition, wholememory_tensor_t wholememory_tensor) +{ + wholememory_tensor_t root_tensor = wholememory_tensor_get_root(wholememory_tensor); + WHOLEMEMORY_CHECK_NOTHROW( + (root_tensor->tensor_description.dim == 1 || root_tensor->tensor_description.dim == 2)); + if (wholememory_tensor->is_wholememory) { + size_t embedding_stride = 1; + size_t const element_size = + wholememory_dtype_get_element_size(wholememory_tensor->tensor_description.dtype); + if (root_tensor->tensor_description.dim == 2) { + embedding_stride = root_tensor->tensor_description.strides[0]; + } + + int world_size; + wholememory_comm_t comm; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator( + &comm, wholememory_tensor_get_memory_handle(wholememory_tensor))); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, comm)); + + wholememory_get_rank_partition_sizes(entry_partition, + wholememory_tensor_get_memory_handle(wholememory_tensor)); + for (int i = 0; i < world_size; i++) { + WHOLEMEMORY_CHECK_NOTHROW(entry_partition[i] % (embedding_stride * element_size) == 0); + entry_partition[i] /= (embedding_stride * element_size); + } + return WHOLEMEMORY_SUCCESS; + } + entry_partition[0] = root_tensor->tensor_description.sizes[0]; + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t wholememory_tensor_get_local_entry_count( + size_t* local_entry_count, wholememory_tensor_t wholememory_tensor) +{ + wholememory_tensor_t root_tensor = wholememory_tensor_get_root(wholememory_tensor); + WHOLEMEMORY_CHECK_NOTHROW( + (root_tensor->tensor_description.dim == 1 || root_tensor->tensor_description.dim == 2)); + if (wholememory_tensor->is_wholememory) { + size_t embedding_stride = 1; + size_t const element_size = + wholememory_dtype_get_element_size(wholememory_tensor->tensor_description.dtype); + if (root_tensor->tensor_description.dim == 2) { + embedding_stride = root_tensor->tensor_description.strides[0]; + } + + size_t entry_cnt; + wholememory_get_local_size(&entry_cnt, + wholememory_tensor_get_memory_handle(wholememory_tensor)); + WHOLEMEMORY_CHECK_NOTHROW(entry_cnt % (embedding_stride * element_size) == 0); + entry_cnt /= (embedding_stride * element_size); + *local_entry_count = entry_cnt; + return WHOLEMEMORY_SUCCESS; + } + *local_entry_count = root_tensor->tensor_description.sizes[0]; + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t wholememory_tensor_get_local_entry_start( + size_t* local_entry_start, wholememory_tensor_t wholememory_tensor) +{ + wholememory_tensor_t root_tensor = wholememory_tensor_get_root(wholememory_tensor); + WHOLEMEMORY_CHECK_NOTHROW( + (root_tensor->tensor_description.dim == 1 || root_tensor->tensor_description.dim == 2)); + if (wholememory_tensor->is_wholememory) { + size_t embedding_stride = 1; + size_t const element_size = + wholememory_dtype_get_element_size(wholememory_tensor->tensor_description.dtype); + if (root_tensor->tensor_description.dim == 2) { + embedding_stride = root_tensor->tensor_description.strides[0]; + } + size_t entry_start; + wholememory_get_local_offset(&entry_start, + wholememory_tensor_get_memory_handle(wholememory_tensor)); + WHOLEMEMORY_CHECK_NOTHROW(entry_start % (embedding_stride * element_size) == 0); + entry_start /= (embedding_stride * element_size); + *local_entry_start = entry_start; + return WHOLEMEMORY_SUCCESS; + } + *local_entry_start = 0; + return WHOLEMEMORY_SUCCESS; } wholememory_error_code_t wholememory_tensor_get_subtensor( diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu new file mode 100644 index 0000000..ff901d8 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.cu @@ -0,0 +1,474 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include +#include +#include +#include + +#include + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "logger.hpp" +#include "wholememory/integer_utils.hpp" +#include "wholememory_ops/register.hpp" +#include "wholememory_ops/temp_memory_handle.hpp" +#include + +namespace wholememory_ops { + +template +__device__ __forceinline__ int dest_rank(IndexT entry_idx, + const size_t* embedding_entry_offsets, + int world_size) +{ + size_t total_entry_count = embedding_entry_offsets[world_size]; + size_t estimated_entry_per_rank = total_entry_count / world_size; + int estimated_rank = max(world_size - 1, int(entry_idx / estimated_entry_per_rank)); + if (embedding_entry_offsets[estimated_rank] > entry_idx) { + for (int i = estimated_rank - 1; i >= 0; i--) { + if (embedding_entry_offsets[i] <= entry_idx) { return i; } + } + } else { + for (int i = estimated_rank + 1; i <= world_size; i++) { + if (embedding_entry_offsets[i] > entry_idx) { return i - 1; } + } + } + return 0; +} + +template +__global__ void bucket_ids_for_hierarchy_kernel(const IndexT* indices, + size_t indice_count, + int64_t* dev_rank_id_count_ptr, + const size_t* embedding_entry_offsets, + int local_size, + int world_size, + int nbucket) +{ + extern __shared__ char shared_mem[]; + size_t* embedding_entry_offsets_shared = reinterpret_cast(shared_mem); + int* rank_count_shared = reinterpret_cast(shared_mem + sizeof(size_t) * (world_size + 1)); + for (int idx = threadIdx.x; idx < nbucket; idx += blockDim.x) { + rank_count_shared[idx] = 0; + } + for (int idx = threadIdx.x; idx < world_size + 1; idx += blockDim.x) { + embedding_entry_offsets_shared[idx] = embedding_entry_offsets[idx]; + } + __syncthreads(); + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count; + idx += blockDim.x * gridDim.x) { + IndexT node_idx = indices[idx]; + if (node_idx < 0) continue; + int rank = dest_rank(node_idx, embedding_entry_offsets_shared, world_size); + int bucket = 0; + if (BUCKET_CROSS_OR_LOCAL == 0) + bucket = rank % local_size; + else + bucket = rank / local_size; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + atomicAdd_block(&rank_count_shared[bucket], 1); +#else + atomicAdd(&rank_count_shared[bucket], 1); +#endif + } + __syncthreads(); + for (int idx = threadIdx.x; idx < nbucket; idx += blockDim.x) { + atomicAdd(reinterpret_cast(dev_rank_id_count_ptr) + idx, + static_cast(rank_count_shared[idx])); + } +} + +template +void bucket_ids_for_hierarchy_temp_func(const void* indices, + wholememory_array_description_t indice_desc, + int64_t* dev_rank_id_count_ptr, + const size_t* dev_embedding_entry_offsets, + int local_size, + int cross_size, + int bucket_cross_or_local, + int sm_count, + cudaStream_t stream) +{ + static constexpr int BLOCK_SIZE = 128; + int block_count = wholememory::div_rounding_up_unsafe(indice_desc.size, BLOCK_SIZE); + block_count = std::min(block_count, sm_count * 4); + const IndexT* indices_ptr = static_cast(indices); + indices_ptr += indice_desc.storage_offset; + int world_size = local_size * cross_size; + if (bucket_cross_or_local == 0) { + int bucket_size = local_size; + cudaMemsetAsync(dev_rank_id_count_ptr, 0, sizeof(int64_t) * bucket_size, stream); + bucket_ids_for_hierarchy_kernel + <<>>(indices_ptr, + indice_desc.size, + dev_rank_id_count_ptr, + dev_embedding_entry_offsets, + local_size, + world_size, + bucket_size); + } else { + int bucket_size = cross_size; + cudaMemsetAsync(dev_rank_id_count_ptr, 0, sizeof(int64_t) * bucket_size, stream); + bucket_ids_for_hierarchy_kernel + <<>>(indices_ptr, + indice_desc.size, + dev_rank_id_count_ptr, + dev_embedding_entry_offsets, + local_size, + world_size, + bucket_size); + } +} + +REGISTER_DISPATCH_ONE_TYPE(BucketIdsForHierarchy, bucket_ids_for_hierarchy_temp_func, SINT3264) + +template +__global__ void reorder_ids_for_hierarchy_kernel(const IndexT* indices, + size_t indice_count, + IndexT* dev_bucket_indices, + IndexT* dev_indice_map, + const int64_t* dev_rank_id_offset_ptr, + const size_t* embedding_entry_offsets, + int local_size, + int world_size, + int nbucket, + int64_t* dev_bucket_atomic_add_ptr) +{ + constexpr size_t shared_mem_size = 24576; + __shared__ char shared_mem[shared_mem_size]; + size_t* embedding_entry_offsets_shared = reinterpret_cast(shared_mem); + char* shared_mem_for_bucket = shared_mem + sizeof(size_t) * (world_size + 1); + int* block_bucket_count_shared = reinterpret_cast(shared_mem_for_bucket); + int* block_bucket_atomic_add_shared = reinterpret_cast(shared_mem_for_bucket) + nbucket; + IndexT* block_bucket_offset_shared = + reinterpret_cast(shared_mem_for_bucket + 2 * sizeof(int) * nbucket); + IndexT* global_bucket_offset_shared = block_bucket_offset_shared + nbucket; + size_t buffer_size = (shared_mem_size - sizeof(size_t) * (world_size + 1) - + nbucket * 2 * (sizeof(IndexT) + sizeof(int))) / + sizeof(IndexT) / 2; + buffer_size = (buffer_size / blockDim.x) * blockDim.x; + assert(buffer_size > 0); + + for (int idx = threadIdx.x; idx < world_size + 1; idx += blockDim.x) { + embedding_entry_offsets_shared[idx] = embedding_entry_offsets[idx]; + } + __syncthreads(); + IndexT* buffer_load = global_bucket_offset_shared + nbucket; + IndexT* buffer_store = buffer_load + buffer_size; + + int warp_idx = threadIdx.x / warpSize; + int lane_idx = threadIdx.x % warpSize; + int nwarp = blockDim.x / warpSize; + for (IndexT load_offset = buffer_size * blockIdx.x; load_offset < indice_count; + load_offset += gridDim.x * buffer_size) { + for (int i = threadIdx.x; i < nbucket; i += blockDim.x) { + block_bucket_count_shared[i] = 0; + block_bucket_atomic_add_shared[i] = 0; + } + __syncthreads(); + for (IndexT i = threadIdx.x; i < buffer_size; i += blockDim.x) { + IndexT load_idx = i + load_offset; + if (load_idx >= indice_count) break; + IndexT indice = indices[load_idx]; + + buffer_load[i] = indice; + int bucket_idx = 0; + int rank = dest_rank(indice, embedding_entry_offsets_shared, world_size); + if (BUCKET_CROSS_OR_LOCAL == 0) { + bucket_idx = rank % local_size; + } else { + bucket_idx = rank / local_size; + } +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + atomicAdd_block(&block_bucket_count_shared[bucket_idx], 1); +#else + atomicAdd(&block_bucket_count_shared[bucket_idx], 1); +#endif + } + __syncthreads(); + if (threadIdx.x == blockDim.x - 1) { + IndexT bucket_offset_tmp = 0; + for (int bi = 0; bi < nbucket; bi++) { + block_bucket_offset_shared[bi] = bucket_offset_tmp; + bucket_offset_tmp += block_bucket_count_shared[bi]; + } + } + if (threadIdx.x < nbucket) { + int bucket_idx = threadIdx.x; + global_bucket_offset_shared[bucket_idx] = + atomicAdd(reinterpret_cast(dev_bucket_atomic_add_ptr) + bucket_idx, + block_bucket_count_shared[bucket_idx]); + } + __syncthreads(); + for (IndexT i = threadIdx.x; i < buffer_size; i += blockDim.x) { + IndexT indice = buffer_load[i]; + IndexT load_idx = i + load_offset; + if (load_idx >= indice_count) break; + int bucket_idx = 0; + int rank = dest_rank(indice, embedding_entry_offsets_shared, world_size); + if (BUCKET_CROSS_OR_LOCAL == 0) { + bucket_idx = rank % local_size; + } else { + bucket_idx = rank / local_size; + } +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + int block_bucket_inc = atomicAdd_block(&block_bucket_atomic_add_shared[bucket_idx], 1); +#else + int block_bucket_inc = atomicAdd(&block_bucket_atomic_add_shared[bucket_idx], 1); +#endif + buffer_store[block_bucket_offset_shared[bucket_idx] + block_bucket_inc] = indice; + dev_indice_map[load_idx] = dev_rank_id_offset_ptr[bucket_idx] + + global_bucket_offset_shared[bucket_idx] + block_bucket_inc; + } + __syncthreads(); + for (int bucket_idx = warp_idx; bucket_idx < nbucket; bucket_idx += nwarp) { + int bucket_length = block_bucket_count_shared[bucket_idx]; + IndexT global_bucket_offset = + dev_rank_id_offset_ptr[bucket_idx] + global_bucket_offset_shared[bucket_idx]; + for (int idx = lane_idx; idx < bucket_length; idx += warpSize) { + dev_bucket_indices[global_bucket_offset + idx] = + buffer_store[block_bucket_offset_shared[bucket_idx] + idx]; + } + } + __syncthreads(); + } +} + +template +void reorder_ids_for_hierarchy_temp_func(const void* indices, + wholememory_array_description_t indice_desc, + void* dev_bucket_indices, + void* dev_indice_map, + const int64_t* dev_rank_id_count_ptr, + const size_t* dev_embedding_entry_offsets, + int local_size, + int cross_size, + int bucket_cross_or_local, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + int sm_count, + cudaStream_t stream) +{ + WHOLEMEMORY_CHECK(indice_desc.storage_offset == 0); + WHOLEMEMORY_CHECK(indice_desc.dtype == WHOLEMEMORY_DT_INT || + indice_desc.dtype == WHOLEMEMORY_DT_INT64); + int nbucket = 0; + if (bucket_cross_or_local == 0) { + nbucket = local_size; + } else { + nbucket = cross_size; + } + int world_size = local_size * cross_size; + temp_memory_handle dev_rank_id_offset_handle(p_env_fns); + int64_t* dev_rank_id_offset_ptr = + static_cast(dev_rank_id_offset_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64)); + void* cub_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceScan::ExclusiveSum(cub_temp_storage, + temp_storage_bytes, + dev_rank_id_count_ptr, + dev_rank_id_offset_ptr, + nbucket, + stream); + cub_temp_storage = p_thrust_allocator->allocate(temp_storage_bytes); + cub::DeviceScan::ExclusiveSum(cub_temp_storage, + temp_storage_bytes, + dev_rank_id_count_ptr, + dev_rank_id_offset_ptr, + nbucket, + stream); + p_thrust_allocator->deallocate(reinterpret_cast(cub_temp_storage), temp_storage_bytes); + + temp_memory_handle dev_bucket_atomic_add_handle(p_env_fns); + int64_t* dev_bucket_atomic_add_ptr = static_cast( + dev_bucket_atomic_add_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64)); + cudaMemsetAsync((void*)dev_bucket_atomic_add_ptr, 0, sizeof(int64_t) * nbucket, stream); + static constexpr int BLOCK_SIZE = 128; + int block_count = wholememory::div_rounding_up_unsafe(indice_desc.size, BLOCK_SIZE); + block_count = std::min(block_count, sm_count * 4); + + if (bucket_cross_or_local == 0) + reorder_ids_for_hierarchy_kernel + <<>>(static_cast(indices), + indice_desc.size, + static_cast(dev_bucket_indices), + static_cast(dev_indice_map), + dev_rank_id_offset_ptr, + dev_embedding_entry_offsets, + local_size, + world_size, + nbucket, + dev_bucket_atomic_add_ptr); + else + reorder_ids_for_hierarchy_kernel + <<>>(static_cast(indices), + indice_desc.size, + static_cast(dev_bucket_indices), + static_cast(dev_indice_map), + dev_rank_id_offset_ptr, + dev_embedding_entry_offsets, + local_size, + world_size, + nbucket, + dev_bucket_atomic_add_ptr); + ; +} + +REGISTER_DISPATCH_ONE_TYPE(ReorderIdsForHierarchy, reorder_ids_for_hierarchy_temp_func, SINT3264) + +wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( + void* indices, + wholememory_array_description_t indice_desc, + void* dev_bucket_indices, + void* dev_indice_map, + int64_t* host_bucket_id_count, + size_t* dev_embedding_entry_offsets, + wholememory_comm_t wm_global_comm, + wholememory_comm_t wm_local_comm, + int bucket_cross_or_local, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + if (indice_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } + int world_size, local_size, cross_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_global_comm)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); + WHOLEMEMORY_CHECK_NOTHROW(world_size % local_size == 0); + cross_size = world_size / local_size; + + WHOLEMEMORY_EXPECTS_NOTHROW(bucket_cross_or_local == 0 || bucket_cross_or_local == 1, + "param bucket_cross_or_local must be 0 or 1, 0: cross, 1: local"); + int nbucket = 0; + if (bucket_cross_or_local == 0) { // bucket by cross id + nbucket = local_size; + } else { // bucket by local id + nbucket = cross_size; + } + constexpr int K_DEFAULT_SM_COUNT = 108; + auto prop = get_device_prop(-1); + int sm_count = (prop != nullptr) ? prop->multiProcessorCount : K_DEFAULT_SM_COUNT; + temp_memory_handle dev_rank_id_count_handle(p_env_fns); + int64_t* dev_rank_id_count_ptr = + static_cast(dev_rank_id_count_handle.device_malloc(nbucket, WHOLEMEMORY_DT_INT64)); + cudaMemsetAsync((void*)dev_rank_id_count_ptr, 0, sizeof(int64_t) * nbucket, stream); + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + BucketIdsForHierarchy, + indices, + indice_desc, + dev_rank_id_count_ptr, + dev_embedding_entry_offsets, + local_size, + cross_size, + bucket_cross_or_local, + sm_count, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("bucket_ids_for_hierarchy_func CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } + WM_CUDA_CHECK_NO_THROW(cudaMemcpyAsync(host_bucket_id_count, + dev_rank_id_count_ptr, + nbucket * sizeof(int64_t), + cudaMemcpyDeviceToHost, + stream)); + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + ReorderIdsForHierarchy, + indices, + indice_desc, + dev_bucket_indices, + dev_indice_map, + dev_rank_id_count_ptr, + dev_embedding_entry_offsets, + local_size, + cross_size, + bucket_cross_or_local, + p_thrust_allocator, + p_env_fns, + sm_count, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("reorder_ids_for_hierarchy CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } catch (wholememory::logic_error& wle) { + WHOLEMEMORY_ERROR("reorder_ids_for_hierarchy LOGIC Error %s\n", wle.what()); + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_UNKNOW_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t bucket_local_ids_func(void* indices, + wholememory_array_description_t indice_desc, + int64_t* host_bucket_id_count, + size_t* dev_embedding_entry_offsets, + wholememory_comm_t wm_local_comm, + wholememory_comm_t wm_cross_comm, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + if (indice_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } + int cross_size, local_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); + + constexpr int K_DEFAULT_SM_COUNT = 108; + auto prop = get_device_prop(-1); + int sm_count = (prop != nullptr) ? prop->multiProcessorCount : K_DEFAULT_SM_COUNT; + temp_memory_handle dev_rank_id_count_handle(p_env_fns); + int64_t* dev_rank_id_count_ptr = + static_cast(dev_rank_id_count_handle.device_malloc(cross_size, WHOLEMEMORY_DT_INT64)); + cudaMemsetAsync((void*)dev_rank_id_count_ptr, 0, sizeof(int64_t) * cross_size, stream); + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + BucketIdsForHierarchy, + indices, + indice_desc, + dev_rank_id_count_ptr, + dev_embedding_entry_offsets, + local_size, + cross_size, + 1, + sm_count, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("bucket_ids_for_hierarchy CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } + WM_CUDA_CHECK_NO_THROW(cudaMemcpyAsync(host_bucket_id_count, + dev_rank_id_count_ptr, + cross_size * sizeof(int64_t), + cudaMemcpyDeviceToHost, + stream)); + WM_CUDA_CHECK(cudaGetLastError()); + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h new file mode 100644 index 0000000..60665c9 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/bucket_ids_for_hierarchy_func.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include "wholememory_ops/temp_memory_handle.hpp" + +namespace wholememory_ops { + +wholememory_error_code_t bucket_and_reorder_ids_for_hierarchy_func( + void* indices, + wholememory_array_description_t indice_desc, + void* dev_bucket_indices, + void* dev_indice_map, + int64_t* host_bucket_id_count, + size_t* dev_embedding_entry_offsets, + wholememory_comm_t wm_global_comm, + wholememory_comm_t wm_local_comm, + int bucket_cross_or_local, // 0: cross, 1: local + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +wholememory_error_code_t bucket_local_ids_func(void* indices, + wholememory_array_description_t indice_desc, + int64_t* host_bucket_id_count, + size_t* dev_embedding_entry_offsets, + wholememory_comm_t wm_local_comm, + wholememory_comm_t wm_cross_comm, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_func.cu b/cpp/src/wholememory_ops/functions/bucket_ids_func.cu index c338065..6bd6b6c 100644 --- a/cpp/src/wholememory_ops/functions/bucket_ids_func.cu +++ b/cpp/src/wholememory_ops/functions/bucket_ids_func.cu @@ -28,23 +28,50 @@ namespace wholememory_ops { +template +__device__ __forceinline__ int dest_rank(IndexT entry_idx, + size_t total_entry_count, + const size_t* embedding_entry_offsets, + int world_size) +{ + size_t estimated_entry_per_rank = total_entry_count / world_size; + int estimated_rank = max(world_size - 1, int(entry_idx / estimated_entry_per_rank)); + if (embedding_entry_offsets[estimated_rank] > entry_idx) { + for (int i = estimated_rank - 1; i >= 0; i--) { + if (embedding_entry_offsets[i] <= entry_idx) { return i; } + } + } else { + for (int i = estimated_rank + 1; i <= world_size; i++) { + if (embedding_entry_offsets[i] > entry_idx) { return i - 1; } + } + } + return 0; +} + template __global__ void bucket_ids_for_ranks_kernel(const IndexT* indices, size_t indice_count, int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, int world_size) { - extern __shared__ int rank_count_shared[]; + extern __shared__ char shmem[]; + int* rank_count_shared = reinterpret_cast(shmem); for (int idx = threadIdx.x; idx < world_size; idx += blockDim.x) { rank_count_shared[idx] = 0; } + size_t* embedding_entry_offsets_shared = + reinterpret_cast(shmem + sizeof(size_t) * world_size); + for (int idx = threadIdx.x; idx < world_size + 1; idx += blockDim.x) { + embedding_entry_offsets_shared[idx] = embedding_entry_offsets[idx]; + } __syncthreads(); + size_t total_entry_count = embedding_entry_offsets_shared[world_size]; for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count; idx += blockDim.x * gridDim.x) { IndexT node_idx = indices[idx]; if (node_idx < 0) continue; - int rank = node_idx / embedding_entry_count_per_rank; + int rank = dest_rank(node_idx, total_entry_count, embedding_entry_offsets_shared, world_size); assert(rank >= 0 && rank < world_size); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 atomicAdd_block(&rank_count_shared[rank], 1); @@ -63,7 +90,7 @@ template void bucket_ids_for_ranks_temp_fn(void* indices, wholememory_array_description_t indice_desc, int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, int world_size, int sm_count, cudaStream_t stream) @@ -73,12 +100,11 @@ void bucket_ids_for_ranks_temp_fn(void* indices, block_count = std::min(block_count, sm_count * 4); IndexT* indices_ptr = static_cast(indices); indices_ptr += indice_desc.storage_offset; - bucket_ids_for_ranks_kernel<<>>( - indices_ptr, - indice_desc.size, - dev_rank_id_count_ptr, - embedding_entry_count_per_rank, - world_size); + bucket_ids_for_ranks_kernel<<>>( + indices_ptr, indice_desc.size, dev_rank_id_count_ptr, embedding_entry_offsets, world_size); } REGISTER_DISPATCH_ONE_TYPE(BucketIdForRanks, bucket_ids_for_ranks_temp_fn, SINT3264) @@ -86,7 +112,7 @@ REGISTER_DISPATCH_ONE_TYPE(BucketIdForRanks, bucket_ids_for_ranks_temp_fn, SINT3 wholememory_error_code_t bucket_ids_for_ranks(void* indices, wholememory_array_description_t indice_desc, int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, int world_size, cudaDeviceProp* prop, cudaStream_t stream) @@ -101,7 +127,7 @@ wholememory_error_code_t bucket_ids_for_ranks(void* indices, indices, indice_desc, dev_rank_id_count_ptr, - embedding_entry_count_per_rank, + embedding_entry_offsets, world_size, sm_count, stream); diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_func.h b/cpp/src/wholememory_ops/functions/bucket_ids_func.h index d38a77b..a8443e3 100644 --- a/cpp/src/wholememory_ops/functions/bucket_ids_func.h +++ b/cpp/src/wholememory_ops/functions/bucket_ids_func.h @@ -23,7 +23,7 @@ namespace wholememory_ops { wholememory_error_code_t bucket_ids_for_ranks(void* indices, wholememory_array_description_t indice_desc, int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, int world_size, cudaDeviceProp* prop, cudaStream_t stream); diff --git a/cpp/src/wholememory_ops/functions/embedding_cache_func.cu b/cpp/src/wholememory_ops/functions/embedding_cache_func.cu index cbe5060..fd9e846 100644 --- a/cpp/src/wholememory_ops/functions/embedding_cache_func.cu +++ b/cpp/src/wholememory_ops/functions/embedding_cache_func.cu @@ -348,10 +348,9 @@ wholememory_error_code_t update_cache_direct_same_comm( auto* raw_embedding_desc = wholememory_tensor_get_tensor_description(wholememory_tensor_get_root(wm_raw_memory_embedding)); - size_t embedding_entry_count_per_rank = 0; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_determine_entry_partition_plan( - &embedding_entry_count_per_rank, raw_embedding_desc->sizes[0], world_size)); - + size_t embedding_entry_start = 0; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_tensor_get_local_entry_start(&embedding_entry_start, wm_raw_memory_embedding)); int indices_num_run = 0; temp_memory_handle unique_indice_handle(p_env_fns), unique_count_handle(p_env_fns); try { @@ -380,7 +379,7 @@ wholememory_error_code_t update_cache_direct_same_comm( &unique_cache_set_start_handle, &unique_cache_set_count_handle, &cache_set_num_run, - world_rank * embedding_entry_count_per_rank, + embedding_entry_start, cache_set_coverage, &thrust_allocator, p_env_fns, @@ -414,7 +413,7 @@ wholememory_error_code_t update_cache_direct_same_comm( static_cast(embedding_local_pointer), embedding_dim_in_int4, cache_set_num_run, - world_rank * embedding_entry_count_per_rank, + embedding_entry_start, cache_set_coverage, stream); @@ -525,7 +524,7 @@ wholememory_error_code_t update_cache_different_comm( wholememory_array_description_t indice_desc, wholememory_tensor_t wm_raw_memory_embedding, wholememory_comm_t cache_comm, - size_t embedding_entry_count_per_cache_rank, + size_t* embedding_entry_offsets, const wholememory::embedding_cache_local_data* cache_local_data, int cache_set_coverage, wholememory_env_func_t* p_env_fns, @@ -554,7 +553,6 @@ wholememory_error_code_t update_cache_different_comm( WHOLEMEMORY_ERROR("SortUniqueLocalIndicesTempFunc failed."); return WHOLEMEMORY_LOGIC_ERROR; } - temp_memory_handle unique_cache_set_lid_handle(p_env_fns), unique_cache_set_start_handle(p_env_fns), unique_cache_set_count_handle(p_env_fns); int cache_set_num_run; @@ -566,7 +564,7 @@ wholememory_error_code_t update_cache_different_comm( &unique_cache_set_start_handle, &unique_cache_set_count_handle, &cache_set_num_run, - cache_world_rank * embedding_entry_count_per_cache_rank, + embedding_entry_offsets[cache_world_rank], cache_set_coverage, &thrust_allocator, p_env_fns, @@ -595,7 +593,7 @@ wholememory_error_code_t update_cache_different_comm( static_cast(wholememory_tensor_get_data_pointer(cache_local_data->access_count_)), local_write_cache_index_ptr, global_load_gid_ptr, - cache_world_rank * embedding_entry_count_per_cache_rank, + embedding_entry_offsets[cache_world_rank], cache_set_coverage, cache_set_num_run, stream); @@ -697,11 +695,12 @@ wholememory_error_code_t writeback_cache_direct_same_comm( auto* raw_embedding_desc = wholememory_tensor_get_tensor_description(wholememory_tensor_get_root(wm_raw_memory_embedding)); - size_t embedding_entry_count_per_rank = 0; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_determine_entry_partition_plan( - &embedding_entry_count_per_rank, raw_embedding_desc->sizes[0], world_size)); - WHOLEMEMORY_CHECK_NOTHROW(embedding_entry_count_per_rank % cache_set_coverage == 0); + size_t embedding_entry_count = 0; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_tensor_get_local_entry_count(&embedding_entry_count, wm_raw_memory_embedding)); + WHOLEMEMORY_CHECK_NOTHROW(embedding_entry_count % cache_set_coverage == 0); + wholememory_tensor_t raw_local_tensor; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_tensor_map_local_tensor(wm_raw_memory_embedding, &raw_local_tensor)); diff --git a/cpp/src/wholememory_ops/functions/embedding_cache_func.h b/cpp/src/wholememory_ops/functions/embedding_cache_func.h index 1f0d5ac..edf71e7 100644 --- a/cpp/src/wholememory_ops/functions/embedding_cache_func.h +++ b/cpp/src/wholememory_ops/functions/embedding_cache_func.h @@ -55,7 +55,7 @@ wholememory_error_code_t update_cache_direct_same_comm( * @param wm_raw_memory_embedding : the WholeMemory Tensor that is to be cached which stores all * embeddings. * @param cache_comm : communicator of cache - * @param embedding_entry_count_per_cache_rank : embedding entries covered by each cache rank + * @param embedding_entry_offsets : embedding entry offset of each cache rank * @param cache_local_data : embedding_cache_local_data of wm_raw_memory_embedding * @param cache_set_coverage : cache set coverage * @param p_env_fns : env fns @@ -67,7 +67,7 @@ wholememory_error_code_t update_cache_different_comm( wholememory_array_description_t indice_desc, wholememory_tensor_t wm_raw_memory_embedding, wholememory_comm_t cache_comm, - size_t embedding_entry_count_per_cache_rank, + size_t* embedding_entry_offsets, const wholememory::embedding_cache_local_data* cache_local_data, int cache_set_coverage, wholememory_env_func_t* p_env_fns, diff --git a/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu b/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu index 137b104..1739488 100644 --- a/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu +++ b/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu @@ -161,7 +161,7 @@ wholememory_error_code_t bucket_and_exchange_ids_func( int64_t* host_rank_id_count_ptr, temp_memory_handle* dev_recv_indices_buffer_handle, int64_t* dev_raw_indice_ptr, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_comm_t wm_comm, wm_thrust_allocator* p_thrust_allocator, wholememory_env_func_t* p_env_fns, @@ -178,7 +178,7 @@ wholememory_error_code_t bucket_and_exchange_ids_func( WHOLEMEMORY_RETURN_ON_FAIL(bucket_ids_for_ranks(indices, indice_desc, dev_rank_id_count_ptr, - embedding_entry_count_per_rank, + embedding_entry_offsets, world_size, get_device_prop(-1), stream)); diff --git a/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.h b/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.h index 15b7cf4..69a2d92 100644 --- a/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.h +++ b/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.h @@ -34,7 +34,7 @@ namespace wholememory_ops { * @param dev_recv_indices_buffer_handle : temp_memory_handle to create buffer for received indices. * @param dev_raw_indice_ptr : pointer to allocated int64_t array to storage raw indices mapping of * sort - * @param embedding_entry_count_per_rank : entry count of embedding count per rank + * @param embedding_entry_offsets : embedding entry offsets * @param wm_comm : WholeMemory Communicator * @param p_thrust_allocator : thrust allocator * @param p_env_fns : EnvFns @@ -48,7 +48,7 @@ wholememory_error_code_t bucket_and_exchange_ids_func( int64_t* host_rank_id_count_ptr, temp_memory_handle* dev_recv_indices_buffer_handle, int64_t* dev_raw_indice_ptr, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_comm_t wm_comm, wm_thrust_allocator* p_thrust_allocator, wholememory_env_func_t* p_env_fns, diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh index a4979f7..140b257 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh @@ -260,23 +260,25 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, OutputT* output, wholememory_matrix_description_t output_desc) { - auto block = cooperative_groups::this_thread_block(); - auto mywarp = cooperative_groups::tiled_partition<32>(block); - __shared__ char shm_in_char[16384]; - OutputT* all_sh = reinterpret_cast(shm_in_char); - OutputT* my_shared; + auto block = cooperative_groups::this_thread_block(); + auto mywarp = cooperative_groups::tiled_partition<32>(block); + constexpr size_t shm_max_size = 16384; + __shared__ char shm_in_char[shm_max_size]; int warp_id = (threadIdx.x + blockIdx.x * blockDim.x) / 32; int lane_id = threadIdx.x % 32; int embedding_size = embedding_desc.sizes[1]; int64_t embedding_stride = embedding_desc.stride; int64_t output_stride = output_desc.stride; - int shm_size = 16384 / sizeof(OutputT); + wholememory::device_reference embedding_dev_ref(embedding_gref); typed_data_vector embeddings; typed_data_vector outputs; + int shm_size = shm_max_size / sizeof(OutputT); + OutputT* all_sh = reinterpret_cast(shm_in_char); + OutputT* my_shared; bool use_shm = true; if (shm_size / (blockDim.x / 32) < output_desc.sizes[1]) { // use_shm = false; @@ -342,6 +344,7 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, int sub_warp_num = subwarp.meta_group_size() * gridDim.x; int lane_id_in_sub_warp = subwarp.thread_rank(); + wholememory::device_reference embedding_dev_ref(embedding_gref); int embedding_size = embedding_desc.sizes[1]; @@ -358,11 +361,10 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, if (embedding_table_idx < 0) continue; int64_t embedding_offset = embedding_desc.storage_offset + embedding_table_idx * embedding_stride; - + EmbeddingT* emb_ptr = &embedding_dev_ref[embedding_offset]; for (int emb_idx = lane_id_in_sub_warp * ALIGNMENT; emb_idx < embedding_size; emb_idx += ALIGNMENT * SUB_WARP_SIZE) { - mov_data(&embeddings, - &embedding_dev_ref[embedding_offset + emb_idx]); + mov_data(&embeddings, &emb_ptr[emb_idx]); #pragma unroll for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) { typed_data_vector_at(outputs, sub_idx) = @@ -522,11 +524,10 @@ __global__ void scatter_func_kernel(const InputT* input, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc) { - auto block = cooperative_groups::this_thread_block(); - auto mywarp = cooperative_groups::tiled_partition<32>(block); - __shared__ char shm_in_char[24576]; - InputT* all_sh = reinterpret_cast(shm_in_char); - InputT* my_shared; + auto block = cooperative_groups::this_thread_block(); + auto mywarp = cooperative_groups::tiled_partition<32>(block); + constexpr size_t shm_max_size = 24576; + __shared__ char shm_in_char[shm_max_size]; int warp_id = (threadIdx.x + blockIdx.x * blockDim.x) / 32; int lane_id = threadIdx.x % 32; @@ -535,11 +536,13 @@ __global__ void scatter_func_kernel(const InputT* input, int64_t input_stride = input_desc.stride; int async_copy_align = sizeof(InputT) > 4 ? 1 : 4 / sizeof(InputT); - int shm_size = 24576 / sizeof(InputT); + wholememory::device_reference embedding_dev_ref(embedding_gref); + int shm_size = shm_max_size / sizeof(InputT); + InputT* all_sh = reinterpret_cast(shm_in_char); + InputT* my_shared; int batch_size = (shm_size / (blockDim.x / 32) - async_copy_align) / input_stride; // indices batch size in lines - wholememory::device_reference embedding_dev_ref(embedding_gref); typed_data_vector embeddings; typed_data_vector inputs; diff --git a/cpp/src/wholememory_ops/functions/map_indices_func.cu b/cpp/src/wholememory_ops/functions/map_indices_func.cu index 1a14181..e07ac40 100644 --- a/cpp/src/wholememory_ops/functions/map_indices_func.cu +++ b/cpp/src/wholememory_ops/functions/map_indices_func.cu @@ -28,7 +28,7 @@ __global__ void storage_idx2wm_emb_idx_kernel(IndexT* indice, IndexT* mapped_indice, int64_t indice_size, int world_size, - int64_t entry_per_rank, + int64_t entry_start, int round_robin_size) { int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -38,7 +38,7 @@ __global__ void storage_idx2wm_emb_idx_kernel(IndexT* indice, IndexT table_off = target_idx % round_robin_size; int rank_id = table_idx % world_size; int rank_table_idx = table_idx / world_size; - IndexT wmidx = entry_per_rank * rank_id + round_robin_size * rank_table_idx + table_off; + IndexT wmidx = entry_start + round_robin_size * rank_table_idx + table_off; mapped_indice[i] = wmidx; } return; @@ -49,7 +49,7 @@ void storage_idx2wm_emb_idx_temp_fn(void* indice_ptr, void* mapped_indice_ptr, int64_t indice_size, int world_size, - int64_t entry_per_rank, + int64_t entry_start, int round_robin_size, cudaStream_t stream) { @@ -59,7 +59,7 @@ void storage_idx2wm_emb_idx_temp_fn(void* indice_ptr, IndexT* indice = static_cast(indice_ptr); IndexT* mapped_indice = static_cast(mapped_indice_ptr); storage_idx2wm_emb_idx_kernel<<>>( - indice, mapped_indice, indice_size, world_size, entry_per_rank, round_robin_size); + indice, mapped_indice, indice_size, world_size, entry_start, round_robin_size); WM_CUDA_CHECK(cudaStreamSynchronize(stream)); return; } @@ -85,14 +85,16 @@ wholememory_error_code_t storage_index2wm_embedding_index(wholememory_tensor_t i WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, handle)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_comm)); - int64_t entry_per_rank = wholememory_tensor_get_entry_per_partition(allocated_embedding); + size_t entry_start = 0; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_tensor_get_local_entry_start(&entry_start, allocated_embedding)); DISPATCH_ONE_TYPE(indice_desc->dtype, storageidx2wmembidx, indice_ptr, mapped_indice_ptr, indice_size, world_size, - entry_per_rank, + entry_start, round_robin_size, (cudaStream_t)stream_int); WM_CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh index bfbdb63..1c0cbb8 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh +++ b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh @@ -27,41 +27,107 @@ class nvshmem_device_reference { __device__ __forceinline__ explicit nvshmem_device_reference( const wholememory_nvshmem_ref_t& nvshmem_ref) : pointer_(static_cast(nvshmem_ref.pointer)), - typed_stride_(nvshmem_ref.stride / sizeof(DataTypeT)) + typed_stride_(nvshmem_ref.stride / sizeof(DataTypeT)), + rank_memory_offsets_(nvshmem_ref.rank_memory_offsets), + world_size_(nvshmem_ref.world_size), + same_chunk_(nvshmem_ref.same_chunk) { assert(nvshmem_ref.stride % sizeof(DataTypeT) == 0); + if (!same_chunk_) { + estimated_stride_ = rank_memory_offsets_[world_size_] / world_size_; + cache_rank_ = 0; + cache_offset_ = 0; + cache_size_ = rank_memory_offsets_[1] - rank_memory_offsets_[0]; + } } __device__ nvshmem_device_reference() = delete; __device__ __forceinline__ DataTypeT load(size_t index) { - size_t rank = index / typed_stride_; - - return nvshmem_get(pointer_ + index - rank * typed_stride_, rank); + size_t rank = dest_rank(index); + if (same_chunk_) + return nvshmem_get(pointer_ + index - rank * typed_stride_, rank); + else + return nvshmem_get( + pointer_ + index - rank_memory_offsets_[rank] / sizeof(DataTypeT), rank); } __device__ __forceinline__ void store(size_t index, DataTypeT val) { - size_t rank = index / typed_stride_; - return nvshmem_put(pointer_ + index - rank * typed_stride_, val, rank); + size_t rank = dest_rank(index); + if (same_chunk_) + return nvshmem_put(pointer_ + index - rank * typed_stride_, rank); + else + return nvshmem_put( + pointer_ + index - rank_memory_offsets_[rank] / sizeof(DataTypeT), val, rank); } __device__ __forceinline__ DataTypeT* symmetric_address(size_t index) { - size_t rank = index / typed_stride_; - return pointer_ + index - rank * typed_stride_; + size_t rank = dest_rank(index); + if (same_chunk_) + return pointer_ + index - rank * typed_stride_; + else + return pointer_ + index - rank_memory_offsets_[rank] / sizeof(DataTypeT); + } + + __device__ __forceinline__ void mov_offsets_to_shmem(char* shmem) + { + if (same_chunk_) return; + size_t* shmem_offsets = reinterpret_cast(shmem); + for (int i = threadIdx.x; i <= world_size_; i += blockDim.x) { + shmem_offsets[i] = rank_memory_offsets_[i]; + } + __syncthreads(); + rank_memory_offsets_ = shmem_offsets; } __device__ __forceinline__ size_t dest_rank(size_t index) { - size_t rank = index / typed_stride_; - return rank; + if (same_chunk_) { + return index / typed_stride_; + } else { + size_t rank = 0; + size_t offset = index * sizeof(DataTypeT); + if (offset >= cache_offset_ && offset < cache_offset_ + cache_size_) { + rank = cache_rank_; + } else { + int estimated_rank = max(world_size_ - 1, int(offset / estimated_stride_)); + if (rank_memory_offsets_[estimated_rank] > offset) { + for (int i = estimated_rank - 1; i >= 0; i--) { + if (rank_memory_offsets_[i] <= offset) { + rank = i; + break; + } + } + } else { + for (int i = estimated_rank + 1; i <= world_size_; i++) { + if (rank_memory_offsets_[i] > offset) { + rank = i - 1; + break; + } + } + } + cache_rank_ = rank; + cache_offset_ = rank_memory_offsets_[rank]; + cache_size_ = rank_memory_offsets_[rank + 1] - rank_memory_offsets_[rank]; + } + return rank; + } } private: DataTypeT* pointer_; size_t typed_stride_; + size_t* rank_memory_offsets_; + int world_size_; + + size_t estimated_stride_; + bool same_chunk_; + int cache_rank_; + size_t cache_offset_; + size_t cache_size_; }; } // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int32_indices.cu index 0275ad5..7ce32b2 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int32_indices.cu @@ -31,24 +31,23 @@ void nvshmem_gather_floating_int32_temp_func(wholememory_comm_t wm_comm, void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) { - nvshmem_gather_temp_get_mem_sort_idx_func( - wm_comm, - embeding_nvshmem_ptr, - embedding_desc, - indices, - indice_count, - output, - temp_output, - output_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - gather_sms); + nvshmem_gather_temp_get_mem_sort_idx_func(wm_comm, + embeding_nvshmem_ptr, + embedding_desc, + indices, + indice_count, + output, + temp_output, + output_desc, + embedding_entry_offsets, + p_env_fns, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemGatherFuncFloatingInt32, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_gather_floating_int32_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_gather_floating_int32_func( output, temp_output, output_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, gather_sms); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int64_indices.cu index a794522..6a5f42b 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int64_indices.cu @@ -31,24 +31,23 @@ void nvshmem_gather_floating_int64_temp_func(wholememory_comm_t wm_comm, void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) { - nvshmem_gather_temp_get_mem_sort_idx_func( - wm_comm, - embeding_nvshmem_ptr, - embedding_desc, - indices, - indice_count, - output, - temp_output, - output_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - gather_sms); + nvshmem_gather_temp_get_mem_sort_idx_func(wm_comm, + embeding_nvshmem_ptr, + embedding_desc, + indices, + indice_count, + output, + temp_output, + output_desc, + embedding_entry_offsets, + p_env_fns, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemGatherFuncFloatingInt64, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_gather_floating_int64_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_gather_floating_int64_func( output, temp_output, output_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, gather_sms); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int32_indices.cu index 0979627..65b9c59 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int32_indices.cu @@ -31,24 +31,23 @@ void nvshmem_gather_integer_int32_temp_func(wholememory_comm_t wm_comm, void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) { - nvshmem_gather_temp_get_mem_sort_idx_func( - wm_comm, - embeding_nvshmem_ptr, - embedding_desc, - indices, - indice_count, - output, - temp_output, - output_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - gather_sms); + nvshmem_gather_temp_get_mem_sort_idx_func(wm_comm, + embeding_nvshmem_ptr, + embedding_desc, + indices, + indice_count, + output, + temp_output, + output_desc, + embedding_entry_offsets, + p_env_fns, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemGatherFuncIntegerInt32, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_gather_integer_int32_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_gather_integer_int32_func( output, temp_output, output_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, gather_sms); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int64_indices.cu index a2d49b1..9cfad1b 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int64_indices.cu @@ -31,24 +31,23 @@ void nvshmem_gather_integer_int64_temp_func(wholememory_comm_t wm_comm, void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) { - nvshmem_gather_temp_get_mem_sort_idx_func( - wm_comm, - embeding_nvshmem_ptr, - embedding_desc, - indices, - indice_count, - output, - temp_output, - output_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - gather_sms); + nvshmem_gather_temp_get_mem_sort_idx_func(wm_comm, + embeding_nvshmem_ptr, + embedding_desc, + indices, + indice_count, + output, + temp_output, + output_desc, + embedding_entry_offsets, + p_env_fns, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemGatherFuncIntegerInt64, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_gather_integer_int64_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_gather_integer_int64_func( output, temp_output, output_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, gather_sms); @@ -113,7 +112,7 @@ __global__ void scatter_func_with_nvshmem_sort_idxs_kernel( const int max_blocks_for_local, const int intra_node_ranks, const int node_rank, - size_t embedding_entry_per_rank, + size_t* embedding_entry_offsets, const int threads_per_group); }; // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh index a0091c3..8dbee95 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh @@ -122,22 +122,23 @@ __global__ void gather_func_with_nvshmem_sort_idxs_kernel( const int max_blocks_for_local, const int intra_node_ranks, const int node_rank, - size_t embedding_entry_per_rank, + size_t* embedding_entry_offsets, EmbeddingT* __restrict__ temp_output, wholememory_matrix_description_t output_desc, const int threads_per_group) { - const int64_t local_index_lowerbound = node_rank * intra_node_ranks * embedding_entry_per_rank; + const int64_t local_index_lowerbound = embedding_entry_offsets[node_rank * intra_node_ranks]; const int64_t local_index_upperbound = - (node_rank + 1) * intra_node_ranks * embedding_entry_per_rank; + embedding_entry_offsets[(node_rank + 1) * intra_node_ranks]; const int64_t local_index_start = LowerBound(sorted_index, indice_count, local_index_lowerbound); const int64_t local_index_length = UpperBound( sorted_index + local_index_start, indice_count - local_index_start, local_index_upperbound - 1); - int embedding_size = embedding_desc.sizes[1]; int64_t embedding_stride = embedding_desc.stride; int64_t output_stride = output_desc.stride; + extern __shared__ char shmem[]; nvshmem_device_reference embedding_nvshmem_device_ref{embeding_nvshmem_ref}; + embedding_nvshmem_device_ref.mov_offsets_to_shmem(shmem); if (blockIdx.x >= max_blocks_for_local) { const int64_t thread_id = (blockIdx.x - max_blocks_for_local) * blockDim.x + threadIdx.x; for (int64_t row_id = thread_id; row_id < indice_count - local_index_length; @@ -313,7 +314,7 @@ void nvshmem_gather_temp_get_mem_sort_idx_func(wholememory_comm_t wm_comm, void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) @@ -336,7 +337,6 @@ void nvshmem_gather_temp_get_mem_sort_idx_func(wholememory_comm_t wm_comm, wm_comm, &thrust_allocator, stream); - int intra_node_rank_num = wm_comm->intra_node_rank_num; int node_id = wm_comm->world_rank / wm_comm->intra_node_rank_num; @@ -357,7 +357,7 @@ void nvshmem_gather_temp_get_mem_sort_idx_func(wholememory_comm_t wm_comm, const int, const int, const int, - size_t, + size_t*, EmbeddingT*, wholememory_matrix_description_t, const int) = nullptr; @@ -416,19 +416,21 @@ void nvshmem_gather_temp_get_mem_sort_idx_func(wholememory_comm_t wm_comm, block_threshold = 1; if (num_blocks == 1) num_blocks = 2; } - - gather_nvshmem_kernel_fn<<>>(embeding_nvshmem_ptr, - embedding_desc, - sorted_index, - dev_raw_indice_ptr, - indice_count, - block_threshold, - intra_node_rank_num, - node_id, - embedding_entry_count_per_rank, - ret_data, - temp_output_desc, - num_threads_per_feature); + size_t shared_mem_size = + embeding_nvshmem_ptr.same_chunk ? 0 : ((embeding_nvshmem_ptr.world_size + 1) * sizeof(size_t)); + gather_nvshmem_kernel_fn<<>>( + embeding_nvshmem_ptr, + embedding_desc, + sorted_index, + dev_raw_indice_ptr, + indice_count, + block_threshold, + intra_node_rank_num, + node_id, + embedding_entry_offsets, + ret_data, + temp_output_desc, + num_threads_per_feature); if (!use_ibgda_flag) { nvshmemx_quiet_on_stream(stream); // wait transfer } @@ -467,12 +469,12 @@ __global__ void scatter_func_with_nvshmem_sort_idxs_kernel( const int max_blocks_for_local, const int intra_node_ranks, const int node_rank, - size_t embedding_entry_per_rank, + size_t* embedding_entry_offsets, const int threads_per_group) { - const int64_t local_index_lowerbound = node_rank * intra_node_ranks * embedding_entry_per_rank; + const int64_t local_index_lowerbound = embedding_entry_offsets[node_rank * intra_node_ranks]; const int64_t local_index_upperbound = - (node_rank + 1) * intra_node_ranks * embedding_entry_per_rank; + embedding_entry_offsets[(node_rank + 1) * intra_node_ranks]; const int64_t local_index_start = LowerBound(sorted_index, indice_count, local_index_lowerbound); const int64_t local_index_length = UpperBound( sorted_index + local_index_start, indice_count - local_index_start, local_index_upperbound - 1); @@ -480,7 +482,9 @@ __global__ void scatter_func_with_nvshmem_sort_idxs_kernel( int embedding_size = embedding_desc.sizes[1]; int64_t embedding_stride = embedding_desc.stride; int64_t input_stride = temp_input_desc.stride; + extern __shared__ char shmem[]; nvshmem_device_reference embedding_nvshmem_device_ref{embeding_nvshmem_ref}; + embedding_nvshmem_device_ref.mov_offsets_to_shmem(shmem); if (blockIdx.x >= max_blocks_for_local) { const int64_t thread_id = (blockIdx.x - max_blocks_for_local) * blockDim.x + threadIdx.x; for (int64_t row_id = thread_id; row_id < indice_count - local_index_length; @@ -554,7 +558,7 @@ void nvshmem_scatter_temp_put_mem_sort_idx_func(wholememory_comm_t wm_comm, int64_t indice_count, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) @@ -620,7 +624,7 @@ void nvshmem_scatter_temp_put_mem_sort_idx_func(wholememory_comm_t wm_comm, const int, const int, const int, - size_t, + size_t*, const int) = nullptr; switch (alignment) { @@ -679,19 +683,21 @@ void nvshmem_scatter_temp_put_mem_sort_idx_func(wholememory_comm_t wm_comm, if (num_blocks == 1) num_blocks = 2; } - scatter_nvshmem_kernel_fn<<>>(temp_input_data, - temp_input_desc, - embeding_nvshmem_ptr, - embedding_desc, - sorted_index, - dev_raw_indice_ptr, - indice_count, - block_threshold, - intra_node_rank_num, - node_id, - embedding_entry_count_per_rank, - - num_threads_per_feature); + size_t shared_mem_size = + embeding_nvshmem_ptr.same_chunk ? 0 : ((embeding_nvshmem_ptr.world_size + 1) * sizeof(size_t)); + scatter_nvshmem_kernel_fn<<>>( + temp_input_data, + temp_input_desc, + embeding_nvshmem_ptr, + embedding_desc, + sorted_index, + dev_raw_indice_ptr, + indice_count, + block_threshold, + intra_node_rank_num, + node_id, + embedding_entry_offsets, + num_threads_per_feature); if (!use_ibgda_flag) { nvshmemx_quiet_on_stream(stream); // wait transfer } diff --git a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int32_indices.cu index a946e43..760df03 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int32_indices.cu @@ -31,24 +31,23 @@ void nvshmem_scatter_floating_int32_temp_func(wholememory_comm_t wm_comm, int64_t indice_count, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) { - nvshmem_scatter_temp_put_mem_sort_idx_func( - wm_comm, - input, - temp_input, - input_desc, - indices, - indice_count, - embeding_nvshmem_ptr, - embedding_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - scatter_sms); + nvshmem_scatter_temp_put_mem_sort_idx_func(wm_comm, + input, + temp_input, + input_desc, + indices, + indice_count, + embeding_nvshmem_ptr, + embedding_desc, + embedding_entry_offsets, + p_env_fns, + stream, + scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemScatterFuncFloatingInt32, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_scatter_floating_int32_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_scatter_floating_int32_func( indices_desc.size, embeding_nvshmem_ptr, embedding_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, scatter_sms); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int64_indices.cu index 8930898..5c22a39 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int64_indices.cu @@ -31,24 +31,23 @@ void nvshmem_scatter_floating_int64_temp_func(wholememory_comm_t wm_comm, int64_t indice_count, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) { - nvshmem_scatter_temp_put_mem_sort_idx_func( - wm_comm, - input, - temp_input, - input_desc, - indices, - indice_count, - embeding_nvshmem_ptr, - embedding_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - scatter_sms); + nvshmem_scatter_temp_put_mem_sort_idx_func(wm_comm, + input, + temp_input, + input_desc, + indices, + indice_count, + embeding_nvshmem_ptr, + embedding_desc, + embedding_entry_offsets, + p_env_fns, + stream, + scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemScatterFuncFloatingInt64, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_scatter_floating_int64_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_scatter_floating_int64_func( indices_desc.size, embeding_nvshmem_ptr, embedding_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, scatter_sms); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int32_indices.cu index 5e93509..7532332 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int32_indices.cu @@ -31,24 +31,23 @@ void nvshmem_scatter_integer_int32_temp_func(wholememory_comm_t wm_comm, int64_t indice_count, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) { - nvshmem_scatter_temp_put_mem_sort_idx_func( - wm_comm, - input, - temp_input, - input_desc, - indices, - indice_count, - embeding_nvshmem_ptr, - embedding_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - scatter_sms); + nvshmem_scatter_temp_put_mem_sort_idx_func(wm_comm, + input, + temp_input, + input_desc, + indices, + indice_count, + embeding_nvshmem_ptr, + embedding_desc, + embedding_entry_offsets, + p_env_fns, + stream, + scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemScatterFuncIntegerInt32, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_scatter_integer_int32_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_scatter_integer_int32_func( indices_desc.size, embeding_nvshmem_ptr, embedding_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, scatter_sms); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int64_indices.cu index 7952e43..a17b49d 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int64_indices.cu @@ -31,24 +31,23 @@ void nvshmem_scatter_integer_int64_temp_func(wholememory_comm_t wm_comm, int64_t indice_count, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) { - nvshmem_scatter_temp_put_mem_sort_idx_func( - wm_comm, - input, - temp_input, - input_desc, - indices, - indice_count, - embeding_nvshmem_ptr, - embedding_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - scatter_sms); + nvshmem_scatter_temp_put_mem_sort_idx_func(wm_comm, + input, + temp_input, + input_desc, + indices, + indice_count, + embeding_nvshmem_ptr, + embedding_desc, + embedding_entry_offsets, + p_env_fns, + stream, + scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemScatterFuncIntegerInt64, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_scatter_integer_int64_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_scatter_integer_int64_func( indices_desc.size, embeding_nvshmem_ptr, embedding_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, scatter_sms); diff --git a/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.cu b/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.cu new file mode 100644 index 0000000..caa9667 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.cu @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sort_unique_ids_for_hierarchy_func.h" +#include "sort_unique_indices_func.h" + +#include +#include + +#include +#include +#include +#include + +#include + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "logger.hpp" +#include "wholememory/communicator.hpp" +#include "wholememory/integer_utils.hpp" +#include "wholememory_ops/register.hpp" +#include "wholememory_ops/temp_memory_handle.hpp" +#include + +namespace wholememory_ops { + +template +__global__ void SortUniqueIndiceMapKernel(IndexT* indice_map, + size_t indice_count, + const IndexT* sort_raw_indices, + const int* unique_count_ptr, + const IndexT* unique_offset_ptr, + size_t num_unique) +{ + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count; + idx += blockDim.x * gridDim.x) { + if (idx >= num_unique) break; + IndexT offset = unique_offset_ptr[idx]; + int count = unique_count_ptr[idx]; + for (IndexT i = offset; i < offset + count; i++) { + indice_map[sort_raw_indices[i]] = idx; + } + } +} + +template +void SortUniqueIndicesMapTempFunc(void* indice_map, + wholememory_array_description_t indice_desc, + const void* sort_raw_indices, + const int* unique_count_ptr, + size_t num_unique, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + static constexpr int BLOCK_SIZE = 128; + int block_count = wholememory::div_rounding_up_unsafe(num_unique, BLOCK_SIZE); + + temp_memory_handle dev_unique_offset_handle(p_env_fns); + IndexT* unique_offset_ptr = + static_cast(dev_unique_offset_handle.device_malloc(num_unique, indice_desc.dtype)); + IndexT* indice_map_ptr = static_cast(indice_map); + const IndexT* sort_raw_indices_ptr = static_cast(sort_raw_indices); + + void* cub_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceScan::ExclusiveSum( + cub_temp_storage, temp_storage_bytes, unique_count_ptr, unique_offset_ptr, num_unique, stream); + cub_temp_storage = p_thrust_allocator->allocate(temp_storage_bytes); + cub::DeviceScan::ExclusiveSum( + cub_temp_storage, temp_storage_bytes, unique_count_ptr, unique_offset_ptr, num_unique, stream); + SortUniqueIndiceMapKernel<<>>(indice_map_ptr, + indice_desc.size, + sort_raw_indices_ptr, + unique_count_ptr, + unique_offset_ptr, + num_unique); + p_thrust_allocator->deallocate(reinterpret_cast(cub_temp_storage), temp_storage_bytes); +} + +REGISTER_DISPATCH_ONE_TYPE(SortUniqueIndicesMapTempFunc, SortUniqueIndicesMapTempFunc, SINT3264) + +wholememory_error_code_t sort_unique_ids_for_hierarchy_func( + void* indices, + wholememory_array_description_t indice_desc, + temp_memory_handle* output_indices_handle, + wholememory_array_description_t* output_indices_desc, + temp_memory_handle* dev_indice_map_handle, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + if (indice_desc.size == 0) { + *output_indices_desc = wholememory_create_array_desc(0, 0, indice_desc.dtype); + return WHOLEMEMORY_SUCCESS; + } + int num_runs = 0; + temp_memory_handle unique_count_handle(p_env_fns); + temp_memory_handle dev_sort_raw_indices_handle(p_env_fns); + void* dev_sort_raw_indices_ptr = + dev_sort_raw_indices_handle.device_malloc(indice_desc.size, indice_desc.dtype); + sort_unique_indices_func(indices, + indice_desc, + dev_sort_raw_indices_ptr, + &num_runs, + output_indices_handle, + &unique_count_handle, + p_thrust_allocator, + p_env_fns, + stream); + *output_indices_desc = wholememory_create_array_desc(num_runs, 0, indice_desc.dtype); + void* dev_indice_map_ptr = + dev_indice_map_handle->device_malloc(indice_desc.size, indice_desc.dtype); + WM_CUDA_CHECK(cudaGetLastError()); + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + SortUniqueIndicesMapTempFunc, + dev_indice_map_ptr, + indice_desc, + dev_sort_raw_indices_ptr, + static_cast(unique_count_handle.pointer()), + num_runs, + p_thrust_allocator, + p_env_fns, + stream); + } catch (...) { + WHOLEMEMORY_FAIL_NOTHROW("map indices failed"); + } + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h b/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h new file mode 100644 index 0000000..8491e58 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "wholememory_ops/temp_memory_handle.hpp" +#include +#include +#include + +namespace wholememory_ops { + +wholememory_error_code_t sort_unique_ids_for_hierarchy_func( + void* indices, + wholememory_array_description_t indice_desc, + temp_memory_handle* output_indices_handle, + wholememory_array_description_t* output_indices_desc, + temp_memory_handle* dev_indice_map_handle, // indice_desc + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_unique_indices_func.cu b/cpp/src/wholememory_ops/functions/sort_unique_indices_func.cu new file mode 100644 index 0000000..a3d3fc6 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_unique_indices_func.cu @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sort_indices_func.h" +#include "sort_unique_indices_func.h" + +#include +#include +#include + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "logger.hpp" +#include "wholememory_ops/register.hpp" + +namespace wholememory_ops { + +template +void SortUniqueIndicesTempFunc(const void* indices, + wholememory_array_description_t indice_desc, + void* sort_raw_indices, + int* num_runs, + temp_memory_handle* unique_indices_handle, + temp_memory_handle* unique_count_handle, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + if (indice_desc.size == 0) return; + wm_thrust_allocator& allocator = *p_thrust_allocator; + WHOLEMEMORY_CHECK_NOTHROW(indice_desc.storage_offset == 0); + temp_memory_handle sorted_indices_handle(p_env_fns); + sorted_indices_handle.device_malloc(indice_desc.size, indice_desc.dtype); + IndexT* sorted_indices = static_cast(sorted_indices_handle.pointer()); + + sort_indices_func( + indices, indice_desc, sorted_indices, sort_raw_indices, p_thrust_allocator, p_env_fns, stream); + + unique_indices_handle->device_malloc(indice_desc.size, indice_desc.dtype); + unique_count_handle->device_malloc(indice_desc.size, WHOLEMEMORY_DT_INT); + IndexT* unique_indices = static_cast(unique_indices_handle->pointer()); + int* unique_counts = static_cast(unique_count_handle->pointer()); + temp_memory_handle number_runs_handle(p_env_fns); + number_runs_handle.device_malloc(1, WHOLEMEMORY_DT_INT); + int* number_runs = static_cast(number_runs_handle.pointer()); + void* cub_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + cub::DeviceRunLengthEncode::Encode(cub_temp_storage, + temp_storage_bytes, + sorted_indices, + unique_indices, + unique_counts, + number_runs, + indice_desc.size, + stream); + cub_temp_storage = allocator.allocate(temp_storage_bytes); + cub::DeviceRunLengthEncode::Encode(cub_temp_storage, + temp_storage_bytes, + sorted_indices, + unique_indices, + unique_counts, + number_runs, + indice_desc.size, + stream); + WM_CUDA_CHECK_NO_THROW( + cudaMemcpyAsync(num_runs, number_runs, sizeof(int), cudaMemcpyDeviceToHost, stream)); +} + +REGISTER_DISPATCH_ONE_TYPE(SortUniqueIndicesTempFunc, SortUniqueIndicesTempFunc, SINT3264) + +wholememory_error_code_t sort_unique_indices_func(const void* indices, + wholememory_array_description_t indice_desc, + void* sort_raw_indices, + int* num_runs, + temp_memory_handle* unique_indices_handle, + temp_memory_handle* unique_count_handle, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + SortUniqueIndicesTempFunc, + indices, + indice_desc, + sort_raw_indices, + num_runs, + unique_indices_handle, + unique_count_handle, + p_thrust_allocator, + p_env_fns, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("sort_unique_indices_func CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } catch (wholememory::logic_error& wle) { + WHOLEMEMORY_ERROR("sort_unique_indices_func LOGIC Error %s\n", wle.what()); + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_UNKNOW_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_unique_indices_func.h b/cpp/src/wholememory_ops/functions/sort_unique_indices_func.h new file mode 100644 index 0000000..2ff697c --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_unique_indices_func.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include + +namespace wholememory_ops { + +wholememory_error_code_t sort_unique_indices_func(const void* indices, + wholememory_array_description_t indice_desc, + void* sort_raw_indices, + int* num_runs, + temp_memory_handle* unique_indices_handle, + temp_memory_handle* unique_count_handle, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/gather_op.cpp b/cpp/src/wholememory_ops/gather_op.cpp index 98d41d2..9444071 100644 --- a/cpp/src/wholememory_ops/gather_op.cpp +++ b/cpp/src/wholememory_ops/gather_op.cpp @@ -93,6 +93,19 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten gather_sms); } + if (has_handle && memory_type == WHOLEMEMORY_MT_HIERARCHY) { + return wholememory_ops::wholememory_gather_hierarchy( + wholememory_tensor_get_memory_handle(wholememory_tensor), + matrix_description, + indices, + indices_desc, + output, + output_desc, + p_env_fns, + static_cast(stream), + gather_sms); + } + WHOLEMEMORY_EXPECTS_NOTHROW(!has_handle || memory_type == WHOLEMEMORY_MT_CHUNKED || memory_type == WHOLEMEMORY_MT_CONTINUOUS, "Memory type not supported."); diff --git a/cpp/src/wholememory_ops/gather_op_impl.h b/cpp/src/wholememory_ops/gather_op_impl.h index 21896ff..19f3c08 100644 --- a/cpp/src/wholememory_ops/gather_op_impl.h +++ b/cpp/src/wholememory_ops/gather_op_impl.h @@ -42,6 +42,17 @@ wholememory_error_code_t wholememory_gather_nccl(wholememory_handle_t wholememor cudaStream_t stream, int gather_sms); +wholememory_error_code_t wholememory_gather_hierarchy( + wholememory_handle_t wholememory_handle, + wholememory_matrix_description_t wholememory_desc, + void* indices, + wholememory_array_description_t indice_desc, + void* output, + wholememory_matrix_description_t output_desc, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream, + int gather_sms); + wholememory_error_code_t wholememory_gather_distributed( wholememory_handle_t wholememory_handle, wholememory_matrix_description_t wholememory_desc, diff --git a/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu new file mode 100644 index 0000000..543bdfd --- /dev/null +++ b/cpp/src/wholememory_ops/gather_op_impl_hierarchy.cu @@ -0,0 +1,360 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include + +#include "logger.hpp" +#include "wholememory/communicator.hpp" +#include "wholememory/memory_handle.hpp" +#include "wholememory_ops/functions/bucket_ids_for_hierarchy_func.h" +#include "wholememory_ops/functions/exchange_embeddings_nccl_func.h" +#include "wholememory_ops/functions/gather_scatter_func.h" +#include "wholememory_ops/functions/sort_unique_ids_for_hierarchy_func.h" +#include "wholememory_ops/gather_op_impl.h" +#include "wholememory_ops/temp_memory_handle.hpp" +#include "wholememory_ops/thrust_allocator.hpp" + +namespace wholememory_ops { + +static wholememory_error_code_t wholememory_cross_gather( + wholememory_handle_t wholememory_handle, + wholememory_matrix_description_t wholememory_desc, + void* indices, + wholememory_array_description_t indice_desc, + void* output, + wholememory_matrix_description_t output_desc, + int64_t* host_bucket_id_count_ptr, + wholememory_comm_t wm_local_comm, + wholememory_comm_t wm_cross_comm, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream, + int gather_sms) +{ + int cross_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); + std::vector host_bucket_id_offset(cross_size); + std::vector host_recv_id_count(cross_size, 0); + std::vector host_recv_id_offset(cross_size); + // exchange node count + wm_cross_comm->host_alltoall( + host_bucket_id_count_ptr, host_recv_id_count.data(), 1, WHOLEMEMORY_DT_INT64); + host_bucket_id_offset[0] = 0; + for (int i = 1; i < cross_size; i++) + host_bucket_id_offset[i] = host_bucket_id_offset[i - 1] + host_bucket_id_count_ptr[i - 1]; + wm_cross_comm->sync_stream(); + // exchange indices + int64_t total_recv_count = 0; + for (int i = 0; i < cross_size; i++) { + host_recv_id_offset[i] = total_recv_count; + total_recv_count += host_recv_id_count[i]; + } + temp_memory_handle dev_recv_bucket_indices_handle(p_env_fns); + void* dev_recv_bucket_indices_ptr = + dev_recv_bucket_indices_handle.device_malloc(total_recv_count, indice_desc.dtype); + wm_cross_comm->alltoallv(indices, + dev_recv_bucket_indices_ptr, + reinterpret_cast(host_bucket_id_count_ptr), + reinterpret_cast(host_bucket_id_offset.data()), + reinterpret_cast(host_recv_id_count.data()), + reinterpret_cast(host_recv_id_offset.data()), + indice_desc.dtype, + stream); + wm_cross_comm->sync_stream(stream); + // local gather + temp_memory_handle dev_local_gather_buffer_handle(p_env_fns); + void* dev_local_gather_buffer_ptr = dev_local_gather_buffer_handle.device_malloc( + wholememory_desc.sizes[1] * total_recv_count, output_desc.dtype); + int64_t local_gather_buffer_size[2] = {total_recv_count, wholememory_desc.sizes[1]}; + wholememory_matrix_description_t local_gather_buffer_desc = wholememory_create_matrix_desc( + local_gather_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); + void* local_fake_ptr = nullptr; + size_t local_mem_offset, local_mem_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_local_memory( + &local_fake_ptr, &local_mem_size, &local_mem_offset, wholememory_handle)); + local_fake_ptr = static_cast(local_fake_ptr) - local_mem_offset; + wholememory_gref_t local_fake_gref = + wholememory_create_continuous_global_reference(local_fake_ptr); + auto local_gather_indice_desc = + wholememory_create_array_desc(total_recv_count, 0, indice_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(local_fake_gref, + wholememory_desc, + dev_recv_bucket_indices_ptr, + local_gather_indice_desc, + dev_local_gather_buffer_ptr, + local_gather_buffer_desc, + stream, + gather_sms)); + // exchange embeddings + size_t output_embedding_size = + wholememory_desc.sizes[1] * wholememory_dtype_get_element_size(output_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(exchange_embeddings_nccl_func(dev_local_gather_buffer_ptr, + host_recv_id_count.data(), + host_bucket_id_count_ptr, + output, + output_embedding_size, + wm_cross_comm, + stream)); + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t wholememory_gather_hierarchy( + wholememory_handle_t wholememory_handle, + wholememory_matrix_description_t wholememory_desc, + void* indices, + wholememory_array_description_t indice_desc, + void* output, + wholememory_matrix_description_t output_desc, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream, + int gather_sms) +{ + try { + if (wholememory_desc.storage_offset < 0 || + wholememory_desc.storage_offset + wholememory_desc.sizes[1] > wholememory_desc.stride) { + return WHOLEMEMORY_INVALID_INPUT; + } + bool sort_unique_indices = true; + + wm_thrust_allocator thrust_allocator(p_env_fns); + + wholememory_comm_t wm_global_comm; + int world_size, world_rank; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_global_comm, wholememory_handle)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_global_comm)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&world_rank, wm_global_comm)); + + wholememory_comm_t wm_local_comm; + int local_size, local_rank; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_local_communicator(&wm_local_comm, wholememory_handle)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&local_size, wm_local_comm)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&local_rank, wm_local_comm)); + + wholememory_comm_t wm_cross_comm; + int cross_size; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_cross_communicator(&wm_cross_comm, wholememory_handle)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&cross_size, wm_cross_comm)); + WHOLEMEMORY_CHECK_NOTHROW(world_size == local_size * cross_size); + + size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); + size_t embedding_entry_size = element_size * wholememory_desc.stride; + temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + std::vector host_embedding_entry_offsets(world_size + 1); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_rank_partition_offsets( + host_embedding_entry_offsets.data(), wholememory_handle)); + for (int i = 0; i < world_size + 1; i++) { + size_t offset = host_embedding_entry_offsets[i]; + WHOLEMEMORY_EXPECTS_NOTHROW( + offset % embedding_entry_size == 0, + "embedding memory offset of rank%d=%ld is not multiple of embedding_entry_size=%ldx%ld", + i, + offset, + element_size, + wholememory_desc.stride); + host_embedding_entry_offsets[i] /= embedding_entry_size; + } + + WM_CUDA_CHECK(cudaMemcpyAsync(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets.data(), + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice, + stream)); + + temp_memory_handle dev_bucket_indices_handle(p_env_fns); + void* dev_bucket_indices_ptr = + dev_bucket_indices_handle.device_malloc(indice_desc.size, indice_desc.dtype); + temp_memory_handle dev_bucket_ids_map_handle(p_env_fns); + void* dev_bucket_ids_map_ptr = + dev_bucket_ids_map_handle.device_malloc(indice_desc.size, indice_desc.dtype); + + std::vector host_bucket_id_count(local_size, 0); + std::vector host_bucket_id_offset(local_size); + std::vector host_recv_id_count(local_size, 0); + std::vector host_recv_id_offset(local_size); + + // bucket indices + WHOLEMEMORY_RETURN_ON_FAIL( + bucket_and_reorder_ids_for_hierarchy_func(indices, + indice_desc, + dev_bucket_indices_ptr, + dev_bucket_ids_map_ptr, + host_bucket_id_count.data(), + dev_embedding_entry_offsets_ptr, + wm_global_comm, + wm_local_comm, + 0, + &thrust_allocator, + p_env_fns, + stream)); + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); + // exchange node count + wm_local_comm->host_alltoall( + host_bucket_id_count.data(), host_recv_id_count.data(), 1, WHOLEMEMORY_DT_INT64); + host_bucket_id_offset[0] = 0; + for (int i = 1; i < local_size; i++) + host_bucket_id_offset[i] = host_bucket_id_offset[i - 1] + host_bucket_id_count[i - 1]; + wm_local_comm->sync_stream(); + // exchange indices + int64_t total_recv_count = 0; + for (int i = 0; i < local_size; i++) { + host_recv_id_offset[i] = total_recv_count; + total_recv_count += host_recv_id_count[i]; + } + temp_memory_handle dev_recv_bucket_indices_handle(p_env_fns); + void* dev_recv_bucket_indices_ptr = + dev_recv_bucket_indices_handle.device_malloc(total_recv_count, indice_desc.dtype); + auto recv_bucket_indices_desc = + wholememory_create_array_desc(total_recv_count, 0, indice_desc.dtype); + wm_local_comm->alltoallv(dev_bucket_indices_ptr, + dev_recv_bucket_indices_ptr, + reinterpret_cast(host_bucket_id_count.data()), + reinterpret_cast(host_bucket_id_offset.data()), + reinterpret_cast(host_recv_id_count.data()), + reinterpret_cast(host_recv_id_offset.data()), + indice_desc.dtype, + stream); + wm_local_comm->sync_stream(stream); + WM_CUDA_CHECK(cudaGetLastError()); + // sort unique / bucket recv indices + temp_memory_handle cross_gather_indices_handle(p_env_fns); + wholememory_array_description_t cross_gather_indices_desc; + temp_memory_handle dev_cross_gather_id_map_handle(p_env_fns); + std::vector host_cross_bucket_id_count(cross_size, 0); + if (sort_unique_indices) { + sort_unique_ids_for_hierarchy_func(dev_recv_bucket_indices_ptr, + recv_bucket_indices_desc, + &cross_gather_indices_handle, + &cross_gather_indices_desc, + &dev_cross_gather_id_map_handle, + &thrust_allocator, + p_env_fns, + stream); + bucket_local_ids_func(cross_gather_indices_handle.pointer(), + cross_gather_indices_desc, + host_cross_bucket_id_count.data(), + dev_embedding_entry_offsets_ptr, + wm_local_comm, + wm_cross_comm, + &thrust_allocator, + p_env_fns, + stream); + } else { + void* cross_gather_indices_ptr = cross_gather_indices_handle.device_malloc( + recv_bucket_indices_desc.size, recv_bucket_indices_desc.dtype); + void* dev_cross_gather_id_map_ptr = dev_cross_gather_id_map_handle.device_malloc( + recv_bucket_indices_desc.size, recv_bucket_indices_desc.dtype); + cross_gather_indices_desc = recv_bucket_indices_desc; + WHOLEMEMORY_RETURN_ON_FAIL( + bucket_and_reorder_ids_for_hierarchy_func(dev_recv_bucket_indices_ptr, + recv_bucket_indices_desc, + cross_gather_indices_ptr, + dev_cross_gather_id_map_ptr, + host_cross_bucket_id_count.data(), + dev_embedding_entry_offsets_ptr, + wm_global_comm, + wm_local_comm, + 1, + &thrust_allocator, + p_env_fns, + stream)); + } + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); + // cross gather + temp_memory_handle dev_cross_gather_buffer_handle(p_env_fns); + void* dev_cross_gather_buffer_ptr = dev_cross_gather_buffer_handle.device_malloc( + wholememory_desc.sizes[1] * cross_gather_indices_desc.size, output_desc.dtype); + int64_t cross_gather_buffer_size[2] = {cross_gather_indices_desc.size, + wholememory_desc.sizes[1]}; + wholememory_matrix_description_t cross_gather_buffer_desc = wholememory_create_matrix_desc( + cross_gather_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); + wholememory_cross_gather(wholememory_handle, + wholememory_desc, + cross_gather_indices_handle.pointer(), + cross_gather_indices_desc, + dev_cross_gather_buffer_ptr, + cross_gather_buffer_desc, + host_cross_bucket_id_count.data(), + wm_local_comm, + wm_cross_comm, + &thrust_allocator, + p_env_fns, + stream, + gather_sms); + // cross gather reorder + temp_memory_handle dev_embedding_map_buffer_handle(p_env_fns); + void* dev_embedding_map_buffer_ptr = dev_embedding_map_buffer_handle.device_malloc( + wholememory_desc.sizes[1] * total_recv_count, output_desc.dtype); + int64_t embedding_map_buffer_size[2] = {total_recv_count, wholememory_desc.sizes[1]}; + wholememory_matrix_description_t embedding_map_buffer_desc = wholememory_create_matrix_desc( + embedding_map_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); + wholememory_gref_t cross_gather_fake_gref = + wholememory_create_continuous_global_reference(dev_cross_gather_buffer_ptr); + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(cross_gather_fake_gref, + cross_gather_buffer_desc, + dev_cross_gather_id_map_handle.pointer(), + recv_bucket_indices_desc, + dev_embedding_map_buffer_ptr, + embedding_map_buffer_desc, + stream, + gather_sms)); + // exchange embeddings + size_t output_embedding_size = + wholememory_desc.sizes[1] * wholememory_dtype_get_element_size(output_desc.dtype); + temp_memory_handle dev_recv_embedding_buffer_handle(p_env_fns); + void* dev_recv_embedding_buffer_ptr = dev_recv_embedding_buffer_handle.device_malloc( + wholememory_desc.sizes[1] * indice_desc.size, output_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(exchange_embeddings_nccl_func(dev_embedding_map_buffer_ptr, + host_recv_id_count.data(), + host_bucket_id_count.data(), + dev_recv_embedding_buffer_ptr, + output_embedding_size, + wm_local_comm, + stream)); + // bucket reorder + wholememory_gref_t recv_embedding_buffer_fake_gref = + wholememory_create_continuous_global_reference(dev_recv_embedding_buffer_ptr); + int64_t recv_embedding_buffer_size[2] = {indice_desc.size, wholememory_desc.sizes[1]}; + wholememory_matrix_description_t recv_embedding_buffer_desc = wholememory_create_matrix_desc( + recv_embedding_buffer_size, wholememory_desc.sizes[1], 0, output_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(recv_embedding_buffer_fake_gref, + recv_embedding_buffer_desc, + dev_bucket_ids_map_ptr, + indice_desc, + output, + output_desc, + stream, + gather_sms)); + WM_CUDA_CHECK(cudaGetLastError()); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("CUDA logic Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } catch (wholememory::logic_error& wle) { + WHOLEMEMORY_ERROR("LOGIC Error %s\n", wle.what()); + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_UNKNOW_ERROR; + } + + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/gather_op_impl_nccl.cu b/cpp/src/wholememory_ops/gather_op_impl_nccl.cu index ddcd0e9..1b8abaa 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_nccl.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_nccl.cu @@ -49,22 +49,9 @@ wholememory_error_code_t wholememory_gather_nccl(wholememory_handle_t wholememor wm_thrust_allocator thrust_allocator(p_env_fns); - size_t embedding_size_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL( - wholememory_get_partition_plan(&embedding_size_per_rank, wholememory_handle)); - size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); size_t embedding_entry_size = element_size * wholememory_desc.stride; - WHOLEMEMORY_EXPECTS_NOTHROW( - embedding_size_per_rank % embedding_entry_size == 0, - "embedding_size_per_rank=%ld is not multiple of embedding_entry_size=%ldx%ld", - embedding_size_per_rank, - element_size, - wholememory_desc.stride); - - size_t embedding_entry_count_per_rank = embedding_size_per_rank / embedding_entry_size; - wholememory_comm_t wm_comm; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, wholememory_handle)); @@ -83,18 +70,44 @@ wholememory_error_code_t wholememory_gather_nccl(wholememory_handle_t wholememor static_cast(dev_raw_indice.device_malloc(indice_desc.size, WHOLEMEMORY_DT_INT64)); int64_t total_recv_count = 0; + + temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_rank_partition_offsets(host_embedding_entry_offsets_ptr, wholememory_handle)); + for (int i = 0; i < world_size + 1; i++) { + size_t offset = host_embedding_entry_offsets_ptr[i]; + WHOLEMEMORY_EXPECTS_NOTHROW( + offset % embedding_entry_size == 0, + "embedding memory offset of rank%d=%ld is not multiple of embedding_entry_size=%ldx%ld", + i, + offset, + element_size, + wholememory_desc.stride); + host_embedding_entry_offsets_ptr[i] /= embedding_entry_size; + } + + WM_CUDA_CHECK(cudaMemcpyAsync(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice, + stream)); WHOLEMEMORY_RETURN_ON_FAIL(bucket_and_exchange_ids_func(indices, indice_desc, host_recv_rank_id_count_ptr, host_rank_id_count_ptr, &dev_recv_indice_buffer, dev_raw_indice_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, wm_comm, &thrust_allocator, p_env_fns, stream)); - // Local Gather for (int i = 0; i < world_size; i++) { total_recv_count += host_recv_rank_id_count_ptr[i]; diff --git a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu index 36e4efb..789dcd4 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu @@ -53,7 +53,7 @@ wholememory_error_code_t nvshmem_gather_floating_int32_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms); @@ -66,7 +66,7 @@ wholememory_error_code_t nvshmem_gather_floating_int64_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms); @@ -80,7 +80,7 @@ wholememory_error_code_t nvshmem_gather_integer_int64_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms); @@ -93,7 +93,7 @@ wholememory_error_code_t nvshmem_gather_integer_int32_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms); @@ -119,24 +119,41 @@ wholememory_error_code_t wholememory_gather_nvshmem( embedding_is_float == output_is_float, "embedding and output should be same number type, e.g. floating number or integer number."); if (indice_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } - size_t embedding_size_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL( - wholememory_get_partition_plan(&embedding_size_per_rank, wholememory_handle)); - size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); - size_t embedding_entry_size = element_size * wholememory_desc.stride; + wholememory_comm_t wm_comm; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, wholememory_handle)); - WHOLEMEMORY_EXPECTS_NOTHROW( - embedding_size_per_rank % embedding_entry_size == 0, - "embedding_size_per_rank=%ld is not multiple of embedding_entry_size=%ldx%ld", - embedding_size_per_rank, - element_size, - wholememory_desc.stride); + int world_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_comm)); - size_t embedding_entry_count_per_rank = embedding_size_per_rank / embedding_entry_size; + temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); - wholememory_comm_t wm_comm; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, wholememory_handle)); + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_rank_partition_offsets(host_embedding_entry_offsets_ptr, wholememory_handle)); + + size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); + size_t embedding_entry_size = element_size * wholememory_desc.stride; + for (int i = 0; i < world_size + 1; i++) { + size_t offset = host_embedding_entry_offsets_ptr[i]; + WHOLEMEMORY_EXPECTS_NOTHROW( + offset % embedding_entry_size == 0, + "embedding memory offset of rank%d=%ld is not multiple of embedding_entry_size=%ldx%ld", + i, + offset, + element_size, + wholememory_desc.stride); + host_embedding_entry_offsets_ptr[i] /= embedding_entry_size; + } + WM_CUDA_CHECK(cudaMemcpyAsync(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice, + stream)); wholememory_nvshmem_ref_t embedding_nvshmem_ref; WHOLEMEMORY_RETURN_ON_FAIL( @@ -161,7 +178,7 @@ wholememory_error_code_t wholememory_gather_nvshmem( void*, void*, wholememory_matrix_description_t, - size_t, + size_t*, wholememory_env_func_t*, cudaStream_t, int) = nullptr; @@ -187,7 +204,7 @@ wholememory_error_code_t wholememory_gather_nvshmem( output, temp_output_ptr, output_desc, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, p_env_fns, stream, gather_sms); diff --git a/cpp/src/wholememory_ops/scatter_op_impl.nvshmem.cu b/cpp/src/wholememory_ops/scatter_op_impl.nvshmem.cu index 0eac3c2..77926df 100644 --- a/cpp/src/wholememory_ops/scatter_op_impl.nvshmem.cu +++ b/cpp/src/wholememory_ops/scatter_op_impl.nvshmem.cu @@ -49,7 +49,7 @@ wholememory_error_code_t nvshmem_scatter_floating_int32_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms); @@ -63,7 +63,7 @@ wholememory_error_code_t nvshmem_scatter_floating_int64_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms); @@ -77,7 +77,7 @@ wholememory_error_code_t nvshmem_scatter_integer_int32_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms); @@ -91,7 +91,7 @@ wholememory_error_code_t nvshmem_scatter_integer_int64_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms); @@ -122,29 +122,41 @@ wholememory_error_code_t wholememory_scatter_nvshmem( return WHOLEMEMORY_INVALID_INPUT; } - size_t embedding_size_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL( - wholememory_get_partition_plan(&embedding_size_per_rank, wholememory_handle)); - - size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); - size_t embedding_entry_size = element_size * wholememory_desc.stride; - - WHOLEMEMORY_EXPECTS_NOTHROW( - embedding_size_per_rank % embedding_entry_size == 0, - "embedding_size_per_rank=%ld is not multiple of embedding_entry_size=%ldx%ld", - embedding_size_per_rank, - element_size, - wholememory_desc.stride); - - size_t embedding_entry_count_per_rank = embedding_size_per_rank / embedding_entry_size; - wholememory_comm_t wm_comm; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, wholememory_handle)); int world_size; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_comm)); - int world_rank; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&world_rank, wm_comm)); + + temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_rank_partition_offsets(host_embedding_entry_offsets_ptr, wholememory_handle)); + + size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); + size_t embedding_entry_size = element_size * wholememory_desc.stride; + for (int i = 0; i < world_size + 1; i++) { + size_t offset = host_embedding_entry_offsets_ptr[i]; + WHOLEMEMORY_EXPECTS_NOTHROW( + offset % embedding_entry_size == 0, + "embedding memory offset of rank%d=%ld is not multiple of embedding_entry_size=%ldx%ld", + i, + offset, + element_size, + wholememory_desc.stride); + host_embedding_entry_offsets_ptr[i] /= embedding_entry_size; + } + WM_CUDA_CHECK(cudaMemcpyAsync(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice, + stream)); + wholememory_nvshmem_ref_t embedding_nvshmem_ref; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_get_nvshmem_reference(&embedding_nvshmem_ref, wholememory_handle)); @@ -168,7 +180,7 @@ wholememory_error_code_t wholememory_scatter_nvshmem( wholememory_array_description_t, wholememory_nvshmem_ref_t, wholememory_matrix_description_t, - size_t, + size_t*, wholememory_env_func_t*, cudaStream_t, int); @@ -195,7 +207,7 @@ wholememory_error_code_t wholememory_scatter_nvshmem( indices_desc, embedding_nvshmem_ref, wholememory_desc, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, p_env_fns, stream, scatter_sms); diff --git a/cpp/src/wholememory_ops/scatter_op_impl_nccl.cu b/cpp/src/wholememory_ops/scatter_op_impl_nccl.cu index dd3bd29..95e2fe6 100644 --- a/cpp/src/wholememory_ops/scatter_op_impl_nccl.cu +++ b/cpp/src/wholememory_ops/scatter_op_impl_nccl.cu @@ -53,22 +53,9 @@ wholememory_error_code_t wholememory_scatter_nccl(void* input, wm_thrust_allocator thrust_allocator(p_env_fns); - size_t embedding_size_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL( - wholememory_get_partition_plan(&embedding_size_per_rank, wholememory_handle)); - size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); size_t embedding_entry_size = element_size * wholememory_desc.stride; - WHOLEMEMORY_EXPECTS_NOTHROW( - embedding_size_per_rank % embedding_entry_size == 0, - "embedding_size_per_rank=%ld is not multiple of embedding_entry_size=%ldx%ld", - embedding_size_per_rank, - element_size, - wholememory_desc.stride); - - size_t embedding_entry_count_per_rank = embedding_size_per_rank / embedding_entry_size; - wholememory_comm_t wm_comm; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, wholememory_handle)); @@ -87,13 +74,39 @@ wholememory_error_code_t wholememory_scatter_nccl(void* input, static_cast(dev_raw_indice.device_malloc(indices_desc.size, WHOLEMEMORY_DT_INT64)); int64_t total_recv_count = 0; + + temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_rank_partition_offsets(host_embedding_entry_offsets_ptr, wholememory_handle)); + for (int i = 0; i < world_size + 1; i++) { + size_t offset = host_embedding_entry_offsets_ptr[i]; + WHOLEMEMORY_EXPECTS_NOTHROW( + offset % embedding_entry_size == 0, + "embedding memory offset of rank%d=%ld is not multiple of embedding_entry_size=%ldx%ld", + i, + offset, + element_size, + wholememory_desc.stride); + host_embedding_entry_offsets_ptr[i] /= embedding_entry_size; + } + WM_CUDA_CHECK(cudaMemcpyAsync(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice, + stream)); WHOLEMEMORY_RETURN_ON_FAIL(bucket_and_exchange_ids_func(indices, indices_desc, host_recv_rank_id_count_ptr, host_rank_id_count_ptr, &dev_recv_indice_buffer, dev_raw_indice_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, wm_comm, &thrust_allocator, p_env_fns, diff --git a/cpp/tests/wholememory_ops/embedding_test_utils.cu b/cpp/tests/wholememory_ops/embedding_test_utils.cu index 2da4d26..8cd286f 100644 --- a/cpp/tests/wholememory_ops/embedding_test_utils.cu +++ b/cpp/tests/wholememory_ops/embedding_test_utils.cu @@ -528,5 +528,22 @@ void host_random_init_float(float* data, int64_t len, float max_value, float min } } +void host_random_partition(size_t* partition_sizes, size_t total_size, int partition_count) +{ + std::default_random_engine random_engine(0); + std::uniform_int_distribution uniform(90, 100); + size_t acc_size = 0; + size_t random_sum = 0; + for (int i = 0; i < partition_count; i++) { + partition_sizes[i] = (size_t)uniform(random_engine); + random_sum += partition_sizes[i]; + } + for (int i = 0; i < partition_count; i++) { + partition_sizes[i] = (size_t)((partition_sizes[i] / (double)random_sum) * total_size); + acc_size += partition_sizes[i]; + } + partition_sizes[0] += total_size - acc_size; +} + } // namespace testing } // namespace wholememory_ops diff --git a/cpp/tests/wholememory_ops/embedding_test_utils.hpp b/cpp/tests/wholememory_ops/embedding_test_utils.hpp index cd8edcd..62a02ce 100644 --- a/cpp/tests/wholememory_ops/embedding_test_utils.hpp +++ b/cpp/tests/wholememory_ops/embedding_test_utils.hpp @@ -63,5 +63,7 @@ void host_check_embedding_same(void* host_embedding, void host_random_init_float(float* data, int64_t len, float max_value, float min_value); +void host_random_partition(size_t* partition_sizes, size_t total_size, int partition_count); + } // namespace testing } // namespace wholememory_ops diff --git a/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu b/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu index 83e57e1..92ab240 100644 --- a/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu @@ -114,6 +114,7 @@ struct EmbeddingBackwardTestParams { WHOLEMEMORY_SUCCESS); return cache_policy; } + int get_rank_partition_method() const { return rank_partition_method; } EmbeddingBackwardTestParams& set_use_cache() { use_cache = true; @@ -139,6 +140,11 @@ struct EmbeddingBackwardTestParams { lr_ = lr; return *this; } + EmbeddingBackwardTestParams& use_random_partition() + { + rank_partition_method = 1; + return *this; + } wholememory_array_description_t indice_description; wholememory_matrix_description_t embedding_description; wholememory_matrix_description_t grad_description; @@ -154,6 +160,7 @@ struct EmbeddingBackwardTestParams { float lr_ = 0.1; std::map optimizer_params; + int rank_partition_method = 0; // 0-default, 1-random }; class WholeMemoryEmbeddingBackwardParameterTests @@ -586,13 +593,18 @@ TEST_P(WholeMemoryEmbeddingBackwardParameterTests, EmbeddingGatherGradientApplyT optimizer, param_name_value.first.c_str(), ¶m_name_value.second), WHOLEMEMORY_SUCCESS); } - + std::vector rank_partition(world_size); + wholememory_ops::testing::host_random_partition( + rank_partition.data(), embedding_tensor_description.sizes[0], world_size); + size_t* rank_partition_ptr = nullptr; + if (params.get_rank_partition_method() == 1) { rank_partition_ptr = rank_partition.data(); } EXPECT_EQ(wholememory_create_embedding(&wm_embedding, &embedding_tensor_description, wm_comm, params.memory_type, params.memory_location, - cache_policy), + cache_policy, + rank_partition_ptr), WHOLEMEMORY_SUCCESS); EXPECT_EQ(wholememory_embedding_set_optimizer(wm_embedding, optimizer), WHOLEMEMORY_SUCCESS); wholememory_tensor_t embedding_tensor = @@ -602,19 +614,18 @@ TEST_P(WholeMemoryEmbeddingBackwardParameterTests, EmbeddingGatherGradientApplyT WHOLEMEMORY_SUCCESS); wholememory_handle_t embedding_handle = wholememory_tensor_get_memory_handle(embedding_tensor); - auto entry_per_partition = wholememory_tensor_get_entry_per_partition(embedding_tensor); - int64_t total_entry_count = params.embedding_description.sizes[0]; - int64_t rank_start_entry = - std::min(world_rank * entry_per_partition, total_entry_count); - int64_t rank_end_entry = - std::min((world_rank + 1) * entry_per_partition, total_entry_count); - int64_t rank_entry_count = rank_end_entry - rank_start_entry; - + size_t rank_entry_count = 0; + size_t rank_start_entry = 0; + EXPECT_EQ(wholememory_tensor_get_local_entry_count(&rank_entry_count, embedding_tensor), + WHOLEMEMORY_SUCCESS); + EXPECT_EQ(wholememory_tensor_get_local_entry_start(&rank_start_entry, embedding_tensor), + WHOLEMEMORY_SUCCESS); + rank_entry_count = std::min( + rank_entry_count, params.embedding_description.sizes[0] - rank_start_entry); auto* dst_base_ptr = static_cast(wholememory_tensor_get_data_pointer(local_embed_tensor)); size_t dst_stride = wholememory_tensor_get_tensor_description(local_embed_tensor)->strides[0]; size_t embedding_copy_size = embedding_dim * sizeof(float); - for (int64_t i = 0; i < rank_entry_count; i++) { WM_CUDA_CHECK_NO_THROW(cudaMemcpy(dst_base_ptr + i * dst_stride, start_embedding_table[rank_start_entry + i].data(), @@ -738,6 +749,30 @@ INSTANTIATE_TEST_SUITE_P( #endif EmbeddingBackwardTestParams().set_entry_count(500).set_indice_count(400).set_embedding_dim(4), EmbeddingBackwardTestParams().set_embedding_dim(3), + EmbeddingBackwardTestParams() + .set_memory_location(WHOLEMEMORY_ML_DEVICE) + .set_optimizer_type(WHOLEMEMORY_OPT_RMSPROP) + .use_random_partition(), + EmbeddingBackwardTestParams() + .set_memory_location(WHOLEMEMORY_ML_DEVICE) + .set_optimizer_type(WHOLEMEMORY_OPT_ADAGRAD) + .use_random_partition(), + EmbeddingBackwardTestParams() + .set_memory_location(WHOLEMEMORY_ML_DEVICE) + .set_optimizer_type(WHOLEMEMORY_OPT_LAZY_ADAM) + .use_random_partition(), + EmbeddingBackwardTestParams() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_optimizer_type(WHOLEMEMORY_OPT_RMSPROP) + .use_random_partition(), + EmbeddingBackwardTestParams() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_optimizer_type(WHOLEMEMORY_OPT_ADAGRAD) + .use_random_partition(), + EmbeddingBackwardTestParams() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_optimizer_type(WHOLEMEMORY_OPT_LAZY_ADAM) + .use_random_partition(), EmbeddingBackwardTestParams().set_use_cache().set_grad_stride(131), EmbeddingBackwardTestParams().set_use_cache().set_grad_stride(131).set_optimizer_type( WHOLEMEMORY_OPT_RMSPROP), diff --git a/cpp/tests/wholememory_ops/wholememory_embedding_tests.cu b/cpp/tests/wholememory_ops/wholememory_embedding_tests.cu index 03f7987..152b2de 100644 --- a/cpp/tests/wholememory_ops/wholememory_embedding_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_embedding_tests.cu @@ -127,6 +127,7 @@ struct EmbeddingTestParams { WHOLEMEMORY_SUCCESS); return cache_policy; } + int get_rank_partition_method() const { return rank_partition_method; } EmbeddingTestParams& non_cache() { cache_type = 0; @@ -147,6 +148,11 @@ struct EmbeddingTestParams { cache_group_count = count; return *this; } + EmbeddingTestParams& use_random_partition() + { + rank_partition_method = 1; + return *this; + } wholememory_array_description_t indice_description; wholememory_matrix_description_t embedding_description; wholememory_matrix_description_t output_description; @@ -155,8 +161,9 @@ struct EmbeddingTestParams { wholememory_memory_type_t cache_memory_type = WHOLEMEMORY_MT_CHUNKED; wholememory_memory_location_t cache_memory_location = WHOLEMEMORY_ML_DEVICE; float cache_ratio = 0.2; - int cache_type = 0; // 0: no cache, 1: device cache, 2: local cache - int cache_group_count = 1; + int cache_type = 0; // 0: no cache, 1: device cache, 2: local cache + int cache_group_count = 1; + int rank_partition_method = 0; // 0-default, 1-random }; class WholeMemoryEmbeddingParameterTests : public ::testing::TestWithParam {}; @@ -238,13 +245,18 @@ TEST_P(WholeMemoryEmbeddingParameterTests, EmbeddingGatherTest) wholememory_tensor_description_t embedding_tensor_description; wholememory_copy_matrix_desc_to_tensor(&embedding_tensor_description, ¶ms.embedding_description); - + std::vector rank_partition(world_size); + wholememory_ops::testing::host_random_partition( + rank_partition.data(), embedding_tensor_description.sizes[0], world_size); + size_t* rank_partition_ptr = nullptr; + if (params.get_rank_partition_method() == 1) { rank_partition_ptr = rank_partition.data(); } EXPECT_EQ(wholememory_create_embedding(&wm_embedding, &embedding_tensor_description, wm_comm, params.memory_type, params.memory_location, - cache_policy), + cache_policy, + rank_partition_ptr), WHOLEMEMORY_SUCCESS); wholememory_tensor_t embedding_tensor = @@ -353,6 +365,19 @@ INSTANTIATE_TEST_SUITE_P( #if 1 EmbeddingTestParams().non_cache(), EmbeddingTestParams().non_cache().set_memory_location(WHOLEMEMORY_ML_DEVICE), + EmbeddingTestParams() + .non_cache() + .set_memory_location(WHOLEMEMORY_ML_DEVICE) + .use_random_partition(), + EmbeddingTestParams() + .non_cache() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .use_random_partition(), + EmbeddingTestParams() + .non_cache() + .set_memory_location(WHOLEMEMORY_ML_DEVICE) + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .use_random_partition(), EmbeddingTestParams().device_cache(), EmbeddingTestParams().device_cache().set_cache_memory_type(WHOLEMEMORY_MT_DISTRIBUTED), EmbeddingTestParams().local_cache(), diff --git a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu index ada9c87..506e21c 100644 --- a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu @@ -50,7 +50,7 @@ typedef struct WholeMemoryGatherTestParam { { return embedding_stride * wholememory_dtype_get_element_size(embedding_type); } - + int get_rank_partition_method() const { return rank_partition_method; } WholeMemoryGatherTestParam& set_memory_type(wholememory_memory_type_t new_memory_type) { memory_type = new_memory_type; @@ -109,6 +109,11 @@ typedef struct WholeMemoryGatherTestParam { distributed_backend = new_distributed_backend; return *this; } + WholeMemoryGatherTestParam& use_random_partition() + { + rank_partition_method = 1; + return *this; + } wholememory_memory_type_t memory_type = WHOLEMEMORY_MT_CHUNKED; wholememory_memory_location_t memory_location = WHOLEMEMORY_ML_DEVICE; int64_t embedding_entry_count = 1000000LL; @@ -123,6 +128,7 @@ typedef struct WholeMemoryGatherTestParam { int64_t indices_storage_offset = 0; int64_t output_storage_offset = 0; wholememory_distributed_backend_t distributed_backend = WHOLEMEMORY_DB_NCCL; + int rank_partition_method = 0; // 0-default, 1-random } WholeMemoryGatherTestParam; class WholeMemoryGatherParameterTests @@ -164,14 +170,19 @@ TEST_P(WholeMemoryGatherParameterTests, GatherTest) auto indices_desc = params.get_indices_desc(); auto output_desc = params.get_output_desc(); size_t embedding_entry_size = params.get_embedding_granularity(); + std::vector rank_partition(world_size); + wholememory_ops::testing::host_random_partition( + rank_partition.data(), embedding_desc.sizes[0], world_size); + size_t* rank_partition_ptr = nullptr; + if (params.get_rank_partition_method() == 1) { rank_partition_ptr = rank_partition.data(); } EXPECT_EQ(wholememory_malloc(&embedding_handle, wholememory_get_memory_size_from_matrix(&embedding_desc), wm_comm, params.memory_type, params.memory_location, - embedding_entry_size), + embedding_entry_size, + rank_partition_ptr), WHOLEMEMORY_SUCCESS); - cudaStream_t stream; EXPECT_EQ(cudaStreamCreate(&stream), cudaSuccess); @@ -289,9 +300,11 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_indices_count(0), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_indices_count(0), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_indices_count(0), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_indices_count(0), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_memory_location(WHOLEMEMORY_ML_HOST), @@ -301,6 +314,20 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_memory_location(WHOLEMEMORY_ML_HOST), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_memory_location(WHOLEMEMORY_ML_HOST), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).use_random_partition(), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).use_random_partition(), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).use_random_partition(), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .use_random_partition(), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .use_random_partition(), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_memory_location(WHOLEMEMORY_ML_HOST) @@ -336,18 +363,27 @@ INSTANTIATE_TEST_SUITE_P( .set_embedding_dim(11) .set_embedding_stride(12) .set_indices_count(100005), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_embedding_dim(11) + .set_embedding_stride(12) + .set_indices_count(100005), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(128), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(128), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_embedding_dim(128), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_dim(128), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(127), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(127), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_embedding_dim(127), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_dim(127), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(129), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(129), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_embedding_dim(129), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_dim(129), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(513), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(513), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_embedding_dim(513), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_dim(513), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_type(WHOLEMEMORY_DT_HALF), @@ -366,6 +402,9 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_output_type(WHOLEMEMORY_DT_HALF), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_output_type(WHOLEMEMORY_DT_HALF), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_type(WHOLEMEMORY_DT_HALF) @@ -378,6 +417,10 @@ INSTANTIATE_TEST_SUITE_P( .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_type(WHOLEMEMORY_DT_HALF) .set_output_type(WHOLEMEMORY_DT_HALF), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_embedding_type(WHOLEMEMORY_DT_HALF) + .set_output_type(WHOLEMEMORY_DT_HALF), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_indices_type(WHOLEMEMORY_DT_INT64), @@ -387,6 +430,9 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_indices_type(WHOLEMEMORY_DT_INT64), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_indices_type(WHOLEMEMORY_DT_INT64), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_stride(33), @@ -394,9 +440,11 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_stride(33), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_embedding_stride(33), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_output_stride(33), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_output_stride(33), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).set_output_stride(33), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_HIERARCHY).set_output_stride(33), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_type(WHOLEMEMORY_DT_HALF) @@ -409,12 +457,20 @@ INSTANTIATE_TEST_SUITE_P( .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_type(WHOLEMEMORY_DT_HALF) .set_embedding_stride(33), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_HIERARCHY) + .set_embedding_type(WHOLEMEMORY_DT_HALF) + .set_embedding_stride(33), WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) #ifdef WITH_NVSHMEM_SUPPORT , WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_distributed_backend(WHOLEMEMORY_DB_NVSHMEM), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_distributed_backend(WHOLEMEMORY_DB_NVSHMEM) + .use_random_partition(), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_dim(11) diff --git a/cpp/tests/wholememory_ops/wholememory_scatter_tests.cu b/cpp/tests/wholememory_ops/wholememory_scatter_tests.cu index 991bc0c..656d608 100644 --- a/cpp/tests/wholememory_ops/wholememory_scatter_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_scatter_tests.cu @@ -47,6 +47,7 @@ typedef struct WholeMemoryScatterTestParam { { return embedding_stride * wholememory_dtype_get_element_size(embedding_type); } + int get_rank_partition_method() const { return rank_partition_method; } WholeMemoryScatterTestParam& set_memory_type(wholememory_memory_type_t new_memory_type) { @@ -107,6 +108,11 @@ typedef struct WholeMemoryScatterTestParam { distributed_backend = new_distributed_backend; return *this; } + WholeMemoryScatterTestParam& use_random_partition() + { + rank_partition_method = 1; + return *this; + } wholememory_memory_type_t memory_type = WHOLEMEMORY_MT_CHUNKED; wholememory_memory_location_t memory_location = WHOLEMEMORY_ML_DEVICE; int64_t embedding_entry_count = 1000000LL; @@ -121,6 +127,7 @@ typedef struct WholeMemoryScatterTestParam { int64_t indices_storage_offset = 0; int64_t input_storage_offset = 0; wholememory_distributed_backend_t distributed_backend = WHOLEMEMORY_DB_NCCL; + int rank_partition_method = 0; // 0-default, 1-random } WholeMemoryScatterTestParam; class WholeMemoryScatterParameterTests @@ -161,12 +168,18 @@ TEST_P(WholeMemoryScatterParameterTests, ScatterTest) auto indices_desc = params.get_indices_desc(); auto input_desc = params.get_input_desc(); size_t embedding_entry_size = params.get_embedding_granularity(); + std::vector rank_partition(world_size); + wholememory_ops::testing::host_random_partition( + rank_partition.data(), embedding_desc.sizes[0], world_size); + size_t* rank_partition_ptr = nullptr; + if (params.get_rank_partition_method() == 1) { rank_partition_ptr = rank_partition.data(); } EXPECT_EQ(wholememory_malloc(&embedding_handle, wholememory_get_memory_size_from_matrix(&embedding_desc), wm_comm, params.memory_type, params.memory_location, - embedding_entry_size), + embedding_entry_size, + rank_partition_ptr), WHOLEMEMORY_SUCCESS); cudaStream_t stream; @@ -304,6 +317,14 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryScatterTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_memory_location(WHOLEMEMORY_ML_HOST), + WholeMemoryScatterTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).use_random_partition(), + WholeMemoryScatterTestParam() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .use_random_partition(), + WholeMemoryScatterTestParam() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .use_random_partition(), WholeMemoryScatterTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(128), WholeMemoryScatterTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(128), WholeMemoryScatterTestParam() @@ -404,6 +425,10 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryScatterTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_distributed_backend(WHOLEMEMORY_DB_NVSHMEM), + WholeMemoryScatterTestParam() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_distributed_backend(WHOLEMEMORY_DB_NVSHMEM) + .use_random_partition(), WholeMemoryScatterTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_indices_count(0) diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index 44728f4..1e1298f 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -60,6 +60,7 @@ cdef extern from "wholememory/wholememory.h": WHOLEMEMORY_MT_CONTINUOUS "WHOLEMEMORY_MT_CONTINUOUS" WHOLEMEMORY_MT_CHUNKED "WHOLEMEMORY_MT_CHUNKED" WHOLEMEMORY_MT_DISTRIBUTED "WHOLEMEMORY_MT_DISTRIBUTED" + WHOLEMEMORY_MT_HIERARCHY "WHOLEMEMORY_MT_HIERARCHY" ctypedef enum wholememory_memory_location_t: WHOLEMEMORY_ML_NONE "WHOLEMEMORY_ML_NONE" @@ -121,13 +122,20 @@ cdef extern from "wholememory/wholememory.h": wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) + size_t data_granularity, + size_t * rank_entry_partition) cdef wholememory_error_code_t wholememory_free(wholememory_handle_t wholememory_handle) cdef wholememory_error_code_t wholememory_get_communicator(wholememory_comm_t * comm, wholememory_handle_t wholememory_handle) + cdef wholememory_error_code_t wholememory_get_local_communicator(wholememory_comm_t * comm, + wholememory_handle_t wholememory_handle) + + cdef wholememory_error_code_t wholememory_get_cross_communicator(wholememory_comm_t * comm, + wholememory_handle_t wholememory_handle) + cdef wholememory_memory_type_t wholememory_get_memory_type(wholememory_handle_t wholememory_handle) cdef wholememory_memory_location_t wholememory_get_memory_location(wholememory_handle_t wholememory_handle) @@ -139,26 +147,30 @@ cdef extern from "wholememory/wholememory.h": size_t * local_offset, wholememory_handle_t wholememory_handle) + cdef wholememory_error_code_t wholememory_get_local_size(size_t * local_size, + wholememory_handle_t wholememory_handle) + + cdef wholememory_error_code_t wholememory_get_local_offset(size_t * local_offset, + wholememory_handle_t wholememory_handle) + cdef wholememory_error_code_t wholememory_get_rank_memory(void** rank_memory_ptr, size_t * rank_memory_size, size_t * rank_memory_offset, int rank, wholememory_handle_t wholememory_handle) + cdef wholememory_error_code_t wholememory_equal_entry_partition_plan(size_t* entry_per_rank, + size_t total_entry_count, + int world_size) + cdef wholememory_error_code_t wholememory_get_global_pointer(void** global_ptr, wholememory_handle_t wholememory_handle) - cdef wholememory_error_code_t wholememory_determine_partition_plan(size_t * size_per_rank, - size_t total_size, - size_t data_granularity, - int world_size) - - cdef wholememory_error_code_t wholememory_determine_entry_partition_plan(size_t * entry_per_rank, - size_t total_entry_count, - int world_size) + cdef wholememory_error_code_t wholememory_get_rank_partition_sizes(size_t * rank_mem_sizes, + wholememory_handle_t wholememory_handle) - cdef wholememory_error_code_t wholememory_get_partition_plan(size_t * size_per_rank, - wholememory_handle_t wholememory_handle) + cdef wholememory_error_code_t wholememory_get_rank_partition_offsets(size_t * rank_mem_offsets, + wholememory_handle_t wholememory_handle) cdef int fork_get_device_count() @@ -220,6 +232,7 @@ cpdef enum WholeMemoryMemoryType: MtContinuous = WHOLEMEMORY_MT_CONTINUOUS MtChunked = WHOLEMEMORY_MT_CHUNKED MtDistributed = WHOLEMEMORY_MT_DISTRIBUTED + MtHierarchy = WHOLEMEMORY_MT_HIERARCHY cpdef enum WholeMemoryMemoryLocation: MlNone = WHOLEMEMORY_ML_NONE @@ -548,7 +561,8 @@ cdef extern from "wholememory/wholememory_tensor.h": wholememory_tensor_description_t *tensor_description, wholememory_comm_t comm, wholememory_memory_type_t memory_type, - wholememory_memory_location_t memory_location) + wholememory_memory_location_t memory_location, + size_t * tensor_entry_partition) cdef wholememory_error_code_t wholememory_destroy_tensor(wholememory_tensor_t wholememory_tensor) @@ -567,6 +581,18 @@ cdef extern from "wholememory/wholememory_tensor.h": cdef wholememory_tensor_description_t * wholememory_tensor_get_tensor_description( wholememory_tensor_t wholememory_tensor) + cdef wholememory_error_code_t wholememory_tensor_get_entry_offsets( + size_t * entry_offsets, wholememory_tensor_t wholememory_tensor); + + cdef wholememory_error_code_t wholememory_tensor_get_entry_partition_sizes( + size_t * entry_partition, wholememory_tensor_t wholememory_tensor); + + cdef wholememory_error_code_t wholememory_tensor_get_local_entry_count( + size_t * local_entry_count, wholememory_tensor_t wholememory_tensor); + + cdef wholememory_error_code_t wholememory_tensor_get_local_entry_start( + size_t * local_entry_start, wholememory_tensor_t wholememory_tensor); + cdef wholememory_error_code_t wholememory_tensor_get_subtensor(wholememory_tensor_t wholememory_tensor, int64_t *starts, int64_t *ends, @@ -642,6 +668,7 @@ cdef extern from "wholememory/embedding.h": wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, wholememory_embedding_cache_policy_t cache_policy, + size_t * embedding_entry_partition, int user_defined_sms, int round_robin_size) @@ -814,16 +841,21 @@ cdef class PyWholeMemoryEmbedding: WholeMemoryMemoryType memory_type, WholeMemoryMemoryLocation memory_location, WholeMemoryCachePolicy cache_policy, + cython.size_t[:] embedding_entry_partition, int user_defined_sms, int round_robin_size): self.memory_type = memory_type self.memory_location = memory_location + cdef size_t* partition_ptr = NULL + if embedding_entry_partition is not None and embedding_entry_partition.size > 0: + partition_ptr = &embedding_entry_partition[0] check_wholememory_error_code(wholememory_create_embedding(&self.wm_embedding, &tensor_desc.tensor_description, comm.comm_id, self.memory_type, self.memory_location, cache_policy.cache_policy, + partition_ptr, user_defined_sms, round_robin_size)) @@ -871,6 +903,7 @@ def create_embedding(PyWholeMemoryTensorDescription tensor_desc, WholeMemoryMemoryType memory_type, WholeMemoryMemoryLocation memory_location, WholeMemoryCachePolicy cache_policy, + cython.size_t[:] embedding_entry_partition, int user_defined_sms, int round_robin_size): wm_embedding = PyWholeMemoryEmbedding() @@ -879,6 +912,7 @@ def create_embedding(PyWholeMemoryTensorDescription tensor_desc, memory_type, memory_location, cache_policy, + embedding_entry_partition, user_defined_sms, round_robin_size) return wm_embedding @@ -1315,17 +1349,22 @@ cdef class PyWholeMemoryHandle: check_wholememory_error_code(wholememory_get_communicator(&py_comm.comm_id, self.wholememory_handle)) return py_comm + def get_local_communicator(self): + py_comm = PyWholeMemoryComm() + check_wholememory_error_code(wholememory_get_local_communicator(&py_comm.comm_id, self.wholememory_handle)) + return py_comm + + def get_cross_communicator(self): + py_comm = PyWholeMemoryComm() + check_wholememory_error_code(wholememory_get_cross_communicator(&py_comm.comm_id, self.wholememory_handle)) + return py_comm + def get_memory_type(self): return WholeMemoryMemoryType(wholememory_get_memory_type(self.wholememory_handle)) def get_memory_location(self): return WholeMemoryMemoryLocation(wholememory_get_memory_location(self.wholememory_handle)) - def get_partition_plan(self): - cdef size_t size_per_rank - check_wholememory_error_code(wholememory_get_partition_plan(&size_per_rank, self.wholememory_handle)) - return size_per_rank - def get_global_flatten_tensor(self, object import_dlpack_fn, WholeMemoryDataType data_type, @@ -1513,12 +1552,15 @@ cdef class PyWholeMemoryTensor: def storage_offset(self): return self.tensor_description.storage_offset - def get_partition_plan(self): - mem_size_per_rank = self.get_wholememory_handle().get_partition_plan() - element_size = wholememory_dtype_get_element_size(self.tensor_description.dtype) - vector_size = element_size * self.stride()[0] - assert mem_size_per_rank % vector_size == 0 - return mem_size_per_rank // vector_size + def get_local_entry_count(self): + cdef size_t local_entry_count = 0 + check_wholememory_error_code(wholememory_tensor_get_local_entry_count(&local_entry_count, self.wholememory_tensor)) + return local_entry_count + + def get_local_entry_start(self): + cdef size_t local_entry_start = 0 + check_wholememory_error_code(wholememory_tensor_get_local_entry_start(&local_entry_start, self.wholememory_tensor)) + return local_entry_start def get_sub_tensor(self, starts, ends): cdef int64_t start_array[2] @@ -1661,10 +1703,10 @@ def split_communicator(PyWholeMemoryComm comm,int color,int key): def communicator_set_distributed_backend(PyWholeMemoryComm py_comm,WholeMemoryDistributedBackend distributed_backend): check_wholememory_error_code(wholememory_communicator_set_distributed_backend(py_comm.comm_id,int(distributed_backend))) -def determine_partition_plan(int64_t entry_count, +def equal_partition_plan(int64_t entry_count, int world_size): cdef size_t per_rank_count - check_wholememory_error_code(wholememory_determine_entry_partition_plan(&per_rank_count, + check_wholememory_error_code(wholememory_equal_entry_partition_plan(&per_rank_count, entry_count, world_size)) return per_rank_count @@ -1673,11 +1715,15 @@ def malloc(cython.size_t total_size, PyWholeMemoryComm py_comm, WholeMemoryMemoryType memory_type, WholeMemoryMemoryLocation memory_location, - cython.size_t data_granularity): + cython.size_t data_granularity, + cython.size_t[:] rank_entry_partition=None): handle = PyWholeMemoryHandle() + cdef size_t* partition_ptr = NULL + if rank_entry_partition is not None and rank_entry_partition.size > 0: + partition_ptr = &rank_entry_partition[0] check_wholememory_error_code(wholememory_malloc(&handle.wholememory_handle, total_size, py_comm.comm_id, int(memory_type), int(memory_location), - data_granularity)) + data_granularity, partition_ptr)) return handle def free(PyWholeMemoryHandle handle): @@ -1687,18 +1733,23 @@ def create_wholememory_array(WholeMemoryDataType dtype, int64_t size, PyWholeMemoryComm comm, WholeMemoryMemoryType mem_type, - WholeMemoryMemoryLocation mem_location): + WholeMemoryMemoryLocation mem_location, + cython.size_t[:] tensor_entry_partition=None): wholememory_tensor = PyWholeMemoryTensor() wholememory_tensor.tensor_description.dtype = int(dtype) wholememory_tensor.tensor_description.storage_offset = 0 wholememory_tensor.tensor_description.dim = 1 wholememory_tensor.tensor_description.strides[0] = 1 wholememory_tensor.tensor_description.sizes[0] = size + cdef size_t* partition_ptr = NULL + if tensor_entry_partition is not None and tensor_entry_partition.size > 0: + partition_ptr = &tensor_entry_partition[0] check_wholememory_error_code(wholememory_create_tensor(&wholememory_tensor.wholememory_tensor, &wholememory_tensor.tensor_description, comm.comm_id, int(mem_type), - int(mem_location))) + int(mem_location), + partition_ptr)) return wholememory_tensor def create_wholememory_matrix(WholeMemoryDataType dtype, @@ -1707,7 +1758,8 @@ def create_wholememory_matrix(WholeMemoryDataType dtype, int64_t stride, PyWholeMemoryComm comm, WholeMemoryMemoryType mem_type, - WholeMemoryMemoryLocation mem_location): + WholeMemoryMemoryLocation mem_location, + cython.size_t[:] tensor_entry_partition=None): wholememory_tensor = PyWholeMemoryTensor() wholememory_tensor.tensor_description.dtype = int(dtype) wholememory_tensor.tensor_description.storage_offset = 0 @@ -1718,17 +1770,22 @@ def create_wholememory_matrix(WholeMemoryDataType dtype, wholememory_tensor.tensor_description.strides[1] = 1 wholememory_tensor.tensor_description.sizes[0] = row wholememory_tensor.tensor_description.sizes[1] = column + cdef size_t* partition_ptr = NULL + if tensor_entry_partition is not None and tensor_entry_partition.size > 0: + partition_ptr = &tensor_entry_partition[0] check_wholememory_error_code(wholememory_create_tensor(&wholememory_tensor.wholememory_tensor, &wholememory_tensor.tensor_description, comm.comm_id, int(mem_type), - int(mem_location))) + int(mem_location), + partition_ptr)) return wholememory_tensor def create_wholememory_tensor(PyWholeMemoryTensorDescription tensor_description, PyWholeMemoryComm comm, WholeMemoryMemoryType mem_type, - WholeMemoryMemoryLocation mem_location): + WholeMemoryMemoryLocation mem_location, + cython.size_t[:] tensor_entry_partition=None): if tensor_description.dim() != 1 and tensor_description.dim() != 2: raise NotImplementedError('WholeMemory currently only support 1D or 2D tensor') if tensor_description.stride()[tensor_description.dim() - 1] != 1: @@ -1737,11 +1794,15 @@ def create_wholememory_tensor(PyWholeMemoryTensorDescription tensor_description, raise ValueError('storage_offset be 0 when created') wholememory_tensor = PyWholeMemoryTensor() wholememory_tensor.tensor_description = tensor_description.tensor_description + cdef size_t* partition_ptr = NULL + if tensor_entry_partition is not None and tensor_entry_partition.size > 0: + partition_ptr = &tensor_entry_partition[0] check_wholememory_error_code(wholememory_create_tensor(&wholememory_tensor.wholememory_tensor, &wholememory_tensor.tensor_description, comm.comm_id, int(mem_type), - int(mem_location))) + int(mem_location), + partition_ptr)) return wholememory_tensor def make_tensor_as_wholememory(PyWholeMemoryTensorDescription tensor_description, diff --git a/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py b/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py index 5ef072e..0f16dbe 100644 --- a/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py +++ b/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py @@ -12,6 +12,7 @@ # limitations under the License. import torch +import numpy as np import pylibwholegraph.binding.wholememory_binding as wmb from pylibwholegraph.torch.dlpack_utils import torch_import_from_dlpack from packaging import version @@ -134,21 +135,11 @@ def copy_host_1D_tensor_to_wholememory( torch_import_from_dlpack, wmb.WholeMemoryMemoryLocation.MlDevice, world_rank ) assert local_tensor_cuda.dim() == 1 - wm_array_size = wm_array.shape[0] - - local_start_ref = min( - wmb.determine_partition_plan(wm_array_size, world_size) * world_rank, - wm_array_size, - ) - local_end = min( - wmb.determine_partition_plan(wm_array_size, world_size) * (world_rank + 1), - wm_array_size, - ) - local_count = local_end - local_start - + local_count = wm_array.get_local_entry_count() + local_start_ref = wm_array.get_local_entry_start() assert local_start == local_start_ref assert local_tensor_cuda.shape[0] == local_count - local_tensor_cuda.copy_(host_tensor[local_start:local_end]) + local_tensor_cuda.copy_(host_tensor[local_start : local_start + local_count]) wm_comm.barrier() @@ -199,5 +190,17 @@ def int_to_wholememory_type(value: int): return wmb.WholeMemoryMemoryType.MtChunked if value == 2: return wmb.WholeMemoryMemoryType.MtDistributed + if value == 3: + return wmb.WholeMemoryMemoryType.MtHierarchy else: raise ValueError("invalid int_to_wholememory_type value") + + +def random_partition(total_entry_count: int, world_size: int) -> np.array: + np.random.seed(42) + random_array = np.random.uniform(90, 100, size=world_size) + random_sum = np.sum(random_array) + partition = ((random_array / random_sum) * total_entry_count).astype(np.uintp) + diff = total_entry_count - np.sum(partition) + partition[0] += diff + return partition diff --git a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py index fc9320c..e7633b0 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py @@ -16,6 +16,7 @@ from pylibwholegraph.utils.multiprocess import multiprocess_run from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm from pylibwholegraph.torch.dlpack_utils import torch_import_from_dlpack +from pylibwholegraph.test_utils.test_comm import random_partition import torch import numpy as np import os @@ -56,6 +57,7 @@ def load_routine_func( embedding_stride, storage_offset, round_robin_size=0, + entry_partition=None, ): wm_comm, _ = init_torch_env_and_create_wm_comm( world_rank, world_size, world_rank, world_size @@ -78,12 +80,18 @@ def load_routine_func( + first_rank_extra_embedding_entry_count * world_size ) - per_rank_entry = wmb.determine_partition_plan(extra_embedding_count, world_size) - rank_start_entry = min(per_rank_entry * world_rank, extra_embedding_count) - rank_end_entry = min(per_rank_entry * (world_rank + 1), extra_embedding_count) - rank_entry_count = rank_end_entry - rank_start_entry + if entry_partition is None: + per_rank_entry = wmb.equal_partition_plan(extra_embedding_count, world_size) + rank_start_entry = min(per_rank_entry * world_rank, extra_embedding_count) + rank_end_entry = min(per_rank_entry * (world_rank + 1), extra_embedding_count) + rank_entry_count = rank_end_entry - rank_start_entry + else: + rank_start_entry = np.sum(entry_partition[:world_rank]) + rank_entry_count = entry_partition[world_rank] + rank_end_entry = rank_start_entry + rank_entry_count if round_robin_size != 0: + per_rank_entry = wmb.equal_partition_plan(extra_embedding_count, world_size) first_rank_extra_embedding_entry_count = embedding_entry_count % ( world_size * round_robin_size ) @@ -137,6 +145,7 @@ def load_routine_func( wm_comm, mt, ml, + entry_partition, ) wholememory_tensor = wholememory_root_tensor.get_sub_tensor( @@ -173,6 +182,7 @@ def load_routine_func( @pytest.mark.parametrize("embedding_stride", [16, 32, 64]) @pytest.mark.parametrize("storage_offset", [0, 3]) @pytest.mark.parametrize("round_robin_size", [256, 1024, 0]) +@pytest.mark.parametrize("partition_method", ["random", "default"]) def test_wholememory_load( file_part_count, embedding_entry_count, @@ -180,6 +190,7 @@ def test_wholememory_load( embedding_stride, storage_offset, round_robin_size, + partition_method, ): if embedding_stride < storage_offset + embedding_dim: pytest.skip( @@ -191,9 +202,19 @@ def test_wholememory_load( "Skipping due to round_robin_size!=0 and storage offset !=0 , " "the configuration is not valid." ) + if partition_method != "default" and round_robin_size != 0: + pytest.skip( + "Skipping due to round_robin_size!=0 and partition method != 'default', " + "the configuration is not valid." + ) global gpu_count if not gpu_count: gpu_count = 1 + + entry_partition = None + if partition_method == "random": + entry_partition = random_partition(embedding_entry_count, gpu_count) + extra_embedding_count = embedding_entry_count if round_robin_size != 0: first_rank_extra_embedding_entry_count = embedding_entry_count % ( @@ -231,7 +252,7 @@ def test_wholememory_load( ) if round_robin_size != 0: - entry_per_rank = wmb.determine_partition_plan(extra_embedding_count, gpu_count) + entry_per_rank = wmb.equal_partition_plan(extra_embedding_count, gpu_count) cpu_embedding_tensor_base_extra = torch.empty( (extra_embedding_count, embedding_dim), dtype=torch.int, device="cpu" @@ -262,6 +283,7 @@ def test_wholememory_load( embedding_stride=embedding_stride, storage_offset=storage_offset, round_robin_size=round_robin_size, + entry_partition=entry_partition, ) multiprocess_run(gpu_count, load_routine_func_partial) @@ -280,6 +302,7 @@ def store_routine_func( embedding_dim, embedding_stride, storage_offset, + entry_partition, ): (wm_comm, _) = init_torch_env_and_create_wm_comm( world_rank, world_size, world_rank, world_size @@ -299,6 +322,7 @@ def store_routine_func( wm_comm, mt, ml, + entry_partition, ) local_root_tensor, local_root_offset = wholememory_root_tensor.get_local_tensor( torch_import_from_dlpack, wmb.WholeMemoryMemoryLocation.MlHost, world_rank @@ -326,14 +350,24 @@ def store_routine_func( @pytest.mark.parametrize("embedding_dim", [16, 31, 33]) @pytest.mark.parametrize("embedding_stride", [16, 32, 64]) @pytest.mark.parametrize("storage_offset", [0, 3]) +@pytest.mark.parametrize("partition_method", ["random"]) def test_wholememory_store( - embedding_entry_count, embedding_dim, embedding_stride, storage_offset + embedding_entry_count, + embedding_dim, + embedding_stride, + storage_offset, + partition_method, ): if embedding_stride < storage_offset + embedding_dim: pytest.skip( "Skipping due to embedding_stride, " "embedding_dim and storage_offset configuration not valid." ) + + global gpu_count + entry_partition = None + if partition_method == "random": + entry_partition = random_partition(embedding_entry_count, gpu_count) file_name_prefix = "pytest_store_temp_file" store_routine_func_partial = partial( store_routine_func, @@ -342,9 +376,9 @@ def test_wholememory_store( embedding_dim=embedding_dim, embedding_stride=embedding_stride, storage_offset=storage_offset, + entry_partition=entry_partition, ) - global gpu_count multiprocess_run(gpu_count, store_routine_func_partial) embedding_entry_offset = 0 file_part_count = gpu_count diff --git a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_tensor.py b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_tensor.py index c0ef740..979e090 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_tensor.py @@ -14,19 +14,20 @@ import pylibwholegraph.binding.wholememory_binding as wmb from pylibwholegraph.utils.multiprocess import multiprocess_run from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm +from pylibwholegraph.test_utils.test_comm import random_partition # Run with: # python3 -m pytest ../tests/pylibwholegraph/test_wholememory_tensor.py -s -def array_test_case(wm_comm, dt, mt, ml, size): +def array_test_case(wm_comm, dt, mt, ml, size, entry_partition): world_rank = wm_comm.get_rank() print( "Rank=%d testing array size=%d dt=%s, mt=%s, ml=%s" % (world_rank, size, dt, mt, ml) ) - wm_array = wmb.create_wholememory_array(dt, size, wm_comm, mt, ml) + wm_array = wmb.create_wholememory_array(dt, size, wm_comm, mt, ml, entry_partition) assert wm_array.dtype == dt assert wm_array.dim() == 1 assert len(wm_array.shape) == 1 @@ -47,14 +48,14 @@ def array_test_case(wm_comm, dt, mt, ml, size): wmb.destroy_wholememory_tensor(wm_array) -def matrix_test_case(wm_comm, dt, mt, ml, mat_size): +def matrix_test_case(wm_comm, dt, mt, ml, mat_size, entry_partition): world_rank = wm_comm.get_rank() print( "Rank=%d testing matrix size=%s dt=%s, mt=%s, ml=%s" % (world_rank, mat_size, dt, mt, ml) ) wm_matrix = wmb.create_wholememory_matrix( - dt, mat_size[0], mat_size[1], -1, wm_comm, mt, ml + dt, mat_size[0], mat_size[1], -1, wm_comm, mt, ml, entry_partition ) assert wm_matrix.dtype == dt @@ -93,7 +94,8 @@ def routine_func(world_rank: int, world_size: int): single_array_size = 128 * 1024 * 1024 * world_size single_matrix_size = (1024 * 1024 * world_size, 128) dt = wmb.WholeMemoryDataType.DtFloat - + array_entry_partition = random_partition(single_array_size, world_size) + matrix_entry_partition = random_partition(single_matrix_size[0], world_size) print("") for mt in [ @@ -106,8 +108,12 @@ def routine_func(world_rank: int, world_size: int): wmb.WholeMemoryMemoryLocation.MlDevice, ]: if wm_comm.support_type_location(mt, ml): - array_test_case(wm_comm, dt, mt, ml, single_array_size) - matrix_test_case(wm_comm, dt, mt, ml, single_matrix_size) + array_test_case( + wm_comm, dt, mt, ml, single_array_size, array_entry_partition + ) + matrix_test_case( + wm_comm, dt, mt, ml, single_matrix_size, matrix_entry_partition + ) wmb.finalize() diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py index c4c3ce5..361ae4f 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py @@ -15,12 +15,13 @@ from pylibwholegraph.utils.multiprocess import multiprocess_run from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm from pylibwholegraph.torch.dlpack_utils import torch_import_from_dlpack +from pylibwholegraph.test_utils.test_comm import random_partition import torch import pylibwholegraph.torch.wholememory_ops as wm_ops # PYTHONPATH=../:$PYTHONPATH python3 -m pytest \ -# ../tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py -s +# ../tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py -s def gen_int_embedding(indice_tensor, embedding_dim, output_type): @@ -46,6 +47,7 @@ def scatter_gather_test_cast( embedding_dim, indice_count, use_python_binding=True, + entry_partition=None, ): world_rank = wm_comm.get_rank() world_size = wm_comm.get_size() @@ -56,7 +58,7 @@ def scatter_gather_test_cast( f"indice_count={indice_count}, dt={dt}, mt={mt}, ml={ml}" ) wm_embedding = wmb.create_wholememory_matrix( - dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml + dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition ) scatter_indice = torch.arange( @@ -88,22 +90,17 @@ def scatter_gather_test_cast( torch_import_from_dlpack, wmb.WholeMemoryMemoryLocation.MlDevice, world_rank ) - local_ref_start = min( - wmb.determine_partition_plan(embedding_count, world_size) * world_rank, - embedding_count, - ) - local_ref_end = min( - wmb.determine_partition_plan(embedding_count, world_size) * (world_rank + 1), - embedding_count, - ) - local_ref_count = local_ref_end - local_ref_start + local_ref_start = wm_embedding.get_local_entry_start() + local_ref_count = wm_embedding.get_local_entry_count() assert local_start == local_ref_start assert local_tensor_cuda.dim() == 2 assert local_tensor_cuda.shape[0] == local_ref_count assert local_tensor_cuda.shape[1] == embedding_dim local_tensor = local_tensor_cuda.cpu() - local_indices = torch.arange(local_ref_start, local_ref_end, dtype=torch.int64) + local_indices = torch.arange( + local_ref_start, local_ref_start + local_ref_count, dtype=torch.int64 + ) local_tensor_ref = gen_int_embedding(local_indices, embedding_dim, torch.float) # print('\nlocal_tensor %s =%s\nlocal_tensor_ref %s =%s' % ( # local_tensor.shape, local_tensor, local_tensor_ref.shape, local_tensor_ref)) @@ -144,6 +141,7 @@ def routine_func(world_rank: int, world_size: int): embedding_dim = 256 indice_count = 100001 dt = wmb.WholeMemoryDataType.DtFloat + entry_partition = random_partition(embedding_count, world_size) print("") @@ -166,8 +164,8 @@ def routine_func(world_rank: int, world_size: int): embedding_dim, indice_count, True, + entry_partition, ) - wmb.finalize() diff --git a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py index 49e549d..7f3bd18 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py @@ -33,7 +33,10 @@ def add_training_options(argparser: ArgumentParser): "--embedding-memory-type", dest="embedding_memory_type", default="chunked", - help="Embedding memory type, should be: continuous, chunked or distributed", + help=( + "Embedding memory type, should be: " + "continuous, chunked, distributed, hierarchy" + ), ) argparser.add_argument( "--cache-type", diff --git a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py index e3ef9b7..6e992a7 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py @@ -149,6 +149,7 @@ def create_builtin_cache_policy( embedding_memory_type != "continuous" and embedding_memory_type != "chunked" and embedding_memory_type != "distributed" + and embedding_memory_type != "hierarchy" ): raise ValueError(f"embedding_memory_type={embedding_memory_type} is not valid") @@ -320,7 +321,6 @@ def add_gradients(self, indice: torch.Tensor, grad_outputs: torch.Tensor): def apply_gradients(self, lr: float): sparse_indices = torch.cat(self.sparse_indices) sparse_grads = torch.cat(self.sparse_grads) - wmb.EmbeddingGatherGradientApply( self.wmb_embedding, wrap_torch_tensor(sparse_indices), @@ -387,6 +387,7 @@ def create_embedding( sizes: List[int], *, cache_policy: Union[WholeMemoryCachePolicy, None] = None, + embedding_entry_partition: Union[List[int], None] = None, random_init: bool = False, gather_sms: int = -1, round_robin_size: int = 0, @@ -399,6 +400,10 @@ def create_embedding( :param dtype: data type :param sizes: size of the embedding, must be 2D :param cache_policy: cache policy + :param embedding_entry_partition: rank partition based on entry; + embedding_entry_partition[i] determines the entry count of rank + i and shoud be a positive integer; the sum of embedding_entry_partition + should equal to total entry count; entries will be equally partitioned if None :param gather_sms: the number of SMs used in gather process :param round_robin_size: continuous embedding size of a rank using round robin shard strategy @@ -421,6 +426,23 @@ def create_embedding( "The caching feature is not supported yet when using NVSHMEM." "Please consider disable it by passing cache_policy = None." ) + if embedding_entry_partition is not None and cache_policy is not None: + print("embedding_entry_partition is ignored because cache_policy is specified") + embedding_entry_partition = None + if embedding_entry_partition is not None and round_robin_size != 0: + print( + "round_robin_size is ignored because embedding_entry_partition is specified" + ) + round_robin_size = 0 + if memory_type == "hierarchy": # todo: modified + comm_backend = comm.distributed_backend + if comm_backend == "nvshmem": + raise AssertionError + ("Hierarchy embedding is not supported yet when using NVSHMEM.") + if cache_policy is not None: + raise AssertionError + ("Hierarchy embedding is not supported yet when using cache.") + comm_backend = "nccl" wm_embedding = WholeMemoryEmbedding( wmb.create_embedding( @@ -429,6 +451,7 @@ def create_embedding( str_to_wmb_wholememory_memory_type(memory_type), str_to_wmb_wholememory_location(memory_location), wmb_cache_policy, + embedding_entry_partition=embedding_entry_partition, user_defined_sms=gather_sms, round_robin_size=round_robin_size, ), @@ -453,6 +476,7 @@ def create_embedding_from_filelist( last_dim_size: int, *, cache_policy: Union[WholeMemoryCachePolicy, None] = None, + embedding_entry_partition: Union[List[int], None] = None, gather_sms: int = -1, round_robin_size: int = 0, ): @@ -465,6 +489,10 @@ def create_embedding_from_filelist( :param dtype: data type :param last_dim_size: size of last dim :param cache_policy: cache policy + :param embedding_entry_partition: rank partition based on entry; + embedding_entry_partition[i] determines the entry count of rank + i and shoud be a positive integer; the sum of embedding_entry_partition + should equal to total entry count; entries will be equally partitioned if None :param gather_sms: the number of SMs used in gather process :param round_robin_size: continuous embedding size of a rank using round robin shard strategy @@ -473,6 +501,14 @@ def create_embedding_from_filelist( if isinstance(filelist, str): filelist = [filelist] assert last_dim_size > 0 + if embedding_entry_partition is not None and cache_policy is not None: + print("embedding_entry_partition is ignored because cache_policy is specified") + embedding_entry_partition = None + if embedding_entry_partition is not None and round_robin_size != 0: + print( + "round_robin_size is ignored because embedding_entry_partition is specified" + ) + round_robin_size = 0 element_size = torch.tensor([], dtype=dtype).element_size() file_entry_size = element_size * last_dim_size total_file_size = 0 @@ -492,6 +528,7 @@ def create_embedding_from_filelist( dtype, [total_entry_count, last_dim_size], cache_policy=cache_policy, + embedding_entry_partition=embedding_entry_partition, gather_sms=gather_sms, round_robin_size=round_robin_size, ) diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index f21ef87..e46ffa2 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -209,6 +209,7 @@ def create_wholememory_tensor( sizes: List[int], dtype: torch.dtype, strides: List[int], + tensor_entry_partition: Union[List[int], None] = None, ): """ Create empty WholeMemory Tensor. Now only support dim = 1 or 2 @@ -218,6 +219,10 @@ def create_wholememory_tensor( :param sizes: size of the tensor :param dtype: data type of the tensor :param strides: strides of the tensor + :param tensor_entry_partition: rank partition based on entry; + tensor_entry_partition[i] determines the entry count of rank + i and shoud be a positive integer; the sum of tensor_entry_partition + should equal to total entry count; entries will be equally partitioned if None :return: Allocated WholeMemoryTensor """ dim = len(sizes) @@ -240,7 +245,9 @@ def create_wholememory_tensor( wm_location = str_to_wmb_wholememory_location(memory_location) return WholeMemoryTensor( - wmb.create_wholememory_tensor(td, comm.wmb_comm, wm_memory_type, wm_location) + wmb.create_wholememory_tensor( + td, comm.wmb_comm, wm_memory_type, wm_location, tensor_entry_partition + ) ) @@ -252,6 +259,7 @@ def create_wholememory_tensor_from_filelist( dtype: torch.dtype, last_dim_size: int = 0, last_dim_strides: int = -1, + tensor_entry_partition: Union[List[int], None] = None, ): """ Create WholeMemory Tensor from list of binary files. @@ -263,6 +271,10 @@ def create_wholememory_tensor_from_filelist( :param last_dim_size: 0 for create 1-D array, positive value for create matrix column size :param last_dim_strides: stride of last_dim, -1 for same as size of last dim. + :param tensor_entry_partition: rank partition based on entry; + tensor_entry_partition[i] determines the entry count of rank + i and shoud be a positive integer; the sum of tensor_entry_partition + should equal to total entry count; entries will be equally partitioned if None :return: WholeMemoryTensor """ if isinstance(filelist, str): @@ -290,7 +302,13 @@ def create_wholememory_tensor_from_filelist( sizes = [total_entry_count, last_dim_size] strides = [last_dim_strides, 1] wm_tensor = create_wholememory_tensor( - comm, memory_type, memory_location, sizes, dtype, strides + comm, + memory_type, + memory_location, + sizes, + dtype, + strides, + tensor_entry_partition, ) wm_tensor.from_filelist(filelist) return wm_tensor diff --git a/python/pylibwholegraph/pylibwholegraph/torch/utils.py b/python/pylibwholegraph/pylibwholegraph/torch/utils.py index bed94c8..102ae12 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/utils.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/utils.py @@ -92,10 +92,12 @@ def str_to_wmb_wholememory_memory_type(str_wmb_type: str): return wmb.WholeMemoryMemoryType.MtChunked elif str_wmb_type == "distributed": return wmb.WholeMemoryMemoryType.MtDistributed + elif str_wmb_type == "hierarchy": + return wmb.WholeMemoryMemoryType.MtHierarchy else: raise ValueError( f"WholeMemory type {str_wmb_type} not supported," - " should be (continuous, chunked, distributed)" + " should be (continuous, chunked, distributed, hierarchy)" ) diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py index 1964404..fd44f04 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py @@ -100,10 +100,8 @@ def torch_malloc_env_fn( memory_context: TorchMemoryContext, global_context: TorchEmptyGlobalContext, ) -> int: - pinned = False device = None - if malloc_type.get_type() == wmb.WholeMemoryMemoryAllocType.MatDevice: device = torch.device("cuda") elif malloc_type.get_type() == wmb.WholeMemoryMemoryAllocType.MatHost: @@ -112,14 +110,10 @@ def torch_malloc_env_fn( assert malloc_type.get_type() == wmb.WholeMemoryMemoryAllocType.MatPinned device = torch.device("cpu") pinned = True - shape = tensor_desc.shape - dtype = wholememory_dtype_to_torch_dtype(tensor_desc.dtype) - t = torch.empty(shape, dtype=dtype, device=device, pin_memory=pinned) memory_context.set_tensor(t) - return t.data_ptr()