From 5fd1a92dc2cbd96fc07f2cbc83f20f2beb50a10d Mon Sep 17 00:00:00 2001 From: Ray Douglass Date: Fri, 19 Jul 2024 15:07:52 -0400 Subject: [PATCH 01/17] DOC v24.10 Updates [skip ci] --- .github/workflows/build.yaml | 12 ++++++------ .github/workflows/pr.yaml | 18 +++++++++--------- .github/workflows/test.yaml | 6 +++--- VERSION | 2 +- ci/build_docs.sh | 2 +- .../environments/all_cuda-118_arch-x86_64.yaml | 4 ++-- .../environments/all_cuda-125_arch-x86_64.yaml | 4 ++-- cpp/CMakeLists.txt | 2 +- cpp/Doxyfile | 2 +- dependencies.yaml | 4 ++-- fetch_rapids.cmake | 2 +- python/pylibwholegraph/CMakeLists.txt | 2 +- 12 files changed, 30 insertions(+), 30 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index d09ba5a4d..454185f4a 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -38,7 +38,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -49,7 +49,7 @@ jobs: if: github.ref_type == 'branch' needs: [python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.10 with: arch: "amd64" branch: ${{ inputs.branch }} @@ -62,7 +62,7 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-24.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -70,7 +70,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -80,7 +80,7 @@ jobs: wheel-publish-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index b48246626..6e5c86c54 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -21,41 +21,41 @@ jobs: - wheel-build-pylibwholegraph - wheel-test-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.10 checks: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.10 with: enable_check_generated_files: false conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.10 with: build_type: pull-request node_type: cpu16 conda-cpp-tests: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.10 with: build_type: pull-request conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.10 with: build_type: pull-request conda-python-tests: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.10 with: build_type: pull-request docs-build: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.10 with: build_type: pull-request arch: "amd64" @@ -64,14 +64,14 @@ jobs: wheel-build-pylibwholegraph: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.10 with: build_type: pull-request script: ci/build_wheel.sh wheel-test-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.10 with: build_type: pull-request script: ci/test_wheel.sh diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index eb258f5ae..f2d7e1cc7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,7 +16,7 @@ on: jobs: conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.10 with: build_type: nightly branch: ${{ inputs.branch }} @@ -24,7 +24,7 @@ jobs: sha: ${{ inputs.sha }} conda-pytorch-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.10 with: build_type: nightly branch: ${{ inputs.branch }} @@ -32,7 +32,7 @@ jobs: sha: ${{ inputs.sha }} wheel-tests-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.08 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.10 with: build_type: nightly branch: ${{ inputs.branch }} diff --git a/VERSION b/VERSION index ec8489fda..7c7ba0443 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -24.08.00 +24.10.00 diff --git a/ci/build_docs.sh b/ci/build_docs.sh index f4ad95f30..7f1074feb 100755 --- a/ci/build_docs.sh +++ b/ci/build_docs.sh @@ -22,7 +22,7 @@ rapids-print-env rapids-logger "Downloading artifacts from previous jobs" CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp) -export RAPIDS_VERSION_NUMBER="24.08" +export RAPIDS_VERSION_NUMBER="24.10" export RAPIDS_DOCS_DIR="$(mktemp -d)" rapids-mamba-retry install \ diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 9bb060599..100d086bc 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -24,8 +24,8 @@ dependencies: - graphviz - ipykernel - ipython -- libraft-headers==24.8.*,>=0.0.0a0 -- librmm==24.8.*,>=0.0.0a0 +- libraft-headers==24.10.*,>=0.0.0a0 +- librmm==24.10.*,>=0.0.0a0 - nanobind>=0.2.0 - nbsphinx - nccl diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 41807e89b..d8e7993c4 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -25,8 +25,8 @@ dependencies: - graphviz - ipykernel - ipython -- libraft-headers==24.8.*,>=0.0.0a0 -- librmm==24.8.*,>=0.0.0a0 +- libraft-headers==24.10.*,>=0.0.0a0 +- librmm==24.10.*,>=0.0.0a0 - nanobind>=0.2.0 - nbsphinx - nccl diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 9c364b0f6..5931b8ca2 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -14,7 +14,7 @@ # limitations under the License. #============================================================================= -set(RAPIDS_VERSION "24.08") +set(RAPIDS_VERSION "24.10") set(WHOLEGRAPH_VERSION "${RAPIDS_VERSION}.00") cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) diff --git a/cpp/Doxyfile b/cpp/Doxyfile index 58cef9e82..b14c55207 100644 --- a/cpp/Doxyfile +++ b/cpp/Doxyfile @@ -38,7 +38,7 @@ PROJECT_NAME = "WholeGraph C API" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 24.08 +PROJECT_NUMBER = 24.10 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/dependencies.yaml b/dependencies.yaml index 6bb940eac..c6f5b93b4 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -84,8 +84,8 @@ dependencies: - cxx-compiler - cython>=3.0.0 - &doxygen doxygen==1.9.1 - - libraft-headers==24.8.*,>=0.0.0a0 - - librmm==24.8.*,>=0.0.0a0 + - libraft-headers==24.10.*,>=0.0.0a0 + - librmm==24.10.*,>=0.0.0a0 - nanobind>=0.2.0 - nccl specific: diff --git a/fetch_rapids.cmake b/fetch_rapids.cmake index bfaae09f8..669811c53 100644 --- a/fetch_rapids.cmake +++ b/fetch_rapids.cmake @@ -12,7 +12,7 @@ # the License. # ============================================================================= if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/CUGRAPH_RAPIDS.cmake) - file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-24.08/RAPIDS.cmake + file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-24.10/RAPIDS.cmake ${CMAKE_CURRENT_BINARY_DIR}/CUGRAPH_RAPIDS.cmake ) endif() diff --git a/python/pylibwholegraph/CMakeLists.txt b/python/pylibwholegraph/CMakeLists.txt index d6e6df9da..c49180b1c 100644 --- a/python/pylibwholegraph/CMakeLists.txt +++ b/python/pylibwholegraph/CMakeLists.txt @@ -16,7 +16,7 @@ cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) -set(RAPIDS_VERSION "24.08") +set(RAPIDS_VERSION "24.10") set(WHOLEGRAPH_VERSION "${RAPIDS_VERSION}.00") include(FetchContent) From 9c47d4e91dcfda3f27fea9c7345c2e2e85f03c74 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Thu, 1 Aug 2024 07:12:13 -0500 Subject: [PATCH 02/17] Remove Dockerfile (#184) It looks like the `Dockerfile` in this repo is fairly old (PyTorch 22.10). I don't know if it is useful -- we have largely deleted Dockerfiles in each RAPIDS repo now that we have devcontainers. Authors: - Bradley Dice (https://github.com/bdice) Approvers: - https://github.com/linhu-nv - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/184 --- Dockerfile | 16 ---------------- 1 file changed, 16 deletions(-) delete mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 8a9046567..000000000 --- a/Dockerfile +++ /dev/null @@ -1,16 +0,0 @@ -FROM nvcr.io/nvidia/pytorch:22.10-py3 - -RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y lsb-core software-properties-common wget libspdlog-dev - -#RUN remove old cmake to update -RUN conda remove --force -y cmake -RUN rm -rf /usr/local/bin/cmake && rm -rf /usr/local/lib/cmake && rm -rf /usr/lib/cmake - -RUN apt-key adv --fetch-keys https://apt.kitware.com/keys/kitware-archive-latest.asc && \ - export LSB_CODENAME=$(lsb_release -cs) && \ - apt-add-repository -y "deb https://apt.kitware.com/ubuntu/ ${LSB_CODENAME} main" && \ - apt update && apt install -y cmake - -# update py for pytest -RUN pip3 install -U py -RUN pip3 install Cython setuputils3 scikit-build nanobind pytest-forked pytest From 91b7dcd92fcbfd2d64fad073a3c654aa6281ab45 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Fri, 2 Aug 2024 22:07:02 +0800 Subject: [PATCH 03/17] support different entry size for different ranks (#194) Allow users to specify the entry size on each rank. node_feat_wm_embedding = wgth.create_embedding( ... embedding_entry_partition=[283071, 401722, 356680, 329221, 238065, 238060, 217897, 384313] ) 1. embedding_entry_partition[i] indicates the number of embedding entries stored on the rank i. 2. If embedding_entry_partition is None, embedding will be partitioned equally. 3. Only chunked device and distributed host/device are supported. Authors: - https://github.com/zhuofan1123 Approvers: - https://github.com/linhu-nv - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/194 --- cpp/bench/CMakeLists.txt | 6 +- cpp/bench/common/wholegraph_benchmark.cpp | 19 +- cpp/bench/common/wholegraph_benchmark.hpp | 4 +- .../wholememory_ops/gather_scatter_bench.cu | 75 ++- cpp/include/wholememory/device_reference.cuh | 35 +- cpp/include/wholememory/embedding.h | 7 +- cpp/include/wholememory/global_reference.h | 10 +- cpp/include/wholememory/wholememory.h | 67 ++- cpp/include/wholememory/wholememory_tensor.h | 41 +- cpp/src/wholememory/embedding.cpp | 105 ++++- cpp/src/wholememory/embedding.hpp | 3 +- cpp/src/wholememory/memory_handle.cpp | 443 +++++++++++++----- cpp/src/wholememory/memory_handle.hpp | 25 +- cpp/src/wholememory/wholememory.cpp | 49 +- cpp/src/wholememory/wholememory_tensor.cpp | 126 ++++- .../functions/bucket_ids_func.cu | 52 +- .../functions/bucket_ids_func.h | 4 +- .../functions/embedding_cache_func.cu | 29 +- .../functions/embedding_cache_func.h | 6 +- .../functions/exchange_ids_nccl_func.cu | 4 +- .../functions/exchange_ids_nccl_func.h | 6 +- .../functions/gather_scatter_func.cuh | 35 +- .../functions/map_indices_func.cu | 14 +- .../functions/nvshmem_device_reference.cuh | 88 +++- ...r_func_impl_floating_data_int32_indices.cu | 33 +- ...r_func_impl_floating_data_int64_indices.cu | 33 +- ...er_func_impl_integer_data_int32_indices.cu | 33 +- ...er_func_impl_integer_data_int64_indices.cu | 35 +- .../functions/nvshmem_gather_scatter_func.cuh | 82 ++-- ...r_func_impl_floating_data_int32_indices.cu | 33 +- ...r_func_impl_floating_data_int64_indices.cu | 33 +- ...er_func_impl_integer_data_int32_indices.cu | 33 +- ...er_func_impl_integer_data_int64_indices.cu | 33 +- .../wholememory_ops/gather_op_impl_nccl.cu | 45 +- .../wholememory_ops/gather_op_impl_nvshmem.cu | 59 ++- .../scatter_op_impl.nvshmem.cu | 62 ++- .../wholememory_ops/scatter_op_impl_nccl.cu | 43 +- .../wholememory_ops/embedding_test_utils.cu | 19 +- .../wholememory_ops/embedding_test_utils.hpp | 4 +- ...lememory_embedding_gradient_apply_tests.cu | 57 ++- .../wholememory_embedding_tests.cu | 33 +- .../wholememory_gather_tests.cu | 27 +- .../wholememory_scatter_tests.cu | 29 +- .../binding/wholememory_binding.pyx | 109 +++-- .../pylibwholegraph/test_utils/test_comm.py | 27 +- .../pylibwholegraph/test_wholememory_io.py | 43 +- .../test_wholememory_tensor.py | 18 +- .../ops/test_wholegraph_gather_scatter.py | 22 +- .../pylibwholegraph/torch/embedding.py | 31 +- .../pylibwholegraph/torch/tensor.py | 12 +- 50 files changed, 1555 insertions(+), 686 deletions(-) diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index ee9936adf..7736c04be 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -1,5 +1,5 @@ #============================================================================= -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,7 +35,9 @@ function(ConfigureBench) rmm::rmm pthread ) - + if(BUILD_WITH_NVSHMEM) + target_compile_definitions(${BENCH_NAME} PRIVATE WITH_NVSHMEM_SUPPORT) + endif() set_target_properties( ${BENCH_NAME} PROPERTIES # set target compile options diff --git a/cpp/bench/common/wholegraph_benchmark.cpp b/cpp/bench/common/wholegraph_benchmark.cpp index 4471626c5..0be685529 100644 --- a/cpp/bench/common/wholegraph_benchmark.cpp +++ b/cpp/bench/common/wholegraph_benchmark.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,6 +52,23 @@ void host_random_init_integer_indices(void* indices, } } +void host_random_partition(size_t* partition_sizes, size_t total_size, int partition_count) +{ + std::default_random_engine random_engine(0); + std::uniform_int_distribution uniform(90, 100); + size_t acc_size = 0; + size_t random_sum = 0; + for (int i = 0; i < partition_count; i++) { + partition_sizes[i] = (size_t)uniform(random_engine); + random_sum += partition_sizes[i]; + } + for (int i = 0; i < partition_count; i++) { + partition_sizes[i] = (size_t)((partition_sizes[i] / (double)random_sum) * total_size); + acc_size += partition_sizes[i]; + } + partition_sizes[0] += total_size - acc_size; +} + void MultiProcessMeasurePerformance(std::function run_fn, wholememory_comm_t& wm_comm, const PerformanceMeter& meter, diff --git a/cpp/bench/common/wholegraph_benchmark.hpp b/cpp/bench/common/wholegraph_benchmark.hpp index 7ac85ba59..a3af9b1c7 100644 --- a/cpp/bench/common/wholegraph_benchmark.hpp +++ b/cpp/bench/common/wholegraph_benchmark.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,8 @@ void host_random_init_integer_indices(void* indices, wholememory_array_description_t indices_desc, int64_t max_indices); +void host_random_partition(size_t* partition_sizes, size_t total_size, int partition_count); + struct Metric { Metric(const std::string& metrics_name, const std::string& metrics_unit, diff --git a/cpp/bench/wholememory_ops/gather_scatter_bench.cu b/cpp/bench/wholememory_ops/gather_scatter_bench.cu index 09da2740a..b081048fb 100644 --- a/cpp/bench/wholememory_ops/gather_scatter_bench.cu +++ b/cpp/bench/wholememory_ops/gather_scatter_bench.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -77,6 +77,8 @@ typedef struct GatherScatterBenchParam { int64_t get_embedding_dim() const { return embedding_dim; } wholememory_dtype_t get_embedding_type() const { return embedding_type; } + int get_partition_method() const { return partition_method; } + std::string get_distributed_backend() const { return distributed_backend; } GatherScatterBenchParam& set_memory_type(wholememory_memory_type_t new_memory_type) { @@ -153,6 +155,18 @@ typedef struct GatherScatterBenchParam { return *this; } + GatherScatterBenchParam& set_partition_method(int new_partition_method) + { + partition_method = new_partition_method; + return *this; + } + + GatherScatterBenchParam& set_distributed_backend(std::string new_distributed_backend) + { + distributed_backend = new_distributed_backend; + return *this; + } + private: int64_t get_embedding_entry_count() const { @@ -196,6 +210,8 @@ typedef struct GatherScatterBenchParam { int64_t embedding_dim = 32; int loop_count = 20; std::string test_type = "gather"; // gather or scatter + int partition_method = 0; + std::string distributed_backend = "nccl"; // nccl or nvshmem std::string server_addr = "localhost"; int server_port = 24987; @@ -256,7 +272,15 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params) wholememory_comm_t wm_comm = create_communicator_by_socket(side_band_communicator, world_rank, world_size); - + std::string distributed_backend = params.get_distributed_backend(); +#ifdef WITH_NVSHMEM_SUPPORT + if (distributed_backend.compare("nvshmem") == 0) + WHOLEMEMORY_CHECK_NOTHROW(wholememory_communicator_set_distributed_backend( + wm_comm, WHOLEMEMORY_DB_NVSHMEM) == WHOLEMEMORY_SUCCESS); +#else + distributed_backend = "nccl"; + params.set_distributed_backend("nccl"); +#endif ShutDownSidebandCommunicator(side_band_communicator); auto embedding_desc = params.get_embedding_desc(); @@ -268,12 +292,17 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params) wholememory_tensor_t embedding_tensor; wholememory_tensor_description_t embedding_tensor_desc; wholememory_copy_matrix_desc_to_tensor(&embedding_tensor_desc, &embedding_desc); + std::vector rank_partition(world_size); + wholegraph::bench::host_random_partition( + rank_partition.data(), embedding_tensor_desc.sizes[0], world_size); WHOLEMEMORY_CHECK_NOTHROW(wholememory_create_tensor(&embedding_tensor, &embedding_tensor_desc, wm_comm, params.get_memory_type(), - params.get_memory_location()) == - WHOLEMEMORY_SUCCESS); + params.get_memory_location(), + params.get_partition_method() == 1 + ? rank_partition.data() + : nullptr) == WHOLEMEMORY_SUCCESS); cudaStream_t stream; WM_CUDA_CHECK_NO_THROW(cudaStreamCreate(&stream)); @@ -318,8 +347,8 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params) double gather_size_mb = (double)params.get_gather_size() / 1024.0 / 1024.0; if (local_rank == 0) { printf( - "%s, world_size=%d, memoryType=%s, memoryLocation=%s, elt_size=%ld, embeddingDim=%ld, " - "embeddingTableSize=%.2lf MB, gatherSize=%.2lf MB\n", + "%s, worldSize=%d, memoryType=%s, memoryLocation=%s, eltSize=%ld, embeddingDim=%ld, " + "embeddingTableSize=%.2lf MB, gatherSize=%.2lf MB, distributedBackend=%s\n", test_type.c_str(), world_size, get_memory_type_string(params.get_memory_type()).c_str(), @@ -327,7 +356,8 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params) wholememory_dtype_get_element_size(params.get_embedding_type()), params.get_embedding_dim(), emb_size_mb, - gather_size_mb); + gather_size_mb, + distributed_backend.c_str()); } PerformanceMeter meter; @@ -388,7 +418,7 @@ void gather_scatter_benchmark(GatherScatterBenchParam& params) int main(int argc, char** argv) { wholegraph::bench::gather_scatter::GatherScatterBenchParam params; - const char* optstr = "ht:l:e:g:d:c:f:a:p:r:s:n:"; + const char* optstr = "ht:l:e:g:d:c:f:a:p:r:s:n:m:b:"; struct option opts[] = { {"help", no_argument, NULL, 'h'}, {"memory_type", @@ -405,8 +435,9 @@ int main(int argc, char** argv) {"node_size", required_argument, NULL, 's'}, // node_size {"num_gpu", required_argument, NULL, 'n'}, // num gpu per node {"server_addr", required_argument, NULL, 'a'}, // server_addr - {"server_port", required_argument, NULL, 'p'} // server_port - }; + {"server_port", required_argument, NULL, 'p'}, // server_port + {"partition_method", required_argument, NULL, 'm'}, + {"distributed_backend", required_argument, NULL, 'b'}}; const char* usage = "Usage: %s [options]\n" @@ -424,7 +455,9 @@ int main(int argc, char** argv) " -s, --node_size node_size or process count\n" " -n, --num_gpu num_gpu per process\n" " -a, --server_addr specify sideband server address\n" - " -p, --server_port specify sideband server port\n"; + " -p, --server_port specify sideband server port\n" + " -m, --partition_method specify rank partition method, 0: Default, 1: Random\n" + " -b, --distributed_backend specify distributed backend: nccl or nvshmem\n"; int c; bool has_option = false; @@ -536,6 +569,26 @@ int main(int argc, char** argv) } params.set_num_gpu(val); break; + case 'm': + val = std::atoi(optarg); + if (val != 0 && val != 1) { + printf("Invalid argument for option -m\n"); + printf(usage, argv[0]); + exit(EXIT_FAILURE); + } + params.set_partition_method(val); + break; + case 'b': + if (strcmp(optarg, "nccl") == 0) { + params.set_distributed_backend("nccl"); + } else if (strcmp(optarg, "nvshmem") == 0) { + params.set_distributed_backend("nvshmem"); + } else { + printf("Invalid argument for option -b\n"); + printf(usage, argv[0]); + exit(EXIT_FAILURE); + } + break; default: printf("Invalid or unrecognized option\n"); printf(usage, argv[0]); diff --git a/cpp/include/wholememory/device_reference.cuh b/cpp/include/wholememory/device_reference.cuh index 87f8021ed..8f2146ae9 100644 --- a/cpp/include/wholememory/device_reference.cuh +++ b/cpp/include/wholememory/device_reference.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,48 @@ class device_reference { public: __device__ __forceinline__ explicit device_reference(const wholememory_gref_t& gref) : pointer_(static_cast(gref.pointer)), - typed_stride_(gref.stride / sizeof(DataTypeT)) + typed_stride_(gref.stride / sizeof(DataTypeT)), + world_size_(gref.world_size), + same_chunk_(gref.same_chunk) { assert(gref.stride % sizeof(DataTypeT) == 0); + if (typed_stride_ != 0 && !same_chunk_) { + assert(world_size_ <= 8); // intra-node WHOLEMEMORY_MT_CHUNKED + for (int i = 0; i < world_size_ + 1; i++) { + assert(gref.rank_memory_offsets[i] % sizeof(DataTypeT) == 0); + typed_rank_mem_offsets_[i] = gref.rank_memory_offsets[i] / sizeof(DataTypeT); + } + } } __device__ device_reference() = delete; __device__ __forceinline__ DataTypeT& operator[](size_t index) { if (typed_stride_ == 0) { return pointer_[index]; } - size_t rank = index / typed_stride_; - return static_cast( - static_cast(pointer_))[rank][index - rank * typed_stride_]; + if (same_chunk_) { + size_t rank = index / typed_stride_; + return static_cast( + static_cast(pointer_))[rank][index - rank * typed_stride_]; + } else { + size_t rank = 0; + for (int i = 1; i < world_size_ + 1; i++) { + if (index < typed_rank_mem_offsets_[i]) { + rank = i - 1; + break; + } + } + return static_cast( + static_cast(pointer_))[rank][index - typed_rank_mem_offsets_[rank]]; + } } private: DataTypeT* pointer_; + int world_size_; size_t typed_stride_; + + bool same_chunk_; + size_t typed_rank_mem_offsets_[8 + 1]; }; } // namespace wholememory diff --git a/cpp/include/wholememory/embedding.h b/cpp/include/wholememory/embedding.h index 08cd73e84..1853742f5 100644 --- a/cpp/include/wholememory/embedding.h +++ b/cpp/include/wholememory/embedding.h @@ -129,6 +129,8 @@ wholememory_error_code_t wholememory_destroy_embedding_cache_policy( * @param memory_type : Memory Type of the underlying WholeMemory * @param memory_location : Memory Location of the underlying WholeMemory * @param cache_policy : Cache policy for this embedding, if don't use cache, use nullptr + * @param embedding_entry_partition: Embedding entry count of each rank, the length must be + * world_size * @param user_defined_sms : User-defined sms number for raw embedding gather/scatter * @param round_robin_size : continuous embedding size in each rank under round-robin shard mode * @return : wholememory_error_code_t @@ -140,8 +142,9 @@ wholememory_error_code_t wholememory_create_embedding( wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, wholememory_embedding_cache_policy_t cache_policy, - int user_defined_sms = -1, - int round_robin_size = 0); + size_t* embedding_entry_partition = nullptr, + int user_defined_sms = -1, + int round_robin_size = 0); /** * Destroy WholeMemory Embedding diff --git a/cpp/include/wholememory/global_reference.h b/cpp/include/wholememory/global_reference.h index 531ad25e7..7f174756e 100644 --- a/cpp/include/wholememory/global_reference.h +++ b/cpp/include/wholememory/global_reference.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,14 +24,18 @@ extern "C" { /** * @brief Global reference of a WholeMemory object * - * A global reference is for Continuous of Chunked WholeMemory Type, in these types, each rank can + * A global reference is for Continuous or Chunked WholeMemory Type, in these types, each rank can * directly access all memory from all ranks. The global reference is used to do this direct access. */ struct wholememory_gref_t { void* pointer; /*!< pointer to data for CONTINUOUS WholeMemory or pointer to data pointer array for CHUNKED WholeMemory */ + size_t* + rank_memory_offsets; /*!< memory offset of each rank, and the length must be world_size+1 */ + int world_size; size_t stride; /*!< must be 0 for CONTINUOUS WholeMemory or memory size in byte for each pointer */ + bool same_chunk; /*!< if true, rank can be got by offset/stride */ }; /** @@ -43,9 +47,11 @@ wholememory_gref_t wholememory_create_continuous_global_reference(void* ptr); struct wholememory_nvshmem_ref_t { void* pointer; + size_t* rank_memory_offsets; size_t stride; int world_rank; int world_size; + bool same_chunk; }; #ifdef __cplusplus diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index f6bacccb3..a1678ee8b 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -238,6 +238,7 @@ typedef struct wholememory_handle_* wholememory_handle_t; * @param memory_type : WholeMemory type * @param memory_location : memory location, host or device * @param data_granularity : granularity size of data, which is guaranteed not to be partitioned. + * @param rank_entry_partition : entry count of each rank (size of entry equal to data_granularity) * @return : wholememory_error_code_t */ wholememory_error_code_t wholememory_malloc(wholememory_handle_t* wholememory_handle_ptr, @@ -245,7 +246,8 @@ wholememory_error_code_t wholememory_malloc(wholememory_handle_t* wholememory_ha wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity); + size_t data_granularity, + size_t* rank_entry_partition = nullptr); /** * Free allocated WholeMemory Handle @@ -309,6 +311,24 @@ wholememory_error_code_t wholememory_get_local_memory(void** local_ptr, size_t* local_offset, wholememory_handle_t wholememory_handle); +/** + * Get local memory size from WholeMemory Handle of current rank + * @param local_size : returned local memory size + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_get_local_size(size_t* local_size, + wholememory_handle_t wholememory_handle); + +/** + * Get local memory offset from WholeMemory Handle of current rank + * @param local_offset : returned local memory offset + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_get_local_offset(size_t* local_offset, + wholememory_handle_t wholememory_handle); + /** * Get local memory of specified rank from WholeMemory Handle * @param rank_memory_ptr : returned local memory pointer of specified rank @@ -324,6 +344,17 @@ wholememory_error_code_t wholememory_get_rank_memory(void** rank_memory_ptr, int rank, wholememory_handle_t wholememory_handle); +/** + * Get the equal partition plan WholeMemory uses by default + * @param entry_per_rank : returned entry count per rank + * @param total_entry_count : total entry count + * @param world_size : communicator world size + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_equal_entry_partition_plan(size_t* entry_per_rank, + size_t total_entry_count, + int world_size); + /** * Get global memory pointer from WholeMemory Handle. * Only Continuous memory type or Chunked Host memory has global pointer. @@ -345,38 +376,22 @@ wholememory_error_code_t wholememory_get_global_reference(wholememory_gref_t* wh wholememory_handle_t wholememory_handle); /** - * Get the partition plan WholeMemory will use - * @param size_per_rank : returned size per rank - * @param total_size : total size - * @param data_granularity : data granularity - * @param world_size : communicator world size - * @return : wholememory_error_code_t - */ -wholememory_error_code_t wholememory_determine_partition_plan(size_t* size_per_rank, - size_t total_size, - size_t data_granularity, - int world_size); - -/** - * Get the partition plan WholeMemory will use based on entry count. - * Entry is number of data granularity - * @param entry_per_rank : returned entry count per rank - * @param total_entry_count : total entry count - * @param world_size : communicator world size + * Get memory size of each rank from WholeMemory Handle + * @param rank_mem_sizes : returned memory size of each rank + * @param wholememory_handle : WholeMemory Handle * @return : wholememory_error_code_t */ -wholememory_error_code_t wholememory_determine_entry_partition_plan(size_t* entry_per_rank, - size_t total_entry_count, - int world_size); +wholememory_error_code_t wholememory_get_rank_partition_sizes( + size_t* rank_mem_sizes, wholememory_handle_t wholememory_handle); /** - * Get the partition plan used in WholeMemory Handle - * @param size_per_rank : returned size per rank + * Get memory offset of each rank from WholeMemory Handle + * @param rank_mem_offsets : returned memory offset of each rank * @param wholememory_handle : WholeMemory Handle * @return : wholememory_error_code_t */ -wholememory_error_code_t wholememory_get_partition_plan(size_t* size_per_rank, - wholememory_handle_t wholememory_handle); +wholememory_error_code_t wholememory_get_rank_partition_offsets( + size_t* rank_mem_offsets, wholememory_handle_t wholememory_handle); /** * Fork a new process and get device count. Should be called before other CUDA call diff --git a/cpp/include/wholememory/wholememory_tensor.h b/cpp/include/wholememory/wholememory_tensor.h index ff7f65c28..9acd3e5bb 100644 --- a/cpp/include/wholememory/wholememory_tensor.h +++ b/cpp/include/wholememory/wholememory_tensor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ typedef struct wholememory_tensor_* wholememory_tensor_t; * @param comm : WholeMemory Communicator * @param memory_type : Memory Type of the underlying WholeMemory * @param memory_location : Memory Location of the underlying WholeMemory + * @param tensor_entry_partition : Tensor entry count of each rank, the length must be world_size. * @return : wholememory_error_code_t */ wholememory_error_code_t wholememory_create_tensor( @@ -44,7 +45,8 @@ wholememory_error_code_t wholememory_create_tensor( wholememory_tensor_description_t* tensor_description, wholememory_comm_t comm, wholememory_memory_type_t memory_type, - wholememory_memory_location_t memory_location); + wholememory_memory_location_t memory_location, + size_t* tensor_entry_partition = nullptr); /** * Destroy WholeMemory Tensor @@ -131,11 +133,40 @@ wholememory_error_code_t wholememory_tensor_map_local_tensor( void* wholememory_tensor_get_data_pointer(wholememory_tensor_t wholememory_tensor); /** - * Get entry count per rank of a WholeMemory Tensor + * Get entry offset of each rank from WholeMemory Tensor + * @param entry_offsets : returned entry offset of each rank * @param wholememory_tensor : WholeMemory Tensor - * @return : entry count per rank + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_tensor_get_entry_offsets( + size_t* entry_offsets, wholememory_tensor_t wholememory_tensor); + +/** + * Get entry count of each rank from WholeMemory Tensor + * @param entry_partition : returned entry count of each rank + * @param wholememory_tensor : WholeMemory Tensor + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_tensor_get_entry_partition_sizes( + size_t* entry_partition, wholememory_tensor_t wholememory_tensor); + +/** + * Get entry count of current rank from WholeMemory Tensor + * @param local_entry_count : returned entry count of current rank + * @param wholememory_tensor : WholeMemory Tensor + * @return : wholememory_error_code_t + */ +wholememory_error_code_t wholememory_tensor_get_local_entry_count( + size_t* local_entry_count, wholememory_tensor_t wholememory_tensor); + +/** + * Get entry start of current rank from WholeMemory Tensor + * @param local_entry_start : returned entry start id of current rank + * @param wholememory_tensor : WholeMemory Tensor + * @return : wholememory_error_code_t */ -size_t wholememory_tensor_get_entry_per_partition(wholememory_tensor_t wholememory_tensor); +wholememory_error_code_t wholememory_tensor_get_local_entry_start( + size_t* local_entry_start, wholememory_tensor_t wholememory_tensor); /** * Get sub tensor of a WholeMemory Tensor diff --git a/cpp/src/wholememory/embedding.cpp b/cpp/src/wholememory/embedding.cpp index f1a868a84..23e9ccb53 100644 --- a/cpp/src/wholememory/embedding.cpp +++ b/cpp/src/wholememory/embedding.cpp @@ -89,7 +89,8 @@ wholememory_error_code_t embedding_base::allocate( wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - wholememory_embedding_cache_policy_t policy) noexcept + wholememory_embedding_cache_policy_t policy, + size_t* embedding_entry_partition) noexcept { cache_policy = policy; raw_embedding_comm_ = comm; @@ -109,6 +110,7 @@ wholememory_error_code_t embedding_base::allocate( comm, memory_type, memory_location)); + embedding_entry_partition = nullptr; } else { wholememory_copy_matrix_desc_to_tensor(&padded_embedding_tensor_description, embedding_description); @@ -123,7 +125,8 @@ wholememory_error_code_t embedding_base::allocate( &padded_embedding_tensor_description, comm, memory_type, - memory_location)); + memory_location, + embedding_entry_partition)); int64_t starts[2] = {0, 0}; int64_t ends[2] = {embedding_description->sizes[0], embedding_description->sizes[1]}; WHOLEMEMORY_RETURN_ON_FAIL( @@ -155,8 +158,6 @@ wholememory_error_code_t embedding_base::gather_gradient_apply(wholememory_tenso host_rank_id_count_handle(p_env_fns); wholememory_ops::temp_memory_handle dev_recv_indices_buffer_handle(p_env_fns); wholememory_ops::temp_memory_handle dev_raw_indice_handle(p_env_fns); - size_t const embedding_entry_count_per_rank = - wholememory_tensor_get_entry_per_partition(allocated_embedding); wholememory_ops::wm_thrust_allocator thrust_allocator(p_env_fns); int world_size = -1, world_rank = -1; int64_t* host_recv_rank_id_count_ptr = nullptr; @@ -174,6 +175,21 @@ wholememory_error_code_t embedding_base::gather_gradient_apply(wholememory_tenso wholememory_array_description_t indice_array_desc; WHOLEMEMORY_CHECK_NOTHROW( wholememory_convert_tensor_desc_to_array(&indice_array_desc, indice_desc)); + + wholememory_ops::temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + wholememory_ops::temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_tensor_get_entry_offsets(host_embedding_entry_offsets_ptr, allocated_embedding)); + WM_CUDA_CHECK_NO_THROW(cudaMemcpy(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice)); + WHOLEMEMORY_RETURN_ON_FAIL( wholememory_ops::bucket_and_exchange_ids_func(wholememory_tensor_get_data_pointer(indices), indice_array_desc, @@ -181,7 +197,7 @@ wholememory_error_code_t embedding_base::gather_gradient_apply(wholememory_tenso host_rank_id_count_ptr, &dev_recv_indices_buffer_handle, dev_raw_indice_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, raw_embedding_comm_, &thrust_allocator, p_env_fns, @@ -308,9 +324,9 @@ wholememory_error_code_t embedding_base::gather_gradient_apply(wholememory_tenso wholememory_error_code_t embedding_base::create_optimizer_states() noexcept { + wholememory_handle_t wm_handle = wholememory_tensor_get_memory_handle(allocated_embedding); wholememory_comm_t wm_raw_comm; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator( - &wm_raw_comm, wholememory_tensor_get_memory_handle(allocated_embedding))); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_raw_comm, wm_handle)); int world_rank, world_size; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&world_rank, wm_raw_comm)); @@ -321,12 +337,16 @@ wholememory_error_code_t embedding_base::create_optimizer_states() noexcept int64_t start[2] = {0, 0}; int64_t end[2] = {user_tensor_desc->sizes[1], -1}; - size_t entry_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_determine_entry_partition_plan( - &entry_per_rank, allocated_tensor_desc->sizes[0], world_size)); + std::vector allocated_tensor_entry_partition(world_size); + std::vector user_tensor_entry_partition(world_size); + wholememory_tensor_get_entry_partition_sizes(allocated_tensor_entry_partition.data(), + allocated_embedding); + wholememory_tensor_get_entry_partition_sizes(user_tensor_entry_partition.data(), user_embedding); optimizer_state_ = std::make_unique(); - optimizer_state_->local_start_index = entry_per_rank * world_rank; + optimizer_state_->local_start_index = 0; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_tensor_get_local_entry_start( + (size_t*)&optimizer_state_->local_start_index, allocated_embedding)); optimizer_impl_base_->create_optimizer_states(optimizer_state_.get(), user_tensor_desc->sizes[1]); bool const need_cachable_states = !optimizer_state_->cachable_states.empty(); wholememory_tensor_description_t cachable_state_desc; @@ -363,7 +383,8 @@ wholememory_error_code_t embedding_base::create_optimizer_states() noexcept raw_embedding_comm_, memory_type, memory_location, - cache_policy)); + cache_policy, + user_tensor_entry_partition.data())); optimizer_state_->global_cachable_raw_user_tensor = wholememory_embedding_get_embedding_tensor(optimizer_state_->cachable_state_embedding); @@ -391,7 +412,8 @@ wholememory_error_code_t embedding_base::create_optimizer_states() noexcept &uc_desc, wm_raw_comm, WHOLEMEMORY_MT_DISTRIBUTED, - WHOLEMEMORY_ML_DEVICE)); + WHOLEMEMORY_ML_DEVICE, + allocated_tensor_entry_partition.data())); start[0] = 0; start[1] = 0; end[0] = user_tensor_desc->sizes[0]; @@ -564,8 +586,6 @@ wholememory_error_code_t device_cached_host_embedding::gather(wholememory_tensor host_rank_id_count_handle(p_env_fns); wholememory_ops::temp_memory_handle dev_recv_indices_buffer_handle(p_env_fns); wholememory_ops::temp_memory_handle dev_raw_indice_handle(p_env_fns); - size_t const embedding_entry_count_per_rank = - wholememory_tensor_get_entry_per_partition(allocated_embedding); wholememory_ops::wm_thrust_allocator thrust_allocator(p_env_fns); int world_size = -1, world_rank = -1; int64_t* host_recv_rank_id_count_ptr = nullptr; @@ -584,6 +604,20 @@ wholememory_error_code_t device_cached_host_embedding::gather(wholememory_tensor wholememory_array_description_t indice_array_desc; WHOLEMEMORY_CHECK_NOTHROW( wholememory_convert_tensor_desc_to_array(&indice_array_desc, indice_desc)); + + wholememory_ops::temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + wholememory_ops::temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_tensor_get_entry_offsets(host_embedding_entry_offsets_ptr, allocated_embedding)); + WM_CUDA_CHECK_NO_THROW(cudaMemcpy(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice)); WHOLEMEMORY_RETURN_ON_FAIL( wholememory_ops::bucket_and_exchange_ids_func(wholememory_tensor_get_data_pointer(indices), indice_array_desc, @@ -591,7 +625,7 @@ wholememory_error_code_t device_cached_host_embedding::gather(wholememory_tensor host_rank_id_count_ptr, &dev_recv_indices_buffer_handle, dev_raw_indice_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, raw_embedding_comm_, &thrust_allocator, p_env_fns, @@ -642,8 +676,10 @@ wholememory_error_code_t device_cached_host_embedding::gather(wholememory_tensor wholememory_gref_t cache_line_tag_gref; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_tensor_get_global_reference( cache_ptr_->get_cache_local_data()->cache_line_tag_, &cache_line_tag_gref)); - int64_t const rank_start_gid = - wholememory_tensor_get_entry_per_partition(allocated_embedding) * world_rank; + + size_t rank_start_gid = 0; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_tensor_get_local_entry_start(&rank_start_gid, allocated_embedding)); wholememory_tensor_description_t recv_indices_desc; auto recv_indices_array_desc = wholememory_create_array_desc(total_recv_count, 0, indice_desc->dtype); @@ -756,8 +792,6 @@ wholememory_error_code_t local_cached_global_readonly_embedding::gather( host_rank_id_count_handle(p_env_fns); wholememory_ops::temp_memory_handle dev_recv_indices_buffer_handle(p_env_fns); wholememory_ops::temp_memory_handle dev_raw_indice_handle(p_env_fns); - size_t const embedding_entry_count_per_rank = - wholememory_tensor_get_entry_per_partition(cache_ptr_->access_count_wm_tensor_); wholememory_ops::wm_thrust_allocator thrust_allocator(p_env_fns); int cache_world_size = -1, cache_world_rank = -1; int64_t* host_recv_rank_id_count_ptr = nullptr; @@ -779,6 +813,19 @@ wholememory_error_code_t local_cached_global_readonly_embedding::gather( wholememory_array_description_t indice_array_desc; WHOLEMEMORY_CHECK_NOTHROW( wholememory_convert_tensor_desc_to_array(&indice_array_desc, indice_desc)); + + wholememory_ops::temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(cache_world_size + 1, WHOLEMEMORY_DT_INT64)); + wholememory_ops::temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(cache_world_size + 1, WHOLEMEMORY_DT_INT64)); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_tensor_get_entry_offsets( + host_embedding_entry_offsets_ptr, cache_ptr_->access_count_wm_tensor_)); + WM_CUDA_CHECK_NO_THROW(cudaMemcpy(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (cache_world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice)); WHOLEMEMORY_RETURN_ON_FAIL( wholememory_ops::bucket_and_exchange_ids_func(wholememory_tensor_get_data_pointer(indices), indice_array_desc, @@ -786,7 +833,7 @@ wholememory_error_code_t local_cached_global_readonly_embedding::gather( host_rank_id_count_ptr, &dev_recv_indices_buffer_handle, dev_raw_indice_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, cache_policy->cache_comm, &thrust_allocator, p_env_fns, @@ -804,7 +851,7 @@ wholememory_error_code_t local_cached_global_readonly_embedding::gather( update_indice_desc, allocated_embedding, cache_policy->cache_comm, - embedding_entry_count_per_rank, + host_embedding_entry_offsets_ptr, cache_ptr_->get_cache_local_data(), cache_ptr_->get_cache_set_coverage(), p_env_fns, @@ -903,6 +950,7 @@ wholememory_error_code_t wholememory_create_embedding( wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, wholememory_embedding_cache_policy_t cache_policy, + size_t* embedding_entry_partition, int user_defined_sms, int round_robin_size) { @@ -961,15 +1009,24 @@ wholememory_error_code_t wholememory_create_embedding( } embedding_impl_ptr = new wholememory::local_cached_global_readonly_embedding(); } + embedding_entry_partition = nullptr; } else { embedding_impl_ptr = new wholememory::noncached_embedding(); } + if (embedding_entry_partition) { + if (round_robin_size != 0) { WHOLEMEMORY_WARN("Parameter 'round_robin_size' is ignored."); } + round_robin_size = 0; + } WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&embedding_world_size, comm)); embedding_impl_ptr->set_shard_method( &embedding_matrix_description, embedding_world_size, round_robin_size); embedding_impl_ptr->set_gather_sms(user_defined_sms); - WHOLEMEMORY_RETURN_ON_FAIL(embedding_impl_ptr->allocate( - &embedding_matrix_description, comm, memory_type, memory_location, cache_policy)); + WHOLEMEMORY_RETURN_ON_FAIL(embedding_impl_ptr->allocate(&embedding_matrix_description, + comm, + memory_type, + memory_location, + cache_policy, + embedding_entry_partition)); *wholememory_embedding = static_cast(embedding_impl_ptr); return WHOLEMEMORY_SUCCESS; } diff --git a/cpp/src/wholememory/embedding.hpp b/cpp/src/wholememory/embedding.hpp index f593c36ab..616667c36 100644 --- a/cpp/src/wholememory/embedding.hpp +++ b/cpp/src/wholememory/embedding.hpp @@ -45,7 +45,8 @@ class embedding_base : public wholememory_embedding_ { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - wholememory_embedding_cache_policy_t policy) noexcept; + wholememory_embedding_cache_policy_t policy, + size_t* embedding_entry_partition) noexcept; void deallocate() noexcept; virtual wholememory_error_code_t gather(wholememory_tensor_t indices, wholememory_tensor_t output, diff --git a/cpp/src/wholememory/memory_handle.cpp b/cpp/src/wholememory/memory_handle.cpp index 16ed43760..c8f1644e3 100644 --- a/cpp/src/wholememory/memory_handle.cpp +++ b/cpp/src/wholememory/memory_handle.cpp @@ -57,7 +57,8 @@ class wholememory_impl { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) : handle_(wholememory_handle), comm_(comm), type_(memory_type), @@ -65,6 +66,16 @@ class wholememory_impl { total_size_(total_size), data_granularity_(data_granularity) { + if (rank_entry_partition != nullptr) { + rank_partition_strategy_.partition_sizes_.resize(comm_->world_size, 0); + rank_partition_strategy_.partition_offsets_.resize(comm_->world_size + 1, 0); + for (int i = 0; i < comm_->world_size; i++) { + rank_partition_strategy_.partition_sizes_[i] = rank_entry_partition[i] * data_granularity_; + rank_partition_strategy_.partition_offsets_[i + 1] = + rank_partition_strategy_.partition_offsets_[i] + + rank_entry_partition[i] * data_granularity_; + } + } distrubuted_backend_ = WHOLEMEMORY_DB_NCCL; } wholememory_impl() = delete; @@ -89,16 +100,17 @@ class wholememory_impl { [[nodiscard]] virtual wholememory_gref_t get_global_reference() const noexcept { wholememory_gref_t gref{}; - gref.pointer = nullptr; - gref.stride = 0; + gref.pointer = nullptr; + gref.stride = 0; + gref.world_size = comm_->world_size; return gref; } virtual bool contains_pointer(const void* ptr) const = 0; void get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset) const { if (local_ptr != nullptr) *local_ptr = local_partition_memory_pointer_; - if (local_size != nullptr) *local_size = rank_partition_strategy_.local_mem_size; - if (local_offset != nullptr) *local_offset = rank_partition_strategy_.local_mem_offset; + if (local_size != nullptr) *local_size = get_local_size(); + if (local_offset != nullptr) *local_offset = get_local_offset(); if (location_ == WHOLEMEMORY_ML_HOST && (type_ == WHOLEMEMORY_MT_CONTINUOUS) && (!(comm_->is_intranode()))) { WHOLEMEMORY_WARN( @@ -120,6 +132,19 @@ class wholememory_impl { { return rank_partition_strategy_.partition_mem_stride; } + [[nodiscard]] size_t get_local_size() const + { + return rank_partition_strategy_.partition_sizes_[comm_->world_rank]; + } + [[nodiscard]] size_t get_local_offset() const + { + return rank_partition_strategy_.partition_offsets_[comm_->world_rank]; + } + std::vector get_rank_sizes() const { return rank_partition_strategy_.partition_sizes_; } + std::vector get_rank_offsets() const + { + return rank_partition_strategy_.partition_offsets_; + } protected: // In WholeMemory, memory is first allocated by one or all ranks, and then partition the whole @@ -136,18 +161,19 @@ class wholememory_impl { // first rank responsible for all memory allocation, continuous or chunked host shared memory may // use this mode. void first_rank_allocate_all_strategy(); - // each rank allocate exactly the same size, chunked device memory or nccl memory may use this - // mode. - void each_rank_same_chunk_strategy(); + // each rank allocate different size, chunked device memory or nccl memory may use this + // mode. If rank_entry_partition isn't set, each rank allocate exactly the same size. + void each_rank_different_chunk_strategy(); // each rank allocate a multiple of pages, and map the whole memory by page, continuous device // memory use this mode. void each_rank_multiple_page_strategy(); // For now, memory rank partitioning strategy is the same for all WholeMemory types. - // Each rank is response for memory of size local_mem_size_ starting from local_mem_offset_. - // And local_mem_offset_ can also be got by rank_mem_stride_ * rank for ranks with local_mem_size_ - // != 0 That means for a valid memory offset position, offset / rank_mem_stride_ can be used to - // get the rank which is responsible for it. + // Each rank is response for memory of size local_mem_size starting from local_mem_offset. + // Local_mem_size can be got by calling get_local_size(), and local_mem_offset can be got + // by calling get_local_offset(). rank_partition_strategy_.partition_sizes_ and + // rank_partition_strategy_.partition_offsets_ record the memory size and memory offset of + // all ranks. void generate_rank_partition_strategy(); /* @@ -182,12 +208,12 @@ class wholememory_impl { } alloc_strategy_; struct partition_strategy { - // size of memory this rank is responsible for - size_t local_mem_size = 0; - // start location of the memory this rank is responsible for - size_t local_mem_offset = 0; + std::vector partition_sizes_; + std::vector partition_offsets_; size_t partition_mem_stride = 0; + bool same_chunk; } rank_partition_strategy_; + void* local_partition_memory_pointer_ = nullptr; void get_rank_partition_info(size_t* rank_mem_size, @@ -195,12 +221,9 @@ class wholememory_impl { int rank) const noexcept { WHOLEMEMORY_CHECK_NOTHROW(rank >= 0 && rank <= comm_->world_size); - size_t rank_mem_part_start = - std::min(rank_partition_strategy_.partition_mem_stride * rank, total_size_); - size_t rank_mem_part_end = - std::min(rank_partition_strategy_.partition_mem_stride * (rank + 1), total_size_); - if (rank_mem_size != nullptr) *rank_mem_size = rank_mem_part_end - rank_mem_part_start; - if (rank_mem_start != nullptr) *rank_mem_start = rank_mem_part_start; + if (rank_mem_size != nullptr) *rank_mem_size = rank_partition_strategy_.partition_sizes_[rank]; + if (rank_mem_start != nullptr) + *rank_mem_start = rank_partition_strategy_.partition_offsets_[rank]; } static constexpr size_t HUGE_PAGE_THRESHOLD = 16UL * 1024UL * 1024UL * 1024UL; @@ -293,16 +316,22 @@ class distributed_wholememory_impl : public wholememory_impl { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : wholememory_impl( - wholememory_handle, total_size, comm, memory_type, memory_location, data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) + : wholememory_impl(wholememory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) { WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED); } void create_memory() override { - each_rank_same_chunk_strategy(); generate_rank_partition_strategy(); + each_rank_different_chunk_strategy(); create_local_cuda_runtime_memory(); register_private_memory(); } @@ -387,17 +416,23 @@ class global_mapped_host_wholememory_impl : public wholememory_impl { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : wholememory_impl( - wholememory_handle, total_size, comm, memory_type, memory_location, data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) + : wholememory_impl(wholememory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) { WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_CONTINUOUS || type_ == WHOLEMEMORY_MT_CHUNKED); WHOLEMEMORY_CHECK(location_ == WHOLEMEMORY_ML_HOST); } void create_memory() override { - first_rank_allocate_all_strategy(); generate_rank_partition_strategy(); + first_rank_allocate_all_strategy(); create_and_map_shared_host_memory(); register_host_memory(); } @@ -413,8 +448,9 @@ class global_mapped_host_wholememory_impl : public wholememory_impl { [[nodiscard]] wholememory_gref_t get_global_reference() const noexcept override { wholememory_gref_t gref{}; - gref.pointer = get_continuous_mapping_pointer(); - gref.stride = 0; + gref.pointer = get_continuous_mapping_pointer(); + gref.stride = 0; + gref.world_size = comm_->world_size; return gref; } bool contains_pointer(const void* ptr) const override @@ -423,6 +459,7 @@ class global_mapped_host_wholememory_impl : public wholememory_impl { uint64_t int_start_ptr = reinterpret_cast(shared_host_handle_.shared_host_memory_ptr); return int_ptr >= int_start_ptr && int_ptr < int_start_ptr + total_size_; } + bool get_rank_memory(void** rank_memory_ptr, size_t* rank_memory_size, size_t* rank_memory_offset, @@ -531,9 +568,7 @@ class global_mapped_host_wholememory_impl : public wholememory_impl { nullptr, alloc_strategy_.total_alloc_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); WHOLEMEMORY_CHECK(mmap_ptr != (void*)-1); } - memset(static_cast(mmap_ptr) + rank_partition_strategy_.local_mem_offset, - 0, - rank_partition_strategy_.local_mem_size); + memset(static_cast(mmap_ptr) + get_local_offset(), 0, get_local_size()); WM_CUDA_CHECK_NO_THROW( cudaHostRegister(mmap_ptr, alloc_strategy_.total_alloc_size, cudaHostRegisterDefault)); if (!use_systemv_shm_) WHOLEMEMORY_CHECK(close(shm_fd) == 0); @@ -541,8 +576,7 @@ class global_mapped_host_wholememory_impl : public wholememory_impl { WM_CUDA_CHECK_NO_THROW(cudaHostGetDevicePointer(&dev_ptr, mmap_ptr, 0)); WHOLEMEMORY_CHECK(dev_ptr == mmap_ptr); shared_host_handle_.shared_host_memory_ptr = dev_ptr; - local_partition_memory_pointer_ = - static_cast(dev_ptr) + rank_partition_strategy_.local_mem_offset; + local_partition_memory_pointer_ = static_cast(dev_ptr) + get_local_offset(); } void unmap_and_destroy_shared_host_memory() noexcept @@ -603,17 +637,23 @@ class continuous_device_wholememory_impl : public wholememory_impl { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : wholememory_impl( - wholememory_handle, total_size, comm, memory_type, memory_location, data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) + : wholememory_impl(wholememory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) { WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_CONTINUOUS); } void create_memory() override { WHOLEMEMORY_CHECK(location_ == WHOLEMEMORY_ML_DEVICE); - each_rank_multiple_page_strategy(); generate_rank_partition_strategy(); + each_rank_multiple_page_strategy(); create_and_map_driver_device_memory(); register_continuous_device_memory(); } @@ -629,8 +669,9 @@ class continuous_device_wholememory_impl : public wholememory_impl { [[nodiscard]] wholememory_gref_t get_global_reference() const noexcept override { wholememory_gref_t gref{}; - gref.pointer = get_continuous_mapping_pointer(); - gref.stride = 0; + gref.pointer = get_continuous_mapping_pointer(); + gref.stride = 0; + gref.world_size = comm_->world_size; return gref; } bool contains_pointer(const void* ptr) const override @@ -960,8 +1001,8 @@ class continuous_device_wholememory_impl : public wholememory_impl { close_unix_domain_sockets(); map_driver_device_memory_handles(&recv_ipc_sharable_cu_handles); communicator_barrier(comm_); - local_partition_memory_pointer_ = static_cast(cu_alloc_handle_.mapped_whole_memory) + - rank_partition_strategy_.local_mem_offset; + local_partition_memory_pointer_ = + static_cast(cu_alloc_handle_.mapped_whole_memory) + get_local_offset(); } void unmap_and_destroy_driver_device_memory() noexcept { @@ -1017,17 +1058,23 @@ class chunked_device_wholememory_impl : public wholememory_impl { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : wholememory_impl( - wholememory_handle, total_size, comm, memory_type, memory_location, data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) + : wholememory_impl(wholememory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) { WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_CHUNKED); WHOLEMEMORY_CHECK(location_ == WHOLEMEMORY_ML_DEVICE); } void create_memory() override { - each_rank_same_chunk_strategy(); generate_rank_partition_strategy(); + each_rank_different_chunk_strategy(); create_and_map_runtime_device_memory(); register_chunked_device_memory(); } @@ -1044,7 +1091,7 @@ class chunked_device_wholememory_impl : public wholememory_impl { for (int i = 0; i < comm_->world_size; i++) { size_t mem_size_of_this_rank_and_after = total_size_ - acc_size; size_t mem_size_for_current_rank = - std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_mem_stride); + std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_sizes_[i]); uint64_t int_start_ptr = reinterpret_cast(cuda_ipc_handle_.mapped_ptrs[i]); if (int_ptr >= int_start_ptr && int_ptr < int_start_ptr + mem_size_for_current_rank) { return true; @@ -1074,7 +1121,7 @@ class chunked_device_wholememory_impl : public wholememory_impl { for (int i = 0; i < comm_->world_size; i++) { size_t mem_size_of_this_rank_and_after = total_size_ - acc_size; size_t mem_size_for_current_rank = - std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_mem_stride); + std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_sizes_[i]); if (mem_size_for_current_rank > 0) { register_wholememory_vma_range_locked( cuda_ipc_handle_.mapped_ptrs[i], mem_size_for_current_rank, handle_); @@ -1089,7 +1136,7 @@ class chunked_device_wholememory_impl : public wholememory_impl { for (int i = 0; i < comm_->world_size; i++) { size_t mem_size_of_this_rank_and_after = total_size_ - acc_size; size_t mem_size_for_current_rank = - std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_mem_stride); + std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_sizes_[i]); if (mem_size_for_current_rank > 0) { unregister_wholememory_vma_range_locked( cuda_ipc_handle_.mapped_ptrs[i], mem_size_for_current_rank, handle_); @@ -1124,12 +1171,20 @@ class chunked_device_wholememory_impl : public wholememory_impl { cuda_ipc_handle_.mapped_ptrs.data(), sizeof(void*) * comm_->world_size, cudaMemcpyHostToDevice)); - gref_.stride = rank_partition_strategy_.partition_mem_stride; + WM_CUDA_CHECK(cudaMalloc(&gref_.rank_memory_offsets, sizeof(size_t) * (comm_->world_size + 1))); + WM_CUDA_CHECK(cudaMemcpy(gref_.rank_memory_offsets, + get_rank_offsets().data(), + sizeof(size_t) * (comm_->world_size + 1), + cudaMemcpyHostToDevice)); + gref_.world_size = comm_->world_size; + gref_.stride = rank_partition_strategy_.partition_mem_stride; + gref_.same_chunk = rank_partition_strategy_.same_chunk; } void unmap_and_destroy_runtime_device_memory() noexcept { try { WM_CUDA_CHECK(cudaFree(gref_.pointer)); + WM_CUDA_CHECK(cudaFree(gref_.rank_memory_offsets)); gref_.pointer = nullptr; for (int i = 0; i < comm_->world_size; i++) { if (i != comm_->world_rank) { @@ -1164,9 +1219,15 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : wholememory_impl( - wholememory_handle, total_size, comm, memory_type, memory_location, data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) + : wholememory_impl(wholememory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) { WHOLEMEMORY_CHECK(type_ == WHOLEMEMORY_MT_DISTRIBUTED); WHOLEMEMORY_CHECK(location_ == WHOLEMEMORY_ML_DEVICE); @@ -1179,8 +1240,8 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { void create_memory() override { - each_rank_same_chunk_strategy(); generate_rank_partition_strategy(); + each_rank_different_chunk_strategy(); nvshmem_malloc_device_memory(); register_nvshmem_device_memory(); } @@ -1198,7 +1259,7 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { for (int i = 0; i < comm_->world_size; i++) { size_t mem_size_of_this_rank_and_after = total_size_ - acc_size; size_t mem_size_for_current_rank = - std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_mem_stride); + std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_sizes_[i]); acc_size += mem_size_for_current_rank; uint64_t int_start_ptr = reinterpret_cast(nvshmem_ptr(nvshmem_memory_handle_.local_alloc_mem_ptr, i)); @@ -1226,6 +1287,8 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { return true; } + [[nodiscard]] wholememory_gref_t get_global_reference() const noexcept override { return gref_; } + protected: void register_nvshmem_device_memory() { @@ -1234,7 +1297,7 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { for (int i = 0; i < comm_->world_size; i++) { size_t mem_size_of_this_rank_and_after = total_size_ - acc_size; size_t mem_size_for_current_rank = - std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_mem_stride); + std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_sizes_[i]); if (mem_size_for_current_rank > 0) { void* ptr = nvshmem_ptr(nvshmem_memory_handle_.local_alloc_mem_ptr, i); if (ptr != nullptr) { @@ -1251,7 +1314,7 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { for (int i = 0; i < comm_->world_size; i++) { size_t mem_size_of_this_rank_and_after = total_size_ - acc_size; size_t mem_size_for_current_rank = - std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_mem_stride); + std::min(mem_size_of_this_rank_and_after, rank_partition_strategy_.partition_sizes_[i]); if (mem_size_for_current_rank > 0) { void* ptr = nvshmem_ptr(nvshmem_memory_handle_.local_alloc_mem_ptr, i); if (ptr != nullptr) { @@ -1270,10 +1333,22 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { nvshmem_memory_handle_.local_alloc_mem_ptr = nvshmem_malloc(alloc_size); local_partition_memory_pointer_ = nvshmem_memory_handle_.local_alloc_mem_ptr; distrubuted_backend_ = WHOLEMEMORY_DB_NVSHMEM; + + WM_CUDA_CHECK(cudaMalloc(&gref_.rank_memory_offsets, sizeof(size_t) * (comm_->world_size + 1))); + WM_CUDA_CHECK(cudaMemcpy(gref_.rank_memory_offsets, + get_rank_offsets().data(), + sizeof(size_t) * (comm_->world_size + 1), + cudaMemcpyHostToDevice)); + gref_.pointer = local_partition_memory_pointer_; + gref_.world_size = comm_->world_size; + gref_.stride = rank_partition_strategy_.partition_mem_stride; + gref_.same_chunk = rank_partition_strategy_.same_chunk; } void nvshmem_free_device_memory() { + WM_CUDA_CHECK(cudaFree(gref_.rank_memory_offsets)); + gref_.pointer = nullptr; if (nvshmem_memory_handle_.local_alloc_mem_ptr) { nvshmem_free(nvshmem_memory_handle_.local_alloc_mem_ptr); @@ -1307,6 +1382,8 @@ class nvshmem_device_wholememory_impl : public wholememory_impl { void* local_alloc_mem_ptr = nullptr; } nvshmem_memory_handle_; inline static bool has_set_nvshmem_heap = false; + + wholememory_gref_t gref_; }; #endif // Implementation for MNNVL wholememory that use cuda driver api. @@ -1320,9 +1397,15 @@ class continuous_mnnvl_wholememory_impl : public continuous_device_wholememory_i wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) - : continuous_device_wholememory_impl( - wholememory_handle, total_size, comm, memory_type, memory_location, data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) + : continuous_device_wholememory_impl(wholememory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition) { WHOLEMEMORY_INFO("Using continuous_mnnvl_wholememory_impl"); WHOLEMEMORY_CHECK_NOTHROW(type_ == WHOLEMEMORY_MT_CONTINUOUS); @@ -1335,8 +1418,8 @@ class continuous_mnnvl_wholememory_impl : public continuous_device_wholememory_i void create_memory() override { check_valid(); - each_rank_multiple_page_strategy(); generate_rank_partition_strategy(); + each_rank_multiple_page_strategy(); create_and_map_driver_memory(); register_continuous_mnnvl_memory(); } @@ -1474,8 +1557,8 @@ class continuous_mnnvl_wholememory_impl : public continuous_device_wholememory_i &cu_alloc_handle_.local_ipc_fabric_handle); map_driver_memory_handles(&recv_ipc_sharable_cu_fabric_handles); - local_partition_memory_pointer_ = static_cast(cu_alloc_handle_.mapped_whole_memory) + - rank_partition_strategy_.local_mem_offset; + local_partition_memory_pointer_ = + static_cast(cu_alloc_handle_.mapped_whole_memory) + get_local_offset(); } void unmap_and_destroy_driver_host_memory() noexcept { @@ -1513,16 +1596,37 @@ class continuous_mnnvl_wholememory_impl : public continuous_device_wholememory_i void wholememory_impl::generate_rank_partition_strategy() { - size_t data_slot_count = total_size_ / data_granularity_; - size_t data_slot_per_rank = determine_entry_partition_plan(data_slot_count, comm_->world_size); - size_t rank_data_slot_start = std::min(comm_->world_rank * data_slot_per_rank, data_slot_count); - size_t rank_data_slot_end = - std::min((comm_->world_rank + 1) * data_slot_per_rank, data_slot_count); - size_t rank_data_slot_count = rank_data_slot_end - rank_data_slot_start; - - rank_partition_strategy_.local_mem_size = rank_data_slot_count * data_granularity_; - rank_partition_strategy_.local_mem_offset = rank_data_slot_start * data_granularity_; + if (!rank_partition_strategy_.partition_sizes_.empty()) { + rank_partition_strategy_.partition_mem_stride = total_size_ / comm_->world_size; + bool check_same = true; + for (int i = 0; i < comm_->world_size - 2; i++) { // ignore the last rank + if (rank_partition_strategy_.partition_sizes_[i] != + rank_partition_strategy_.partition_sizes_[i + 1]) { + check_same = false; + break; + } + } + rank_partition_strategy_.same_chunk = check_same; + return; + } + size_t data_slot_count = total_size_ / data_granularity_; + + size_t data_slot_per_rank = 0; + equal_partition_plan(&data_slot_per_rank, data_slot_count, comm_->world_size); + + rank_partition_strategy_.partition_sizes_.resize(comm_->world_size, 0); + rank_partition_strategy_.partition_offsets_.resize(comm_->world_size + 1, 0); + for (int i = 0; i < comm_->world_size; i++) { + size_t tmp_slot_start = std::min(i * data_slot_per_rank, data_slot_count); + size_t tmp_slot_end = std::min((i + 1) * data_slot_per_rank, data_slot_count); + rank_partition_strategy_.partition_sizes_[i] = + (tmp_slot_end - tmp_slot_start) * data_granularity_; + rank_partition_strategy_.partition_offsets_[i] = tmp_slot_start * data_granularity_; + } + rank_partition_strategy_.partition_offsets_[comm_->world_size] = + data_slot_count * data_granularity_; rank_partition_strategy_.partition_mem_stride = data_slot_per_rank * data_granularity_; + rank_partition_strategy_.same_chunk = true; } void wholememory_impl::first_rank_allocate_all_strategy() @@ -1542,28 +1646,33 @@ void wholememory_impl::first_rank_allocate_all_strategy() alloc_strategy_.alloc_sizes[0] = alloc_strategy_.total_alloc_size; } -void wholememory_impl::each_rank_same_chunk_strategy() +void wholememory_impl::each_rank_different_chunk_strategy() { - size_t data_slot_count = total_size_ / data_granularity_; - size_t data_slot_per_rank = determine_entry_partition_plan(data_slot_count, comm_->world_size); - // each rank allocate same size - alloc_strategy_.local_alloc_size = data_slot_per_rank * data_granularity_; - alloc_strategy_.alignment = comm_->alloc_granularity; - if (total_size_ > HUGE_PAGE_THRESHOLD) { - alloc_strategy_.local_alloc_size = - round_up_unsafe(alloc_strategy_.local_alloc_size, HUGE_PAGE_SIZE); - alloc_strategy_.alignment = HUGE_PAGE_SIZE; - } - alloc_strategy_.total_alloc_size = alloc_strategy_.local_alloc_size * comm_->world_size; - alloc_strategy_.alloc_offsets.clear(); alloc_strategy_.alloc_offsets.resize(comm_->world_size, 0); + alloc_strategy_.alloc_sizes.clear(); + alloc_strategy_.alloc_sizes.resize(comm_->world_size, 0); + + size_t rank_local_alloc_offset = 0; for (int i = 0; i < comm_->world_size; i++) { - alloc_strategy_.alloc_offsets[i] = alloc_strategy_.local_alloc_size * i; + size_t rank_local_alloc_size = rank_partition_strategy_.partition_sizes_[i]; + size_t rank_alignment; + if (total_size_ > HUGE_PAGE_THRESHOLD) { + rank_local_alloc_size = round_up_unsafe(rank_local_alloc_size, HUGE_PAGE_SIZE); + rank_alignment = HUGE_PAGE_SIZE; + } else { + rank_local_alloc_size = round_up_unsafe(rank_local_alloc_size, comm_->alloc_granularity); + rank_alignment = comm_->alloc_granularity; + } + if (i == comm_->world_rank) { + alloc_strategy_.local_alloc_size = rank_local_alloc_size; + alloc_strategy_.alignment = rank_alignment; + } + alloc_strategy_.alloc_offsets[i] = rank_local_alloc_offset; + alloc_strategy_.alloc_sizes[i] = rank_local_alloc_size; + rank_local_alloc_offset += rank_local_alloc_size; } - - alloc_strategy_.alloc_sizes.clear(); - alloc_strategy_.alloc_sizes.resize(comm_->world_size, alloc_strategy_.local_alloc_size); + alloc_strategy_.total_alloc_size = rank_local_alloc_offset; } void wholememory_impl::each_rank_multiple_page_strategy() @@ -1643,11 +1752,26 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) noexcept + size_t data_granularity, + size_t* rank_entry_partition) noexcept { try { if (total_size % data_granularity != 0) return WHOLEMEMORY_INVALID_VALUE; - + if (rank_entry_partition != nullptr) { + int64_t total_slot_count = 0; + for (int i = 0; i < comm->world_size; i++) { + WM_COMM_CHECK_ALL_SAME(comm, rank_entry_partition[i]); + if (rank_entry_partition[i] <= 0) { return WHOLEMEMORY_INVALID_VALUE; } + total_slot_count += rank_entry_partition[i]; + } + if (total_slot_count * data_granularity != total_size) { + WHOLEMEMORY_ERROR("total slot count * data granularity (%ld*%ld) != total size (%ld)", + total_slot_count, + data_granularity, + total_size); + return WHOLEMEMORY_INVALID_VALUE; + } + } *wholememory_handle_ptr = nullptr; std::unique_lock mlock(comm->mu); auto* whole_memory_handle = new wholememory_handle_(); @@ -1660,27 +1784,52 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha if (memory_type == WHOLEMEMORY_MT_DISTRIBUTED) { #ifdef WITH_NVSHMEM_SUPPORT if (comm->bind_to_nvshmem) { - whole_memory_handle->impl = new nvshmem_device_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new nvshmem_device_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } else #endif { - whole_memory_handle->impl = new distributed_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new distributed_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } } else if (memory_type == WHOLEMEMORY_MT_CONTINUOUS) { if (is_intranode_communicator(comm) || !SupportEGM()) { if (memory_location == WHOLEMEMORY_ML_HOST) { - whole_memory_handle->impl = new global_mapped_host_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new global_mapped_host_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } else { - whole_memory_handle->impl = new continuous_device_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new continuous_device_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } } else { #if CUDA_VERSION >= 12030 - whole_memory_handle->impl = new continuous_mnnvl_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new continuous_mnnvl_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); #else WHOLEMEMORY_FAIL_NOTHROW("Multinode CONTINUOUS is only supported on CUDA Version >= 12.3"); #endif @@ -1688,11 +1837,21 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha } else if (memory_type == WHOLEMEMORY_MT_CHUNKED) { WHOLEMEMORY_CHECK_NOTHROW(is_intranode_communicator(comm)); if (memory_location == WHOLEMEMORY_ML_HOST) { - whole_memory_handle->impl = new global_mapped_host_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new global_mapped_host_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } else { - whole_memory_handle->impl = new chunked_device_wholememory_impl( - whole_memory_handle, total_size, comm, memory_type, memory_location, data_granularity); + whole_memory_handle->impl = new chunked_device_wholememory_impl(whole_memory_handle, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } } else { WHOLEMEMORY_FATAL("Unsupported memory_type (%d) and memory_location (%d).", @@ -1858,45 +2017,69 @@ wholememory_error_code_t get_nvshmem_reference_frome_handle( (wholememory_handle->impl->get_distributed_backend() != WHOLEMEMORY_DB_NVSHMEM)) { return WHOLEMEMORY_INVALID_INPUT; } - *wholememory_nvshmem_ref = wholememory_nvshmem_ref_t{}; - size_t local_size, local_offset; - void* pointer; - - wholememory_handle->impl->get_local_memory(&pointer, &local_size, &local_offset); - wholememory_nvshmem_ref->pointer = pointer; - wholememory_nvshmem_ref->stride = wholememory_handle->impl->get_partition_stride(); - wholememory_nvshmem_ref->world_rank = wholememory_handle->impl->get_comm()->world_rank; - wholememory_nvshmem_ref->world_size = wholememory_handle->impl->get_comm()->world_size; + wholememory_gref_t wholememory_gref_tmp = wholememory_handle->impl->get_global_reference(); + *wholememory_nvshmem_ref = wholememory_nvshmem_ref_t{}; + wholememory_nvshmem_ref->pointer = wholememory_gref_tmp.pointer; + wholememory_nvshmem_ref->rank_memory_offsets = wholememory_gref_tmp.rank_memory_offsets; + wholememory_nvshmem_ref->world_size = wholememory_gref_tmp.world_size; + wholememory_nvshmem_ref->world_rank = wholememory_handle->impl->get_comm()->world_rank; + wholememory_nvshmem_ref->stride = wholememory_gref_tmp.stride; + wholememory_nvshmem_ref->same_chunk = wholememory_gref_tmp.same_chunk; return (wholememory_nvshmem_ref->pointer == nullptr) ? WHOLEMEMORY_INVALID_INPUT : WHOLEMEMORY_SUCCESS; } #endif -wholememory_error_code_t determine_partition_plan(size_t* size_per_rank, - size_t total_size, - size_t data_granularity, - int world_size) noexcept +wholememory_error_code_t equal_partition_plan(size_t* entry_per_rank, + size_t total_entry_count, + int world_size) noexcept +{ + *entry_per_rank = div_rounding_up_safe(total_entry_count, world_size); + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t get_rank_partition_sizes_from_handle( + size_t* rank_sizes, wholememory_handle_t wholememory_handle) noexcept +{ + if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { + return WHOLEMEMORY_INVALID_INPUT; + } + std::vector rank_sizes_ = wholememory_handle->impl->get_rank_sizes(); + for (int i = 0; i < rank_sizes_.size(); i++) + rank_sizes[i] = rank_sizes_[i]; + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t get_rank_partition_offsets_from_handle( + size_t* rank_offsets, wholememory_handle_t wholememory_handle) noexcept { - if (total_size % data_granularity != 0) { return WHOLEMEMORY_INVALID_VALUE; } - if (size_per_rank == nullptr) { return WHOLEMEMORY_INVALID_INPUT; } - size_t entry_per_rank = 0; - *size_per_rank = determine_entry_partition_plan(total_size / data_granularity, world_size); + if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { + return WHOLEMEMORY_INVALID_INPUT; + } + std::vector rank_offsets_ = wholememory_handle->impl->get_rank_offsets(); + for (int i = 0; i < rank_offsets_.size(); i++) + rank_offsets[i] = rank_offsets_[i]; return WHOLEMEMORY_SUCCESS; } -size_t determine_entry_partition_plan(size_t total_entry_count, int world_size) noexcept +wholememory_error_code_t get_local_size_from_handle( + size_t* rank_size, wholememory_handle_t wholememory_handle) noexcept { - return div_rounding_up_safe(total_entry_count, world_size); + if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { + return WHOLEMEMORY_INVALID_INPUT; + } + *rank_size = wholememory_handle->impl->get_local_size(); + return WHOLEMEMORY_SUCCESS; } -wholememory_error_code_t get_partition_plan_from_handle( - size_t* size_per_rank, wholememory_handle_t wholememory_handle) noexcept +wholememory_error_code_t get_local_offset_from_handle( + size_t* local_offset, wholememory_handle_t wholememory_handle) noexcept { if (wholememory_handle == nullptr || wholememory_handle->impl == nullptr) { return WHOLEMEMORY_INVALID_INPUT; } - *size_per_rank = wholememory_handle->impl->get_partition_stride(); + *local_offset = wholememory_handle->impl->get_local_offset(); return WHOLEMEMORY_SUCCESS; } diff --git a/cpp/src/wholememory/memory_handle.hpp b/cpp/src/wholememory/memory_handle.hpp index d261445cc..c16e5bc03 100644 --- a/cpp/src/wholememory/memory_handle.hpp +++ b/cpp/src/wholememory/memory_handle.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,7 +40,8 @@ wholememory_error_code_t create_wholememory(wholememory_handle_t* wholememory_ha wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) noexcept; + size_t data_granularity, + size_t* rank_entry_partition = nullptr) noexcept; wholememory_error_code_t destroy_wholememory_with_comm_locked( wholememory_handle_t wholememory_handle) noexcept; @@ -71,21 +72,27 @@ wholememory_error_code_t get_rank_memory_from_handle( int rank, wholememory_handle_t wholememory_handle) noexcept; +wholememory_error_code_t get_local_size_from_handle( + size_t* size, wholememory_handle_t wholememory_handle) noexcept; + +wholememory_error_code_t get_local_offset_from_handle( + size_t* offset, wholememory_handle_t wholememory_handle) noexcept; + wholememory_error_code_t get_global_pointer_from_handle( void** global_ptr, wholememory_handle_t wholememory_handle) noexcept; wholememory_error_code_t get_global_reference_from_handle( wholememory_gref_t* wholememory_gref, wholememory_handle_t wholememory_handle) noexcept; -wholememory_error_code_t determine_partition_plan(size_t* size_per_rank, - size_t total_size, - size_t data_granularity, - int world_size) noexcept; +wholememory_error_code_t equal_partition_plan(size_t* entry_per_rank, + size_t total_entry_count, + int world_size) noexcept; -size_t determine_entry_partition_plan(size_t total_entry_count, int world_size) noexcept; +wholememory_error_code_t get_rank_partition_sizes_from_handle( + size_t* rank_sizes, wholememory_handle_t wholememory_handle) noexcept; -wholememory_error_code_t get_partition_plan_from_handle( - size_t* size_per_rank, wholememory_handle_t wholememory_handle) noexcept; +wholememory_error_code_t get_rank_partition_offsets_from_handle( + size_t* rank_offsets, wholememory_handle_t wholememory_handle) noexcept; wholememory_distributed_backend_t get_distributed_backend_t( wholememory_handle_t wholememory_handle) noexcept; diff --git a/cpp/src/wholememory/wholememory.cpp b/cpp/src/wholememory/wholememory.cpp index 814e90087..600906889 100644 --- a/cpp/src/wholememory/wholememory.cpp +++ b/cpp/src/wholememory/wholememory.cpp @@ -107,10 +107,16 @@ wholememory_error_code_t wholememory_malloc(wholememory_handle_t* wholememory_ha wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) + size_t data_granularity, + size_t* rank_entry_partition) { - return wholememory::create_wholememory( - wholememory_handle_ptr, total_size, comm, memory_type, memory_location, data_granularity); + return wholememory::create_wholememory(wholememory_handle_ptr, + total_size, + comm, + memory_type, + memory_location, + data_granularity, + rank_entry_partition); } wholememory_error_code_t wholememory_free(wholememory_handle_t wholememory_handle) @@ -170,6 +176,13 @@ wholememory_error_code_t wholememory_get_rank_memory(void** rank_memory_ptr, rank_memory_ptr, rank_memory_size, rank_memory_offset, rank, wholememory_handle); } +wholememory_error_code_t wholememory_equal_entry_partition_plan(size_t* entry_per_rank, + size_t total_entry_count, + int world_size) +{ + return wholememory::equal_partition_plan(entry_per_rank, total_entry_count, world_size); +} + wholememory_error_code_t wholememory_get_global_pointer(void** global_ptr, wholememory_handle_t wholememory_handle) { @@ -193,28 +206,28 @@ wholememory_error_code_t wholememory_get_nvshmem_reference( #endif -wholememory_error_code_t wholememory_determine_partition_plan(size_t* size_per_rank, - size_t total_size, - size_t data_granularity, - int world_size) +wholememory_error_code_t wholememory_get_rank_partition_sizes( + size_t* rank_sizes, wholememory_handle_t wholememory_handle) { - return wholememory::determine_partition_plan( - size_per_rank, total_size, data_granularity, world_size); + return wholememory::get_rank_partition_sizes_from_handle(rank_sizes, wholememory_handle); } -wholememory_error_code_t wholememory_determine_entry_partition_plan(size_t* entry_per_rank, - size_t total_entry_count, - int world_size) +wholememory_error_code_t wholememory_get_rank_partition_offsets( + size_t* rank_offsets, wholememory_handle_t wholememory_handle) { - if (entry_per_rank == nullptr) { return WHOLEMEMORY_INVALID_INPUT; } - *entry_per_rank = wholememory::determine_entry_partition_plan(total_entry_count, world_size); - return WHOLEMEMORY_SUCCESS; + return wholememory::get_rank_partition_offsets_from_handle(rank_offsets, wholememory_handle); } -wholememory_error_code_t wholememory_get_partition_plan(size_t* size_per_rank, - wholememory_handle_t wholememory_handle) +wholememory_error_code_t wholememory_get_local_size(size_t* local_size, + wholememory_handle_t wholememory_handle) +{ + return wholememory::get_local_size_from_handle(local_size, wholememory_handle); +} + +wholememory_error_code_t wholememory_get_local_offset(size_t* local_size, + wholememory_handle_t wholememory_handle) { - return wholememory::get_partition_plan_from_handle(size_per_rank, wholememory_handle); + return wholememory::get_local_offset_from_handle(local_size, wholememory_handle); } int fork_get_device_count() diff --git a/cpp/src/wholememory/wholememory_tensor.cpp b/cpp/src/wholememory/wholememory_tensor.cpp index 110657183..41ba10937 100644 --- a/cpp/src/wholememory/wholememory_tensor.cpp +++ b/cpp/src/wholememory/wholememory_tensor.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,7 +53,8 @@ wholememory_error_code_t wholememory_create_tensor( wholememory_tensor_description_t* tensor_description, wholememory_comm_t comm, wholememory_memory_type_t memory_type, - wholememory_memory_location_t memory_location) + wholememory_memory_location_t memory_location, + size_t* tensor_entry_partition) { if (p_wholememory_tensor == nullptr) { WHOLEMEMORY_ERROR("p_wholememory_tensor is nullptr"); @@ -98,7 +99,8 @@ wholememory_error_code_t wholememory_create_tensor( comm, memory_type, memory_location, - granularity); + granularity, + tensor_entry_partition); inc_tensor_count(); if (ret_code != WHOLEMEMORY_SUCCESS) { free(wholememory_tensor); } return ret_code; @@ -259,16 +261,10 @@ wholememory_error_code_t wholememory_tensor_map_local_tensor( wholememory_get_local_memory(&local_ptr, &local_size, &local_offset, handle)); size_t const element_size = wholememory_dtype_get_element_size(wm_desc->dtype); size_t const gran_size = wm_desc->dim == 1 ? element_size : element_size * wm_desc->strides[0]; - size_t size_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_partition_plan(&size_per_rank, handle)); - WHOLEMEMORY_CHECK_NOTHROW(size_per_rank % gran_size == 0); - size_t entry_per_rank = size_per_rank / gran_size; - int64_t local_start = std::min(entry_per_rank * world_rank, wm_desc->sizes[0]); - int64_t local_end = std::min(entry_per_rank * (world_rank + 1), wm_desc->sizes[0]); + local_size = std::min(local_size, wm_desc->sizes[0] * gran_size - local_offset); if (local_size % gran_size != 0) return WHOLEMEMORY_LOGIC_ERROR; wholememory_tensor_description_t local_desc = *wm_desc; - // local_desc.sizes[0] = local_size / gran_size; - local_desc.sizes[0] = (local_end - local_start); + local_desc.sizes[0] = local_size / gran_size; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_make_tensor_from_pointer(local_tensor, local_ptr, &local_desc)); @@ -297,36 +293,120 @@ void* wholememory_tensor_get_data_pointer(wholememory_tensor_t wholememory_tenso wholememory_tensor->tensor_description.storage_offset; } -size_t wholememory_tensor_get_entry_per_partition(wholememory_tensor_t wholememory_tensor) +wholememory_error_code_t wholememory_tensor_get_entry_offsets( + size_t* entry_offsets, wholememory_tensor_t wholememory_tensor) { wholememory_tensor_t root_tensor = wholememory_tensor_get_root(wholememory_tensor); WHOLEMEMORY_CHECK_NOTHROW( (root_tensor->tensor_description.dim == 1 || root_tensor->tensor_description.dim == 2)); if (wholememory_tensor->is_wholememory) { - size_t size_per_rank; - wholememory_get_partition_plan(&size_per_rank, - wholememory_tensor_get_memory_handle(root_tensor)); size_t embedding_stride = 1; size_t const element_size = wholememory_dtype_get_element_size(wholememory_tensor->tensor_description.dtype); if (root_tensor->tensor_description.dim == 2) { embedding_stride = root_tensor->tensor_description.strides[0]; } - WHOLEMEMORY_CHECK_NOTHROW(size_per_rank % (embedding_stride * element_size) == 0); - size_t det_entry_per_rank; int world_size; wholememory_comm_t comm; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator( &comm, wholememory_tensor_get_memory_handle(wholememory_tensor))); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, comm)); - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_determine_entry_partition_plan( - &det_entry_per_rank, root_tensor->tensor_description.sizes[0], world_size)); - WHOLEMEMORY_CHECK_NOTHROW(det_entry_per_rank == - size_per_rank / (embedding_stride * element_size)); - return det_entry_per_rank; + + wholememory_get_rank_partition_offsets( + entry_offsets, wholememory_tensor_get_memory_handle(wholememory_tensor)); + for (int i = 0; i < world_size + 1; i++) { + WHOLEMEMORY_CHECK_NOTHROW(entry_offsets[i] % (embedding_stride * element_size) == 0); + entry_offsets[i] /= (embedding_stride * element_size); + } + return WHOLEMEMORY_SUCCESS; } - return root_tensor->tensor_description.sizes[0]; + entry_offsets[0] = 0; + entry_offsets[1] = root_tensor->tensor_description.sizes[0]; + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t wholememory_tensor_get_entry_partition_sizes( + size_t* entry_partition, wholememory_tensor_t wholememory_tensor) +{ + wholememory_tensor_t root_tensor = wholememory_tensor_get_root(wholememory_tensor); + WHOLEMEMORY_CHECK_NOTHROW( + (root_tensor->tensor_description.dim == 1 || root_tensor->tensor_description.dim == 2)); + if (wholememory_tensor->is_wholememory) { + size_t embedding_stride = 1; + size_t const element_size = + wholememory_dtype_get_element_size(wholememory_tensor->tensor_description.dtype); + if (root_tensor->tensor_description.dim == 2) { + embedding_stride = root_tensor->tensor_description.strides[0]; + } + + int world_size; + wholememory_comm_t comm; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator( + &comm, wholememory_tensor_get_memory_handle(wholememory_tensor))); + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, comm)); + + wholememory_get_rank_partition_sizes(entry_partition, + wholememory_tensor_get_memory_handle(wholememory_tensor)); + for (int i = 0; i < world_size; i++) { + WHOLEMEMORY_CHECK_NOTHROW(entry_partition[i] % (embedding_stride * element_size) == 0); + entry_partition[i] /= (embedding_stride * element_size); + } + return WHOLEMEMORY_SUCCESS; + } + entry_partition[0] = root_tensor->tensor_description.sizes[0]; + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t wholememory_tensor_get_local_entry_count( + size_t* local_entry_count, wholememory_tensor_t wholememory_tensor) +{ + wholememory_tensor_t root_tensor = wholememory_tensor_get_root(wholememory_tensor); + WHOLEMEMORY_CHECK_NOTHROW( + (root_tensor->tensor_description.dim == 1 || root_tensor->tensor_description.dim == 2)); + if (wholememory_tensor->is_wholememory) { + size_t embedding_stride = 1; + size_t const element_size = + wholememory_dtype_get_element_size(wholememory_tensor->tensor_description.dtype); + if (root_tensor->tensor_description.dim == 2) { + embedding_stride = root_tensor->tensor_description.strides[0]; + } + + size_t entry_cnt; + wholememory_get_local_size(&entry_cnt, + wholememory_tensor_get_memory_handle(wholememory_tensor)); + WHOLEMEMORY_CHECK_NOTHROW(entry_cnt % (embedding_stride * element_size) == 0); + entry_cnt /= (embedding_stride * element_size); + *local_entry_count = entry_cnt; + return WHOLEMEMORY_SUCCESS; + } + *local_entry_count = root_tensor->tensor_description.sizes[0]; + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t wholememory_tensor_get_local_entry_start( + size_t* local_entry_start, wholememory_tensor_t wholememory_tensor) +{ + wholememory_tensor_t root_tensor = wholememory_tensor_get_root(wholememory_tensor); + WHOLEMEMORY_CHECK_NOTHROW( + (root_tensor->tensor_description.dim == 1 || root_tensor->tensor_description.dim == 2)); + if (wholememory_tensor->is_wholememory) { + size_t embedding_stride = 1; + size_t const element_size = + wholememory_dtype_get_element_size(wholememory_tensor->tensor_description.dtype); + if (root_tensor->tensor_description.dim == 2) { + embedding_stride = root_tensor->tensor_description.strides[0]; + } + size_t entry_start; + wholememory_get_local_offset(&entry_start, + wholememory_tensor_get_memory_handle(wholememory_tensor)); + WHOLEMEMORY_CHECK_NOTHROW(entry_start % (embedding_stride * element_size) == 0); + entry_start /= (embedding_stride * element_size); + *local_entry_start = entry_start; + return WHOLEMEMORY_SUCCESS; + } + *local_entry_start = 0; + return WHOLEMEMORY_SUCCESS; } wholememory_error_code_t wholememory_tensor_get_subtensor( diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_func.cu b/cpp/src/wholememory_ops/functions/bucket_ids_func.cu index e2ca6faa3..6bd6b6c44 100644 --- a/cpp/src/wholememory_ops/functions/bucket_ids_func.cu +++ b/cpp/src/wholememory_ops/functions/bucket_ids_func.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,23 +28,50 @@ namespace wholememory_ops { +template +__device__ __forceinline__ int dest_rank(IndexT entry_idx, + size_t total_entry_count, + const size_t* embedding_entry_offsets, + int world_size) +{ + size_t estimated_entry_per_rank = total_entry_count / world_size; + int estimated_rank = max(world_size - 1, int(entry_idx / estimated_entry_per_rank)); + if (embedding_entry_offsets[estimated_rank] > entry_idx) { + for (int i = estimated_rank - 1; i >= 0; i--) { + if (embedding_entry_offsets[i] <= entry_idx) { return i; } + } + } else { + for (int i = estimated_rank + 1; i <= world_size; i++) { + if (embedding_entry_offsets[i] > entry_idx) { return i - 1; } + } + } + return 0; +} + template __global__ void bucket_ids_for_ranks_kernel(const IndexT* indices, size_t indice_count, int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, int world_size) { - extern __shared__ int rank_count_shared[]; + extern __shared__ char shmem[]; + int* rank_count_shared = reinterpret_cast(shmem); for (int idx = threadIdx.x; idx < world_size; idx += blockDim.x) { rank_count_shared[idx] = 0; } + size_t* embedding_entry_offsets_shared = + reinterpret_cast(shmem + sizeof(size_t) * world_size); + for (int idx = threadIdx.x; idx < world_size + 1; idx += blockDim.x) { + embedding_entry_offsets_shared[idx] = embedding_entry_offsets[idx]; + } __syncthreads(); + size_t total_entry_count = embedding_entry_offsets_shared[world_size]; for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < indice_count; idx += blockDim.x * gridDim.x) { IndexT node_idx = indices[idx]; if (node_idx < 0) continue; - int rank = node_idx / embedding_entry_count_per_rank; + int rank = dest_rank(node_idx, total_entry_count, embedding_entry_offsets_shared, world_size); assert(rank >= 0 && rank < world_size); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 atomicAdd_block(&rank_count_shared[rank], 1); @@ -63,7 +90,7 @@ template void bucket_ids_for_ranks_temp_fn(void* indices, wholememory_array_description_t indice_desc, int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, int world_size, int sm_count, cudaStream_t stream) @@ -73,12 +100,11 @@ void bucket_ids_for_ranks_temp_fn(void* indices, block_count = std::min(block_count, sm_count * 4); IndexT* indices_ptr = static_cast(indices); indices_ptr += indice_desc.storage_offset; - bucket_ids_for_ranks_kernel<<>>( - indices_ptr, - indice_desc.size, - dev_rank_id_count_ptr, - embedding_entry_count_per_rank, - world_size); + bucket_ids_for_ranks_kernel<<>>( + indices_ptr, indice_desc.size, dev_rank_id_count_ptr, embedding_entry_offsets, world_size); } REGISTER_DISPATCH_ONE_TYPE(BucketIdForRanks, bucket_ids_for_ranks_temp_fn, SINT3264) @@ -86,7 +112,7 @@ REGISTER_DISPATCH_ONE_TYPE(BucketIdForRanks, bucket_ids_for_ranks_temp_fn, SINT3 wholememory_error_code_t bucket_ids_for_ranks(void* indices, wholememory_array_description_t indice_desc, int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, int world_size, cudaDeviceProp* prop, cudaStream_t stream) @@ -101,7 +127,7 @@ wholememory_error_code_t bucket_ids_for_ranks(void* indices, indices, indice_desc, dev_rank_id_count_ptr, - embedding_entry_count_per_rank, + embedding_entry_offsets, world_size, sm_count, stream); diff --git a/cpp/src/wholememory_ops/functions/bucket_ids_func.h b/cpp/src/wholememory_ops/functions/bucket_ids_func.h index 3a1f684c0..a8443e3e4 100644 --- a/cpp/src/wholememory_ops/functions/bucket_ids_func.h +++ b/cpp/src/wholememory_ops/functions/bucket_ids_func.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ namespace wholememory_ops { wholememory_error_code_t bucket_ids_for_ranks(void* indices, wholememory_array_description_t indice_desc, int64_t* dev_rank_id_count_ptr, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, int world_size, cudaDeviceProp* prop, cudaStream_t stream); diff --git a/cpp/src/wholememory_ops/functions/embedding_cache_func.cu b/cpp/src/wholememory_ops/functions/embedding_cache_func.cu index a79f6500b..fd9e8464f 100644 --- a/cpp/src/wholememory_ops/functions/embedding_cache_func.cu +++ b/cpp/src/wholememory_ops/functions/embedding_cache_func.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -348,10 +348,9 @@ wholememory_error_code_t update_cache_direct_same_comm( auto* raw_embedding_desc = wholememory_tensor_get_tensor_description(wholememory_tensor_get_root(wm_raw_memory_embedding)); - size_t embedding_entry_count_per_rank = 0; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_determine_entry_partition_plan( - &embedding_entry_count_per_rank, raw_embedding_desc->sizes[0], world_size)); - + size_t embedding_entry_start = 0; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_tensor_get_local_entry_start(&embedding_entry_start, wm_raw_memory_embedding)); int indices_num_run = 0; temp_memory_handle unique_indice_handle(p_env_fns), unique_count_handle(p_env_fns); try { @@ -380,7 +379,7 @@ wholememory_error_code_t update_cache_direct_same_comm( &unique_cache_set_start_handle, &unique_cache_set_count_handle, &cache_set_num_run, - world_rank * embedding_entry_count_per_rank, + embedding_entry_start, cache_set_coverage, &thrust_allocator, p_env_fns, @@ -414,7 +413,7 @@ wholememory_error_code_t update_cache_direct_same_comm( static_cast(embedding_local_pointer), embedding_dim_in_int4, cache_set_num_run, - world_rank * embedding_entry_count_per_rank, + embedding_entry_start, cache_set_coverage, stream); @@ -525,7 +524,7 @@ wholememory_error_code_t update_cache_different_comm( wholememory_array_description_t indice_desc, wholememory_tensor_t wm_raw_memory_embedding, wholememory_comm_t cache_comm, - size_t embedding_entry_count_per_cache_rank, + size_t* embedding_entry_offsets, const wholememory::embedding_cache_local_data* cache_local_data, int cache_set_coverage, wholememory_env_func_t* p_env_fns, @@ -554,7 +553,6 @@ wholememory_error_code_t update_cache_different_comm( WHOLEMEMORY_ERROR("SortUniqueLocalIndicesTempFunc failed."); return WHOLEMEMORY_LOGIC_ERROR; } - temp_memory_handle unique_cache_set_lid_handle(p_env_fns), unique_cache_set_start_handle(p_env_fns), unique_cache_set_count_handle(p_env_fns); int cache_set_num_run; @@ -566,7 +564,7 @@ wholememory_error_code_t update_cache_different_comm( &unique_cache_set_start_handle, &unique_cache_set_count_handle, &cache_set_num_run, - cache_world_rank * embedding_entry_count_per_cache_rank, + embedding_entry_offsets[cache_world_rank], cache_set_coverage, &thrust_allocator, p_env_fns, @@ -595,7 +593,7 @@ wholememory_error_code_t update_cache_different_comm( static_cast(wholememory_tensor_get_data_pointer(cache_local_data->access_count_)), local_write_cache_index_ptr, global_load_gid_ptr, - cache_world_rank * embedding_entry_count_per_cache_rank, + embedding_entry_offsets[cache_world_rank], cache_set_coverage, cache_set_num_run, stream); @@ -697,11 +695,12 @@ wholememory_error_code_t writeback_cache_direct_same_comm( auto* raw_embedding_desc = wholememory_tensor_get_tensor_description(wholememory_tensor_get_root(wm_raw_memory_embedding)); - size_t embedding_entry_count_per_rank = 0; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_determine_entry_partition_plan( - &embedding_entry_count_per_rank, raw_embedding_desc->sizes[0], world_size)); - WHOLEMEMORY_CHECK_NOTHROW(embedding_entry_count_per_rank % cache_set_coverage == 0); + size_t embedding_entry_count = 0; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_tensor_get_local_entry_count(&embedding_entry_count, wm_raw_memory_embedding)); + WHOLEMEMORY_CHECK_NOTHROW(embedding_entry_count % cache_set_coverage == 0); + wholememory_tensor_t raw_local_tensor; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_tensor_map_local_tensor(wm_raw_memory_embedding, &raw_local_tensor)); diff --git a/cpp/src/wholememory_ops/functions/embedding_cache_func.h b/cpp/src/wholememory_ops/functions/embedding_cache_func.h index 957ba5242..edf71e74e 100644 --- a/cpp/src/wholememory_ops/functions/embedding_cache_func.h +++ b/cpp/src/wholememory_ops/functions/embedding_cache_func.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,7 +55,7 @@ wholememory_error_code_t update_cache_direct_same_comm( * @param wm_raw_memory_embedding : the WholeMemory Tensor that is to be cached which stores all * embeddings. * @param cache_comm : communicator of cache - * @param embedding_entry_count_per_cache_rank : embedding entries covered by each cache rank + * @param embedding_entry_offsets : embedding entry offset of each cache rank * @param cache_local_data : embedding_cache_local_data of wm_raw_memory_embedding * @param cache_set_coverage : cache set coverage * @param p_env_fns : env fns @@ -67,7 +67,7 @@ wholememory_error_code_t update_cache_different_comm( wholememory_array_description_t indice_desc, wholememory_tensor_t wm_raw_memory_embedding, wholememory_comm_t cache_comm, - size_t embedding_entry_count_per_cache_rank, + size_t* embedding_entry_offsets, const wholememory::embedding_cache_local_data* cache_local_data, int cache_set_coverage, wholememory_env_func_t* p_env_fns, diff --git a/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu b/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu index 137b10470..173948874 100644 --- a/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu +++ b/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.cu @@ -161,7 +161,7 @@ wholememory_error_code_t bucket_and_exchange_ids_func( int64_t* host_rank_id_count_ptr, temp_memory_handle* dev_recv_indices_buffer_handle, int64_t* dev_raw_indice_ptr, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_comm_t wm_comm, wm_thrust_allocator* p_thrust_allocator, wholememory_env_func_t* p_env_fns, @@ -178,7 +178,7 @@ wholememory_error_code_t bucket_and_exchange_ids_func( WHOLEMEMORY_RETURN_ON_FAIL(bucket_ids_for_ranks(indices, indice_desc, dev_rank_id_count_ptr, - embedding_entry_count_per_rank, + embedding_entry_offsets, world_size, get_device_prop(-1), stream)); diff --git a/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.h b/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.h index a2c8ff3f3..69a2d92f0 100644 --- a/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.h +++ b/cpp/src/wholememory_ops/functions/exchange_ids_nccl_func.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ namespace wholememory_ops { * @param dev_recv_indices_buffer_handle : temp_memory_handle to create buffer for received indices. * @param dev_raw_indice_ptr : pointer to allocated int64_t array to storage raw indices mapping of * sort - * @param embedding_entry_count_per_rank : entry count of embedding count per rank + * @param embedding_entry_offsets : embedding entry offsets * @param wm_comm : WholeMemory Communicator * @param p_thrust_allocator : thrust allocator * @param p_env_fns : EnvFns @@ -48,7 +48,7 @@ wholememory_error_code_t bucket_and_exchange_ids_func( int64_t* host_rank_id_count_ptr, temp_memory_handle* dev_recv_indices_buffer_handle, int64_t* dev_raw_indice_ptr, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_comm_t wm_comm, wm_thrust_allocator* p_thrust_allocator, wholememory_env_func_t* p_env_fns, diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh index a4979f7be..140b257f8 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh @@ -260,23 +260,25 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, OutputT* output, wholememory_matrix_description_t output_desc) { - auto block = cooperative_groups::this_thread_block(); - auto mywarp = cooperative_groups::tiled_partition<32>(block); - __shared__ char shm_in_char[16384]; - OutputT* all_sh = reinterpret_cast(shm_in_char); - OutputT* my_shared; + auto block = cooperative_groups::this_thread_block(); + auto mywarp = cooperative_groups::tiled_partition<32>(block); + constexpr size_t shm_max_size = 16384; + __shared__ char shm_in_char[shm_max_size]; int warp_id = (threadIdx.x + blockIdx.x * blockDim.x) / 32; int lane_id = threadIdx.x % 32; int embedding_size = embedding_desc.sizes[1]; int64_t embedding_stride = embedding_desc.stride; int64_t output_stride = output_desc.stride; - int shm_size = 16384 / sizeof(OutputT); + wholememory::device_reference embedding_dev_ref(embedding_gref); typed_data_vector embeddings; typed_data_vector outputs; + int shm_size = shm_max_size / sizeof(OutputT); + OutputT* all_sh = reinterpret_cast(shm_in_char); + OutputT* my_shared; bool use_shm = true; if (shm_size / (blockDim.x / 32) < output_desc.sizes[1]) { // use_shm = false; @@ -342,6 +344,7 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, int sub_warp_num = subwarp.meta_group_size() * gridDim.x; int lane_id_in_sub_warp = subwarp.thread_rank(); + wholememory::device_reference embedding_dev_ref(embedding_gref); int embedding_size = embedding_desc.sizes[1]; @@ -358,11 +361,10 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, if (embedding_table_idx < 0) continue; int64_t embedding_offset = embedding_desc.storage_offset + embedding_table_idx * embedding_stride; - + EmbeddingT* emb_ptr = &embedding_dev_ref[embedding_offset]; for (int emb_idx = lane_id_in_sub_warp * ALIGNMENT; emb_idx < embedding_size; emb_idx += ALIGNMENT * SUB_WARP_SIZE) { - mov_data(&embeddings, - &embedding_dev_ref[embedding_offset + emb_idx]); + mov_data(&embeddings, &emb_ptr[emb_idx]); #pragma unroll for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) { typed_data_vector_at(outputs, sub_idx) = @@ -522,11 +524,10 @@ __global__ void scatter_func_kernel(const InputT* input, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc) { - auto block = cooperative_groups::this_thread_block(); - auto mywarp = cooperative_groups::tiled_partition<32>(block); - __shared__ char shm_in_char[24576]; - InputT* all_sh = reinterpret_cast(shm_in_char); - InputT* my_shared; + auto block = cooperative_groups::this_thread_block(); + auto mywarp = cooperative_groups::tiled_partition<32>(block); + constexpr size_t shm_max_size = 24576; + __shared__ char shm_in_char[shm_max_size]; int warp_id = (threadIdx.x + blockIdx.x * blockDim.x) / 32; int lane_id = threadIdx.x % 32; @@ -535,11 +536,13 @@ __global__ void scatter_func_kernel(const InputT* input, int64_t input_stride = input_desc.stride; int async_copy_align = sizeof(InputT) > 4 ? 1 : 4 / sizeof(InputT); - int shm_size = 24576 / sizeof(InputT); + wholememory::device_reference embedding_dev_ref(embedding_gref); + int shm_size = shm_max_size / sizeof(InputT); + InputT* all_sh = reinterpret_cast(shm_in_char); + InputT* my_shared; int batch_size = (shm_size / (blockDim.x / 32) - async_copy_align) / input_stride; // indices batch size in lines - wholememory::device_reference embedding_dev_ref(embedding_gref); typed_data_vector embeddings; typed_data_vector inputs; diff --git a/cpp/src/wholememory_ops/functions/map_indices_func.cu b/cpp/src/wholememory_ops/functions/map_indices_func.cu index 1a1418179..e07ac40f3 100644 --- a/cpp/src/wholememory_ops/functions/map_indices_func.cu +++ b/cpp/src/wholememory_ops/functions/map_indices_func.cu @@ -28,7 +28,7 @@ __global__ void storage_idx2wm_emb_idx_kernel(IndexT* indice, IndexT* mapped_indice, int64_t indice_size, int world_size, - int64_t entry_per_rank, + int64_t entry_start, int round_robin_size) { int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -38,7 +38,7 @@ __global__ void storage_idx2wm_emb_idx_kernel(IndexT* indice, IndexT table_off = target_idx % round_robin_size; int rank_id = table_idx % world_size; int rank_table_idx = table_idx / world_size; - IndexT wmidx = entry_per_rank * rank_id + round_robin_size * rank_table_idx + table_off; + IndexT wmidx = entry_start + round_robin_size * rank_table_idx + table_off; mapped_indice[i] = wmidx; } return; @@ -49,7 +49,7 @@ void storage_idx2wm_emb_idx_temp_fn(void* indice_ptr, void* mapped_indice_ptr, int64_t indice_size, int world_size, - int64_t entry_per_rank, + int64_t entry_start, int round_robin_size, cudaStream_t stream) { @@ -59,7 +59,7 @@ void storage_idx2wm_emb_idx_temp_fn(void* indice_ptr, IndexT* indice = static_cast(indice_ptr); IndexT* mapped_indice = static_cast(mapped_indice_ptr); storage_idx2wm_emb_idx_kernel<<>>( - indice, mapped_indice, indice_size, world_size, entry_per_rank, round_robin_size); + indice, mapped_indice, indice_size, world_size, entry_start, round_robin_size); WM_CUDA_CHECK(cudaStreamSynchronize(stream)); return; } @@ -85,14 +85,16 @@ wholememory_error_code_t storage_index2wm_embedding_index(wholememory_tensor_t i WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, handle)); WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_comm)); - int64_t entry_per_rank = wholememory_tensor_get_entry_per_partition(allocated_embedding); + size_t entry_start = 0; + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_tensor_get_local_entry_start(&entry_start, allocated_embedding)); DISPATCH_ONE_TYPE(indice_desc->dtype, storageidx2wmembidx, indice_ptr, mapped_indice_ptr, indice_size, world_size, - entry_per_rank, + entry_start, round_robin_size, (cudaStream_t)stream_int); WM_CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh index 5fa93ee12..1c0cbb8dc 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh +++ b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,41 +27,107 @@ class nvshmem_device_reference { __device__ __forceinline__ explicit nvshmem_device_reference( const wholememory_nvshmem_ref_t& nvshmem_ref) : pointer_(static_cast(nvshmem_ref.pointer)), - typed_stride_(nvshmem_ref.stride / sizeof(DataTypeT)) + typed_stride_(nvshmem_ref.stride / sizeof(DataTypeT)), + rank_memory_offsets_(nvshmem_ref.rank_memory_offsets), + world_size_(nvshmem_ref.world_size), + same_chunk_(nvshmem_ref.same_chunk) { assert(nvshmem_ref.stride % sizeof(DataTypeT) == 0); + if (!same_chunk_) { + estimated_stride_ = rank_memory_offsets_[world_size_] / world_size_; + cache_rank_ = 0; + cache_offset_ = 0; + cache_size_ = rank_memory_offsets_[1] - rank_memory_offsets_[0]; + } } __device__ nvshmem_device_reference() = delete; __device__ __forceinline__ DataTypeT load(size_t index) { - size_t rank = index / typed_stride_; - - return nvshmem_get(pointer_ + index - rank * typed_stride_, rank); + size_t rank = dest_rank(index); + if (same_chunk_) + return nvshmem_get(pointer_ + index - rank * typed_stride_, rank); + else + return nvshmem_get( + pointer_ + index - rank_memory_offsets_[rank] / sizeof(DataTypeT), rank); } __device__ __forceinline__ void store(size_t index, DataTypeT val) { - size_t rank = index / typed_stride_; - return nvshmem_put(pointer_ + index - rank * typed_stride_, val, rank); + size_t rank = dest_rank(index); + if (same_chunk_) + return nvshmem_put(pointer_ + index - rank * typed_stride_, rank); + else + return nvshmem_put( + pointer_ + index - rank_memory_offsets_[rank] / sizeof(DataTypeT), val, rank); } __device__ __forceinline__ DataTypeT* symmetric_address(size_t index) { - size_t rank = index / typed_stride_; - return pointer_ + index - rank * typed_stride_; + size_t rank = dest_rank(index); + if (same_chunk_) + return pointer_ + index - rank * typed_stride_; + else + return pointer_ + index - rank_memory_offsets_[rank] / sizeof(DataTypeT); + } + + __device__ __forceinline__ void mov_offsets_to_shmem(char* shmem) + { + if (same_chunk_) return; + size_t* shmem_offsets = reinterpret_cast(shmem); + for (int i = threadIdx.x; i <= world_size_; i += blockDim.x) { + shmem_offsets[i] = rank_memory_offsets_[i]; + } + __syncthreads(); + rank_memory_offsets_ = shmem_offsets; } __device__ __forceinline__ size_t dest_rank(size_t index) { - size_t rank = index / typed_stride_; - return rank; + if (same_chunk_) { + return index / typed_stride_; + } else { + size_t rank = 0; + size_t offset = index * sizeof(DataTypeT); + if (offset >= cache_offset_ && offset < cache_offset_ + cache_size_) { + rank = cache_rank_; + } else { + int estimated_rank = max(world_size_ - 1, int(offset / estimated_stride_)); + if (rank_memory_offsets_[estimated_rank] > offset) { + for (int i = estimated_rank - 1; i >= 0; i--) { + if (rank_memory_offsets_[i] <= offset) { + rank = i; + break; + } + } + } else { + for (int i = estimated_rank + 1; i <= world_size_; i++) { + if (rank_memory_offsets_[i] > offset) { + rank = i - 1; + break; + } + } + } + cache_rank_ = rank; + cache_offset_ = rank_memory_offsets_[rank]; + cache_size_ = rank_memory_offsets_[rank + 1] - rank_memory_offsets_[rank]; + } + return rank; + } } private: DataTypeT* pointer_; size_t typed_stride_; + size_t* rank_memory_offsets_; + int world_size_; + + size_t estimated_stride_; + bool same_chunk_; + int cache_rank_; + size_t cache_offset_; + size_t cache_size_; }; } // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int32_indices.cu index 74072280c..7ce32b24a 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int32_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,24 +31,23 @@ void nvshmem_gather_floating_int32_temp_func(wholememory_comm_t wm_comm, void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) { - nvshmem_gather_temp_get_mem_sort_idx_func( - wm_comm, - embeding_nvshmem_ptr, - embedding_desc, - indices, - indice_count, - output, - temp_output, - output_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - gather_sms); + nvshmem_gather_temp_get_mem_sort_idx_func(wm_comm, + embeding_nvshmem_ptr, + embedding_desc, + indices, + indice_count, + output, + temp_output, + output_desc, + embedding_entry_offsets, + p_env_fns, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemGatherFuncFloatingInt32, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_gather_floating_int32_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_gather_floating_int32_func( output, temp_output, output_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, gather_sms); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int64_indices.cu index 65c9fec77..6a5f42b09 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int64_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,24 +31,23 @@ void nvshmem_gather_floating_int64_temp_func(wholememory_comm_t wm_comm, void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) { - nvshmem_gather_temp_get_mem_sort_idx_func( - wm_comm, - embeding_nvshmem_ptr, - embedding_desc, - indices, - indice_count, - output, - temp_output, - output_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - gather_sms); + nvshmem_gather_temp_get_mem_sort_idx_func(wm_comm, + embeding_nvshmem_ptr, + embedding_desc, + indices, + indice_count, + output, + temp_output, + output_desc, + embedding_entry_offsets, + p_env_fns, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemGatherFuncFloatingInt64, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_gather_floating_int64_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_gather_floating_int64_func( output, temp_output, output_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, gather_sms); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int32_indices.cu index b97e47760..65b9c598e 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int32_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,24 +31,23 @@ void nvshmem_gather_integer_int32_temp_func(wholememory_comm_t wm_comm, void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) { - nvshmem_gather_temp_get_mem_sort_idx_func( - wm_comm, - embeding_nvshmem_ptr, - embedding_desc, - indices, - indice_count, - output, - temp_output, - output_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - gather_sms); + nvshmem_gather_temp_get_mem_sort_idx_func(wm_comm, + embeding_nvshmem_ptr, + embedding_desc, + indices, + indice_count, + output, + temp_output, + output_desc, + embedding_entry_offsets, + p_env_fns, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemGatherFuncIntegerInt32, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_gather_integer_int32_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_gather_integer_int32_func( output, temp_output, output_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, gather_sms); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int64_indices.cu index a1876a322..9cfad1b66 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int64_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,24 +31,23 @@ void nvshmem_gather_integer_int64_temp_func(wholememory_comm_t wm_comm, void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) { - nvshmem_gather_temp_get_mem_sort_idx_func( - wm_comm, - embeding_nvshmem_ptr, - embedding_desc, - indices, - indice_count, - output, - temp_output, - output_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - gather_sms); + nvshmem_gather_temp_get_mem_sort_idx_func(wm_comm, + embeding_nvshmem_ptr, + embedding_desc, + indices, + indice_count, + output, + temp_output, + output_desc, + embedding_entry_offsets, + p_env_fns, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemGatherFuncIntegerInt64, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_gather_integer_int64_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_gather_integer_int64_func( output, temp_output, output_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, gather_sms); @@ -113,7 +112,7 @@ __global__ void scatter_func_with_nvshmem_sort_idxs_kernel( const int max_blocks_for_local, const int intra_node_ranks, const int node_rank, - size_t embedding_entry_per_rank, + size_t* embedding_entry_offsets, const int threads_per_group); }; // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh index a0091c31c..8dbee95f3 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh @@ -122,22 +122,23 @@ __global__ void gather_func_with_nvshmem_sort_idxs_kernel( const int max_blocks_for_local, const int intra_node_ranks, const int node_rank, - size_t embedding_entry_per_rank, + size_t* embedding_entry_offsets, EmbeddingT* __restrict__ temp_output, wholememory_matrix_description_t output_desc, const int threads_per_group) { - const int64_t local_index_lowerbound = node_rank * intra_node_ranks * embedding_entry_per_rank; + const int64_t local_index_lowerbound = embedding_entry_offsets[node_rank * intra_node_ranks]; const int64_t local_index_upperbound = - (node_rank + 1) * intra_node_ranks * embedding_entry_per_rank; + embedding_entry_offsets[(node_rank + 1) * intra_node_ranks]; const int64_t local_index_start = LowerBound(sorted_index, indice_count, local_index_lowerbound); const int64_t local_index_length = UpperBound( sorted_index + local_index_start, indice_count - local_index_start, local_index_upperbound - 1); - int embedding_size = embedding_desc.sizes[1]; int64_t embedding_stride = embedding_desc.stride; int64_t output_stride = output_desc.stride; + extern __shared__ char shmem[]; nvshmem_device_reference embedding_nvshmem_device_ref{embeding_nvshmem_ref}; + embedding_nvshmem_device_ref.mov_offsets_to_shmem(shmem); if (blockIdx.x >= max_blocks_for_local) { const int64_t thread_id = (blockIdx.x - max_blocks_for_local) * blockDim.x + threadIdx.x; for (int64_t row_id = thread_id; row_id < indice_count - local_index_length; @@ -313,7 +314,7 @@ void nvshmem_gather_temp_get_mem_sort_idx_func(wholememory_comm_t wm_comm, void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) @@ -336,7 +337,6 @@ void nvshmem_gather_temp_get_mem_sort_idx_func(wholememory_comm_t wm_comm, wm_comm, &thrust_allocator, stream); - int intra_node_rank_num = wm_comm->intra_node_rank_num; int node_id = wm_comm->world_rank / wm_comm->intra_node_rank_num; @@ -357,7 +357,7 @@ void nvshmem_gather_temp_get_mem_sort_idx_func(wholememory_comm_t wm_comm, const int, const int, const int, - size_t, + size_t*, EmbeddingT*, wholememory_matrix_description_t, const int) = nullptr; @@ -416,19 +416,21 @@ void nvshmem_gather_temp_get_mem_sort_idx_func(wholememory_comm_t wm_comm, block_threshold = 1; if (num_blocks == 1) num_blocks = 2; } - - gather_nvshmem_kernel_fn<<>>(embeding_nvshmem_ptr, - embedding_desc, - sorted_index, - dev_raw_indice_ptr, - indice_count, - block_threshold, - intra_node_rank_num, - node_id, - embedding_entry_count_per_rank, - ret_data, - temp_output_desc, - num_threads_per_feature); + size_t shared_mem_size = + embeding_nvshmem_ptr.same_chunk ? 0 : ((embeding_nvshmem_ptr.world_size + 1) * sizeof(size_t)); + gather_nvshmem_kernel_fn<<>>( + embeding_nvshmem_ptr, + embedding_desc, + sorted_index, + dev_raw_indice_ptr, + indice_count, + block_threshold, + intra_node_rank_num, + node_id, + embedding_entry_offsets, + ret_data, + temp_output_desc, + num_threads_per_feature); if (!use_ibgda_flag) { nvshmemx_quiet_on_stream(stream); // wait transfer } @@ -467,12 +469,12 @@ __global__ void scatter_func_with_nvshmem_sort_idxs_kernel( const int max_blocks_for_local, const int intra_node_ranks, const int node_rank, - size_t embedding_entry_per_rank, + size_t* embedding_entry_offsets, const int threads_per_group) { - const int64_t local_index_lowerbound = node_rank * intra_node_ranks * embedding_entry_per_rank; + const int64_t local_index_lowerbound = embedding_entry_offsets[node_rank * intra_node_ranks]; const int64_t local_index_upperbound = - (node_rank + 1) * intra_node_ranks * embedding_entry_per_rank; + embedding_entry_offsets[(node_rank + 1) * intra_node_ranks]; const int64_t local_index_start = LowerBound(sorted_index, indice_count, local_index_lowerbound); const int64_t local_index_length = UpperBound( sorted_index + local_index_start, indice_count - local_index_start, local_index_upperbound - 1); @@ -480,7 +482,9 @@ __global__ void scatter_func_with_nvshmem_sort_idxs_kernel( int embedding_size = embedding_desc.sizes[1]; int64_t embedding_stride = embedding_desc.stride; int64_t input_stride = temp_input_desc.stride; + extern __shared__ char shmem[]; nvshmem_device_reference embedding_nvshmem_device_ref{embeding_nvshmem_ref}; + embedding_nvshmem_device_ref.mov_offsets_to_shmem(shmem); if (blockIdx.x >= max_blocks_for_local) { const int64_t thread_id = (blockIdx.x - max_blocks_for_local) * blockDim.x + threadIdx.x; for (int64_t row_id = thread_id; row_id < indice_count - local_index_length; @@ -554,7 +558,7 @@ void nvshmem_scatter_temp_put_mem_sort_idx_func(wholememory_comm_t wm_comm, int64_t indice_count, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) @@ -620,7 +624,7 @@ void nvshmem_scatter_temp_put_mem_sort_idx_func(wholememory_comm_t wm_comm, const int, const int, const int, - size_t, + size_t*, const int) = nullptr; switch (alignment) { @@ -679,19 +683,21 @@ void nvshmem_scatter_temp_put_mem_sort_idx_func(wholememory_comm_t wm_comm, if (num_blocks == 1) num_blocks = 2; } - scatter_nvshmem_kernel_fn<<>>(temp_input_data, - temp_input_desc, - embeding_nvshmem_ptr, - embedding_desc, - sorted_index, - dev_raw_indice_ptr, - indice_count, - block_threshold, - intra_node_rank_num, - node_id, - embedding_entry_count_per_rank, - - num_threads_per_feature); + size_t shared_mem_size = + embeding_nvshmem_ptr.same_chunk ? 0 : ((embeding_nvshmem_ptr.world_size + 1) * sizeof(size_t)); + scatter_nvshmem_kernel_fn<<>>( + temp_input_data, + temp_input_desc, + embeding_nvshmem_ptr, + embedding_desc, + sorted_index, + dev_raw_indice_ptr, + indice_count, + block_threshold, + intra_node_rank_num, + node_id, + embedding_entry_offsets, + num_threads_per_feature); if (!use_ibgda_flag) { nvshmemx_quiet_on_stream(stream); // wait transfer } diff --git a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int32_indices.cu index 3fe3a96fa..760df036f 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int32_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,24 +31,23 @@ void nvshmem_scatter_floating_int32_temp_func(wholememory_comm_t wm_comm, int64_t indice_count, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) { - nvshmem_scatter_temp_put_mem_sort_idx_func( - wm_comm, - input, - temp_input, - input_desc, - indices, - indice_count, - embeding_nvshmem_ptr, - embedding_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - scatter_sms); + nvshmem_scatter_temp_put_mem_sort_idx_func(wm_comm, + input, + temp_input, + input_desc, + indices, + indice_count, + embeding_nvshmem_ptr, + embedding_desc, + embedding_entry_offsets, + p_env_fns, + stream, + scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemScatterFuncFloatingInt32, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_scatter_floating_int32_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_scatter_floating_int32_func( indices_desc.size, embeding_nvshmem_ptr, embedding_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, scatter_sms); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int64_indices.cu index 51107a5fc..5c22a3975 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int64_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,24 +31,23 @@ void nvshmem_scatter_floating_int64_temp_func(wholememory_comm_t wm_comm, int64_t indice_count, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) { - nvshmem_scatter_temp_put_mem_sort_idx_func( - wm_comm, - input, - temp_input, - input_desc, - indices, - indice_count, - embeding_nvshmem_ptr, - embedding_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - scatter_sms); + nvshmem_scatter_temp_put_mem_sort_idx_func(wm_comm, + input, + temp_input, + input_desc, + indices, + indice_count, + embeding_nvshmem_ptr, + embedding_desc, + embedding_entry_offsets, + p_env_fns, + stream, + scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemScatterFuncFloatingInt64, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_scatter_floating_int64_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_scatter_floating_int64_func( indices_desc.size, embeding_nvshmem_ptr, embedding_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, scatter_sms); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int32_indices.cu index 4530442be..7532332e6 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int32_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,24 +31,23 @@ void nvshmem_scatter_integer_int32_temp_func(wholememory_comm_t wm_comm, int64_t indice_count, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) { - nvshmem_scatter_temp_put_mem_sort_idx_func( - wm_comm, - input, - temp_input, - input_desc, - indices, - indice_count, - embeding_nvshmem_ptr, - embedding_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - scatter_sms); + nvshmem_scatter_temp_put_mem_sort_idx_func(wm_comm, + input, + temp_input, + input_desc, + indices, + indice_count, + embeding_nvshmem_ptr, + embedding_desc, + embedding_entry_offsets, + p_env_fns, + stream, + scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemScatterFuncIntegerInt32, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_scatter_integer_int32_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_scatter_integer_int32_func( indices_desc.size, embeding_nvshmem_ptr, embedding_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, scatter_sms); diff --git a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int64_indices.cu index 6d98b9b6b..a17b49df7 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int64_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,24 +31,23 @@ void nvshmem_scatter_integer_int64_temp_func(wholememory_comm_t wm_comm, int64_t indice_count, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) { - nvshmem_scatter_temp_put_mem_sort_idx_func( - wm_comm, - input, - temp_input, - input_desc, - indices, - indice_count, - embeding_nvshmem_ptr, - embedding_desc, - embedding_entry_count_per_rank, - p_env_fns, - stream, - scatter_sms); + nvshmem_scatter_temp_put_mem_sort_idx_func(wm_comm, + input, + temp_input, + input_desc, + indices, + indice_count, + embeding_nvshmem_ptr, + embedding_desc, + embedding_entry_offsets, + p_env_fns, + stream, + scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemScatterFuncIntegerInt64, @@ -65,7 +64,7 @@ wholememory_error_code_t nvshmem_scatter_integer_int64_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms) @@ -86,7 +85,7 @@ wholememory_error_code_t nvshmem_scatter_integer_int64_func( indices_desc.size, embeding_nvshmem_ptr, embedding_desc, - embedding_entry_count_per_rank, + embedding_entry_offsets, p_env_fns, stream, scatter_sms); diff --git a/cpp/src/wholememory_ops/gather_op_impl_nccl.cu b/cpp/src/wholememory_ops/gather_op_impl_nccl.cu index e842a829a..1b8abaab2 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_nccl.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_nccl.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,22 +49,9 @@ wholememory_error_code_t wholememory_gather_nccl(wholememory_handle_t wholememor wm_thrust_allocator thrust_allocator(p_env_fns); - size_t embedding_size_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL( - wholememory_get_partition_plan(&embedding_size_per_rank, wholememory_handle)); - size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); size_t embedding_entry_size = element_size * wholememory_desc.stride; - WHOLEMEMORY_EXPECTS_NOTHROW( - embedding_size_per_rank % embedding_entry_size == 0, - "embedding_size_per_rank=%ld is not multiple of embedding_entry_size=%ldx%ld", - embedding_size_per_rank, - element_size, - wholememory_desc.stride); - - size_t embedding_entry_count_per_rank = embedding_size_per_rank / embedding_entry_size; - wholememory_comm_t wm_comm; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, wholememory_handle)); @@ -83,18 +70,44 @@ wholememory_error_code_t wholememory_gather_nccl(wholememory_handle_t wholememor static_cast(dev_raw_indice.device_malloc(indice_desc.size, WHOLEMEMORY_DT_INT64)); int64_t total_recv_count = 0; + + temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_rank_partition_offsets(host_embedding_entry_offsets_ptr, wholememory_handle)); + for (int i = 0; i < world_size + 1; i++) { + size_t offset = host_embedding_entry_offsets_ptr[i]; + WHOLEMEMORY_EXPECTS_NOTHROW( + offset % embedding_entry_size == 0, + "embedding memory offset of rank%d=%ld is not multiple of embedding_entry_size=%ldx%ld", + i, + offset, + element_size, + wholememory_desc.stride); + host_embedding_entry_offsets_ptr[i] /= embedding_entry_size; + } + + WM_CUDA_CHECK(cudaMemcpyAsync(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice, + stream)); WHOLEMEMORY_RETURN_ON_FAIL(bucket_and_exchange_ids_func(indices, indice_desc, host_recv_rank_id_count_ptr, host_rank_id_count_ptr, &dev_recv_indice_buffer, dev_raw_indice_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, wm_comm, &thrust_allocator, p_env_fns, stream)); - // Local Gather for (int i = 0; i < world_size; i++) { total_recv_count += host_recv_rank_id_count_ptr[i]; diff --git a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu index 8a683a8c1..789dcd487 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ wholememory_error_code_t nvshmem_gather_floating_int32_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms); @@ -66,7 +66,7 @@ wholememory_error_code_t nvshmem_gather_floating_int64_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms); @@ -80,7 +80,7 @@ wholememory_error_code_t nvshmem_gather_integer_int64_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms); @@ -93,7 +93,7 @@ wholememory_error_code_t nvshmem_gather_integer_int32_func( void* output, void* temp_output, wholememory_matrix_description_t output_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms); @@ -119,24 +119,41 @@ wholememory_error_code_t wholememory_gather_nvshmem( embedding_is_float == output_is_float, "embedding and output should be same number type, e.g. floating number or integer number."); if (indice_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } - size_t embedding_size_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL( - wholememory_get_partition_plan(&embedding_size_per_rank, wholememory_handle)); - size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); - size_t embedding_entry_size = element_size * wholememory_desc.stride; + wholememory_comm_t wm_comm; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, wholememory_handle)); - WHOLEMEMORY_EXPECTS_NOTHROW( - embedding_size_per_rank % embedding_entry_size == 0, - "embedding_size_per_rank=%ld is not multiple of embedding_entry_size=%ldx%ld", - embedding_size_per_rank, - element_size, - wholememory_desc.stride); + int world_size; + WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_comm)); - size_t embedding_entry_count_per_rank = embedding_size_per_rank / embedding_entry_size; + temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); - wholememory_comm_t wm_comm; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, wholememory_handle)); + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_rank_partition_offsets(host_embedding_entry_offsets_ptr, wholememory_handle)); + + size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); + size_t embedding_entry_size = element_size * wholememory_desc.stride; + for (int i = 0; i < world_size + 1; i++) { + size_t offset = host_embedding_entry_offsets_ptr[i]; + WHOLEMEMORY_EXPECTS_NOTHROW( + offset % embedding_entry_size == 0, + "embedding memory offset of rank%d=%ld is not multiple of embedding_entry_size=%ldx%ld", + i, + offset, + element_size, + wholememory_desc.stride); + host_embedding_entry_offsets_ptr[i] /= embedding_entry_size; + } + WM_CUDA_CHECK(cudaMemcpyAsync(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice, + stream)); wholememory_nvshmem_ref_t embedding_nvshmem_ref; WHOLEMEMORY_RETURN_ON_FAIL( @@ -161,7 +178,7 @@ wholememory_error_code_t wholememory_gather_nvshmem( void*, void*, wholememory_matrix_description_t, - size_t, + size_t*, wholememory_env_func_t*, cudaStream_t, int) = nullptr; @@ -187,7 +204,7 @@ wholememory_error_code_t wholememory_gather_nvshmem( output, temp_output_ptr, output_desc, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, p_env_fns, stream, gather_sms); diff --git a/cpp/src/wholememory_ops/scatter_op_impl.nvshmem.cu b/cpp/src/wholememory_ops/scatter_op_impl.nvshmem.cu index 80dc20784..77926df9e 100644 --- a/cpp/src/wholememory_ops/scatter_op_impl.nvshmem.cu +++ b/cpp/src/wholememory_ops/scatter_op_impl.nvshmem.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,7 +49,7 @@ wholememory_error_code_t nvshmem_scatter_floating_int32_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms); @@ -63,7 +63,7 @@ wholememory_error_code_t nvshmem_scatter_floating_int64_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms); @@ -77,7 +77,7 @@ wholememory_error_code_t nvshmem_scatter_integer_int32_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms); @@ -91,7 +91,7 @@ wholememory_error_code_t nvshmem_scatter_integer_int64_func( wholememory_array_description_t indices_desc, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, wholememory_matrix_description_t embedding_desc, - size_t embedding_entry_count_per_rank, + size_t* embedding_entry_offsets, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int scatter_sms); @@ -122,29 +122,41 @@ wholememory_error_code_t wholememory_scatter_nvshmem( return WHOLEMEMORY_INVALID_INPUT; } - size_t embedding_size_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL( - wholememory_get_partition_plan(&embedding_size_per_rank, wholememory_handle)); - - size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); - size_t embedding_entry_size = element_size * wholememory_desc.stride; - - WHOLEMEMORY_EXPECTS_NOTHROW( - embedding_size_per_rank % embedding_entry_size == 0, - "embedding_size_per_rank=%ld is not multiple of embedding_entry_size=%ldx%ld", - embedding_size_per_rank, - element_size, - wholememory_desc.stride); - - size_t embedding_entry_count_per_rank = embedding_size_per_rank / embedding_entry_size; - wholememory_comm_t wm_comm; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, wholememory_handle)); int world_size; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_size(&world_size, wm_comm)); - int world_rank; - WHOLEMEMORY_RETURN_ON_FAIL(wholememory_communicator_get_rank(&world_rank, wm_comm)); + + temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_rank_partition_offsets(host_embedding_entry_offsets_ptr, wholememory_handle)); + + size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); + size_t embedding_entry_size = element_size * wholememory_desc.stride; + for (int i = 0; i < world_size + 1; i++) { + size_t offset = host_embedding_entry_offsets_ptr[i]; + WHOLEMEMORY_EXPECTS_NOTHROW( + offset % embedding_entry_size == 0, + "embedding memory offset of rank%d=%ld is not multiple of embedding_entry_size=%ldx%ld", + i, + offset, + element_size, + wholememory_desc.stride); + host_embedding_entry_offsets_ptr[i] /= embedding_entry_size; + } + WM_CUDA_CHECK(cudaMemcpyAsync(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice, + stream)); + wholememory_nvshmem_ref_t embedding_nvshmem_ref; WHOLEMEMORY_RETURN_ON_FAIL( wholememory_get_nvshmem_reference(&embedding_nvshmem_ref, wholememory_handle)); @@ -168,7 +180,7 @@ wholememory_error_code_t wholememory_scatter_nvshmem( wholememory_array_description_t, wholememory_nvshmem_ref_t, wholememory_matrix_description_t, - size_t, + size_t*, wholememory_env_func_t*, cudaStream_t, int); @@ -195,7 +207,7 @@ wholememory_error_code_t wholememory_scatter_nvshmem( indices_desc, embedding_nvshmem_ref, wholememory_desc, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, p_env_fns, stream, scatter_sms); diff --git a/cpp/src/wholememory_ops/scatter_op_impl_nccl.cu b/cpp/src/wholememory_ops/scatter_op_impl_nccl.cu index 47765de17..95e2fe6de 100644 --- a/cpp/src/wholememory_ops/scatter_op_impl_nccl.cu +++ b/cpp/src/wholememory_ops/scatter_op_impl_nccl.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,22 +53,9 @@ wholememory_error_code_t wholememory_scatter_nccl(void* input, wm_thrust_allocator thrust_allocator(p_env_fns); - size_t embedding_size_per_rank; - WHOLEMEMORY_RETURN_ON_FAIL( - wholememory_get_partition_plan(&embedding_size_per_rank, wholememory_handle)); - size_t element_size = wholememory_dtype_get_element_size(wholememory_desc.dtype); size_t embedding_entry_size = element_size * wholememory_desc.stride; - WHOLEMEMORY_EXPECTS_NOTHROW( - embedding_size_per_rank % embedding_entry_size == 0, - "embedding_size_per_rank=%ld is not multiple of embedding_entry_size=%ldx%ld", - embedding_size_per_rank, - element_size, - wholememory_desc.stride); - - size_t embedding_entry_count_per_rank = embedding_size_per_rank / embedding_entry_size; - wholememory_comm_t wm_comm; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_get_communicator(&wm_comm, wholememory_handle)); @@ -87,13 +74,39 @@ wholememory_error_code_t wholememory_scatter_nccl(void* input, static_cast(dev_raw_indice.device_malloc(indices_desc.size, WHOLEMEMORY_DT_INT64)); int64_t total_recv_count = 0; + + temp_memory_handle dev_embedding_entry_offsets_handle(p_env_fns); + size_t* dev_embedding_entry_offsets_ptr = static_cast( + dev_embedding_entry_offsets_handle.device_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + temp_memory_handle host_embedding_entry_offsets_handle(p_env_fns); + size_t* host_embedding_entry_offsets_ptr = static_cast( + host_embedding_entry_offsets_handle.host_malloc(world_size + 1, WHOLEMEMORY_DT_INT64)); + + WHOLEMEMORY_RETURN_ON_FAIL( + wholememory_get_rank_partition_offsets(host_embedding_entry_offsets_ptr, wholememory_handle)); + for (int i = 0; i < world_size + 1; i++) { + size_t offset = host_embedding_entry_offsets_ptr[i]; + WHOLEMEMORY_EXPECTS_NOTHROW( + offset % embedding_entry_size == 0, + "embedding memory offset of rank%d=%ld is not multiple of embedding_entry_size=%ldx%ld", + i, + offset, + element_size, + wholememory_desc.stride); + host_embedding_entry_offsets_ptr[i] /= embedding_entry_size; + } + WM_CUDA_CHECK(cudaMemcpyAsync(dev_embedding_entry_offsets_ptr, + host_embedding_entry_offsets_ptr, + (world_size + 1) * sizeof(size_t), + cudaMemcpyHostToDevice, + stream)); WHOLEMEMORY_RETURN_ON_FAIL(bucket_and_exchange_ids_func(indices, indices_desc, host_recv_rank_id_count_ptr, host_rank_id_count_ptr, &dev_recv_indice_buffer, dev_raw_indice_ptr, - embedding_entry_count_per_rank, + dev_embedding_entry_offsets_ptr, wm_comm, &thrust_allocator, p_env_fns, diff --git a/cpp/tests/wholememory_ops/embedding_test_utils.cu b/cpp/tests/wholememory_ops/embedding_test_utils.cu index 8ee7293ed..8cd286fbd 100644 --- a/cpp/tests/wholememory_ops/embedding_test_utils.cu +++ b/cpp/tests/wholememory_ops/embedding_test_utils.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -528,5 +528,22 @@ void host_random_init_float(float* data, int64_t len, float max_value, float min } } +void host_random_partition(size_t* partition_sizes, size_t total_size, int partition_count) +{ + std::default_random_engine random_engine(0); + std::uniform_int_distribution uniform(90, 100); + size_t acc_size = 0; + size_t random_sum = 0; + for (int i = 0; i < partition_count; i++) { + partition_sizes[i] = (size_t)uniform(random_engine); + random_sum += partition_sizes[i]; + } + for (int i = 0; i < partition_count; i++) { + partition_sizes[i] = (size_t)((partition_sizes[i] / (double)random_sum) * total_size); + acc_size += partition_sizes[i]; + } + partition_sizes[0] += total_size - acc_size; +} + } // namespace testing } // namespace wholememory_ops diff --git a/cpp/tests/wholememory_ops/embedding_test_utils.hpp b/cpp/tests/wholememory_ops/embedding_test_utils.hpp index 209bc5916..62a02ce75 100644 --- a/cpp/tests/wholememory_ops/embedding_test_utils.hpp +++ b/cpp/tests/wholememory_ops/embedding_test_utils.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,5 +63,7 @@ void host_check_embedding_same(void* host_embedding, void host_random_init_float(float* data, int64_t len, float max_value, float min_value); +void host_random_partition(size_t* partition_sizes, size_t total_size, int partition_count); + } // namespace testing } // namespace wholememory_ops diff --git a/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu b/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu index 83e57e14b..92ab24095 100644 --- a/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu @@ -114,6 +114,7 @@ struct EmbeddingBackwardTestParams { WHOLEMEMORY_SUCCESS); return cache_policy; } + int get_rank_partition_method() const { return rank_partition_method; } EmbeddingBackwardTestParams& set_use_cache() { use_cache = true; @@ -139,6 +140,11 @@ struct EmbeddingBackwardTestParams { lr_ = lr; return *this; } + EmbeddingBackwardTestParams& use_random_partition() + { + rank_partition_method = 1; + return *this; + } wholememory_array_description_t indice_description; wholememory_matrix_description_t embedding_description; wholememory_matrix_description_t grad_description; @@ -154,6 +160,7 @@ struct EmbeddingBackwardTestParams { float lr_ = 0.1; std::map optimizer_params; + int rank_partition_method = 0; // 0-default, 1-random }; class WholeMemoryEmbeddingBackwardParameterTests @@ -586,13 +593,18 @@ TEST_P(WholeMemoryEmbeddingBackwardParameterTests, EmbeddingGatherGradientApplyT optimizer, param_name_value.first.c_str(), ¶m_name_value.second), WHOLEMEMORY_SUCCESS); } - + std::vector rank_partition(world_size); + wholememory_ops::testing::host_random_partition( + rank_partition.data(), embedding_tensor_description.sizes[0], world_size); + size_t* rank_partition_ptr = nullptr; + if (params.get_rank_partition_method() == 1) { rank_partition_ptr = rank_partition.data(); } EXPECT_EQ(wholememory_create_embedding(&wm_embedding, &embedding_tensor_description, wm_comm, params.memory_type, params.memory_location, - cache_policy), + cache_policy, + rank_partition_ptr), WHOLEMEMORY_SUCCESS); EXPECT_EQ(wholememory_embedding_set_optimizer(wm_embedding, optimizer), WHOLEMEMORY_SUCCESS); wholememory_tensor_t embedding_tensor = @@ -602,19 +614,18 @@ TEST_P(WholeMemoryEmbeddingBackwardParameterTests, EmbeddingGatherGradientApplyT WHOLEMEMORY_SUCCESS); wholememory_handle_t embedding_handle = wholememory_tensor_get_memory_handle(embedding_tensor); - auto entry_per_partition = wholememory_tensor_get_entry_per_partition(embedding_tensor); - int64_t total_entry_count = params.embedding_description.sizes[0]; - int64_t rank_start_entry = - std::min(world_rank * entry_per_partition, total_entry_count); - int64_t rank_end_entry = - std::min((world_rank + 1) * entry_per_partition, total_entry_count); - int64_t rank_entry_count = rank_end_entry - rank_start_entry; - + size_t rank_entry_count = 0; + size_t rank_start_entry = 0; + EXPECT_EQ(wholememory_tensor_get_local_entry_count(&rank_entry_count, embedding_tensor), + WHOLEMEMORY_SUCCESS); + EXPECT_EQ(wholememory_tensor_get_local_entry_start(&rank_start_entry, embedding_tensor), + WHOLEMEMORY_SUCCESS); + rank_entry_count = std::min( + rank_entry_count, params.embedding_description.sizes[0] - rank_start_entry); auto* dst_base_ptr = static_cast(wholememory_tensor_get_data_pointer(local_embed_tensor)); size_t dst_stride = wholememory_tensor_get_tensor_description(local_embed_tensor)->strides[0]; size_t embedding_copy_size = embedding_dim * sizeof(float); - for (int64_t i = 0; i < rank_entry_count; i++) { WM_CUDA_CHECK_NO_THROW(cudaMemcpy(dst_base_ptr + i * dst_stride, start_embedding_table[rank_start_entry + i].data(), @@ -738,6 +749,30 @@ INSTANTIATE_TEST_SUITE_P( #endif EmbeddingBackwardTestParams().set_entry_count(500).set_indice_count(400).set_embedding_dim(4), EmbeddingBackwardTestParams().set_embedding_dim(3), + EmbeddingBackwardTestParams() + .set_memory_location(WHOLEMEMORY_ML_DEVICE) + .set_optimizer_type(WHOLEMEMORY_OPT_RMSPROP) + .use_random_partition(), + EmbeddingBackwardTestParams() + .set_memory_location(WHOLEMEMORY_ML_DEVICE) + .set_optimizer_type(WHOLEMEMORY_OPT_ADAGRAD) + .use_random_partition(), + EmbeddingBackwardTestParams() + .set_memory_location(WHOLEMEMORY_ML_DEVICE) + .set_optimizer_type(WHOLEMEMORY_OPT_LAZY_ADAM) + .use_random_partition(), + EmbeddingBackwardTestParams() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_optimizer_type(WHOLEMEMORY_OPT_RMSPROP) + .use_random_partition(), + EmbeddingBackwardTestParams() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_optimizer_type(WHOLEMEMORY_OPT_ADAGRAD) + .use_random_partition(), + EmbeddingBackwardTestParams() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_optimizer_type(WHOLEMEMORY_OPT_LAZY_ADAM) + .use_random_partition(), EmbeddingBackwardTestParams().set_use_cache().set_grad_stride(131), EmbeddingBackwardTestParams().set_use_cache().set_grad_stride(131).set_optimizer_type( WHOLEMEMORY_OPT_RMSPROP), diff --git a/cpp/tests/wholememory_ops/wholememory_embedding_tests.cu b/cpp/tests/wholememory_ops/wholememory_embedding_tests.cu index 03f798775..152b2de1a 100644 --- a/cpp/tests/wholememory_ops/wholememory_embedding_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_embedding_tests.cu @@ -127,6 +127,7 @@ struct EmbeddingTestParams { WHOLEMEMORY_SUCCESS); return cache_policy; } + int get_rank_partition_method() const { return rank_partition_method; } EmbeddingTestParams& non_cache() { cache_type = 0; @@ -147,6 +148,11 @@ struct EmbeddingTestParams { cache_group_count = count; return *this; } + EmbeddingTestParams& use_random_partition() + { + rank_partition_method = 1; + return *this; + } wholememory_array_description_t indice_description; wholememory_matrix_description_t embedding_description; wholememory_matrix_description_t output_description; @@ -155,8 +161,9 @@ struct EmbeddingTestParams { wholememory_memory_type_t cache_memory_type = WHOLEMEMORY_MT_CHUNKED; wholememory_memory_location_t cache_memory_location = WHOLEMEMORY_ML_DEVICE; float cache_ratio = 0.2; - int cache_type = 0; // 0: no cache, 1: device cache, 2: local cache - int cache_group_count = 1; + int cache_type = 0; // 0: no cache, 1: device cache, 2: local cache + int cache_group_count = 1; + int rank_partition_method = 0; // 0-default, 1-random }; class WholeMemoryEmbeddingParameterTests : public ::testing::TestWithParam {}; @@ -238,13 +245,18 @@ TEST_P(WholeMemoryEmbeddingParameterTests, EmbeddingGatherTest) wholememory_tensor_description_t embedding_tensor_description; wholememory_copy_matrix_desc_to_tensor(&embedding_tensor_description, ¶ms.embedding_description); - + std::vector rank_partition(world_size); + wholememory_ops::testing::host_random_partition( + rank_partition.data(), embedding_tensor_description.sizes[0], world_size); + size_t* rank_partition_ptr = nullptr; + if (params.get_rank_partition_method() == 1) { rank_partition_ptr = rank_partition.data(); } EXPECT_EQ(wholememory_create_embedding(&wm_embedding, &embedding_tensor_description, wm_comm, params.memory_type, params.memory_location, - cache_policy), + cache_policy, + rank_partition_ptr), WHOLEMEMORY_SUCCESS); wholememory_tensor_t embedding_tensor = @@ -353,6 +365,19 @@ INSTANTIATE_TEST_SUITE_P( #if 1 EmbeddingTestParams().non_cache(), EmbeddingTestParams().non_cache().set_memory_location(WHOLEMEMORY_ML_DEVICE), + EmbeddingTestParams() + .non_cache() + .set_memory_location(WHOLEMEMORY_ML_DEVICE) + .use_random_partition(), + EmbeddingTestParams() + .non_cache() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .use_random_partition(), + EmbeddingTestParams() + .non_cache() + .set_memory_location(WHOLEMEMORY_ML_DEVICE) + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .use_random_partition(), EmbeddingTestParams().device_cache(), EmbeddingTestParams().device_cache().set_cache_memory_type(WHOLEMEMORY_MT_DISTRIBUTED), EmbeddingTestParams().local_cache(), diff --git a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu index ada9c87e1..f86c4b93f 100644 --- a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu @@ -50,7 +50,7 @@ typedef struct WholeMemoryGatherTestParam { { return embedding_stride * wholememory_dtype_get_element_size(embedding_type); } - + int get_rank_partition_method() const { return rank_partition_method; } WholeMemoryGatherTestParam& set_memory_type(wholememory_memory_type_t new_memory_type) { memory_type = new_memory_type; @@ -109,6 +109,11 @@ typedef struct WholeMemoryGatherTestParam { distributed_backend = new_distributed_backend; return *this; } + WholeMemoryGatherTestParam& use_random_partition() + { + rank_partition_method = 1; + return *this; + } wholememory_memory_type_t memory_type = WHOLEMEMORY_MT_CHUNKED; wholememory_memory_location_t memory_location = WHOLEMEMORY_ML_DEVICE; int64_t embedding_entry_count = 1000000LL; @@ -123,6 +128,7 @@ typedef struct WholeMemoryGatherTestParam { int64_t indices_storage_offset = 0; int64_t output_storage_offset = 0; wholememory_distributed_backend_t distributed_backend = WHOLEMEMORY_DB_NCCL; + int rank_partition_method = 0; // 0-default, 1-random } WholeMemoryGatherTestParam; class WholeMemoryGatherParameterTests @@ -164,14 +170,19 @@ TEST_P(WholeMemoryGatherParameterTests, GatherTest) auto indices_desc = params.get_indices_desc(); auto output_desc = params.get_output_desc(); size_t embedding_entry_size = params.get_embedding_granularity(); + std::vector rank_partition(world_size); + wholememory_ops::testing::host_random_partition( + rank_partition.data(), embedding_desc.sizes[0], world_size); + size_t* rank_partition_ptr = nullptr; + if (params.get_rank_partition_method() == 1) { rank_partition_ptr = rank_partition.data(); } EXPECT_EQ(wholememory_malloc(&embedding_handle, wholememory_get_memory_size_from_matrix(&embedding_desc), wm_comm, params.memory_type, params.memory_location, - embedding_entry_size), + embedding_entry_size, + rank_partition_ptr), WHOLEMEMORY_SUCCESS); - cudaStream_t stream; EXPECT_EQ(cudaStreamCreate(&stream), cudaSuccess); @@ -301,6 +312,12 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_memory_location(WHOLEMEMORY_ML_HOST), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).use_random_partition(), + WholeMemoryGatherTestParam().set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED).use_random_partition(), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .use_random_partition(), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_memory_location(WHOLEMEMORY_ML_HOST) @@ -415,6 +432,10 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_distributed_backend(WHOLEMEMORY_DB_NVSHMEM), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_distributed_backend(WHOLEMEMORY_DB_NVSHMEM) + .use_random_partition(), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_embedding_dim(11) diff --git a/cpp/tests/wholememory_ops/wholememory_scatter_tests.cu b/cpp/tests/wholememory_ops/wholememory_scatter_tests.cu index a30b1b90f..656d608ea 100644 --- a/cpp/tests/wholememory_ops/wholememory_scatter_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_scatter_tests.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,6 +47,7 @@ typedef struct WholeMemoryScatterTestParam { { return embedding_stride * wholememory_dtype_get_element_size(embedding_type); } + int get_rank_partition_method() const { return rank_partition_method; } WholeMemoryScatterTestParam& set_memory_type(wholememory_memory_type_t new_memory_type) { @@ -107,6 +108,11 @@ typedef struct WholeMemoryScatterTestParam { distributed_backend = new_distributed_backend; return *this; } + WholeMemoryScatterTestParam& use_random_partition() + { + rank_partition_method = 1; + return *this; + } wholememory_memory_type_t memory_type = WHOLEMEMORY_MT_CHUNKED; wholememory_memory_location_t memory_location = WHOLEMEMORY_ML_DEVICE; int64_t embedding_entry_count = 1000000LL; @@ -121,6 +127,7 @@ typedef struct WholeMemoryScatterTestParam { int64_t indices_storage_offset = 0; int64_t input_storage_offset = 0; wholememory_distributed_backend_t distributed_backend = WHOLEMEMORY_DB_NCCL; + int rank_partition_method = 0; // 0-default, 1-random } WholeMemoryScatterTestParam; class WholeMemoryScatterParameterTests @@ -161,12 +168,18 @@ TEST_P(WholeMemoryScatterParameterTests, ScatterTest) auto indices_desc = params.get_indices_desc(); auto input_desc = params.get_input_desc(); size_t embedding_entry_size = params.get_embedding_granularity(); + std::vector rank_partition(world_size); + wholememory_ops::testing::host_random_partition( + rank_partition.data(), embedding_desc.sizes[0], world_size); + size_t* rank_partition_ptr = nullptr; + if (params.get_rank_partition_method() == 1) { rank_partition_ptr = rank_partition.data(); } EXPECT_EQ(wholememory_malloc(&embedding_handle, wholememory_get_memory_size_from_matrix(&embedding_desc), wm_comm, params.memory_type, params.memory_location, - embedding_entry_size), + embedding_entry_size, + rank_partition_ptr), WHOLEMEMORY_SUCCESS); cudaStream_t stream; @@ -304,6 +317,14 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryScatterTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_memory_location(WHOLEMEMORY_ML_HOST), + WholeMemoryScatterTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).use_random_partition(), + WholeMemoryScatterTestParam() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .use_random_partition(), + WholeMemoryScatterTestParam() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .use_random_partition(), WholeMemoryScatterTestParam().set_memory_type(WHOLEMEMORY_MT_CONTINUOUS).set_embedding_dim(128), WholeMemoryScatterTestParam().set_memory_type(WHOLEMEMORY_MT_CHUNKED).set_embedding_dim(128), WholeMemoryScatterTestParam() @@ -404,6 +425,10 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryScatterTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_distributed_backend(WHOLEMEMORY_DB_NVSHMEM), + WholeMemoryScatterTestParam() + .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) + .set_distributed_backend(WHOLEMEMORY_DB_NVSHMEM) + .use_random_partition(), WholeMemoryScatterTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_indices_count(0) diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index dc72eb32c..61039d83c 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -122,7 +122,8 @@ cdef extern from "wholememory/wholememory.h": wholememory_comm_t comm, wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, - size_t data_granularity) + size_t data_granularity, + size_t * rank_entry_partition) cdef wholememory_error_code_t wholememory_free(wholememory_handle_t wholememory_handle) @@ -140,26 +141,30 @@ cdef extern from "wholememory/wholememory.h": size_t * local_offset, wholememory_handle_t wholememory_handle) + cdef wholememory_error_code_t wholememory_get_local_size(size_t * local_size, + wholememory_handle_t wholememory_handle) + + cdef wholememory_error_code_t wholememory_get_local_offset(size_t * local_offset, + wholememory_handle_t wholememory_handle) + cdef wholememory_error_code_t wholememory_get_rank_memory(void** rank_memory_ptr, size_t * rank_memory_size, size_t * rank_memory_offset, int rank, wholememory_handle_t wholememory_handle) + cdef wholememory_error_code_t wholememory_equal_entry_partition_plan(size_t* entry_per_rank, + size_t total_entry_count, + int world_size) + cdef wholememory_error_code_t wholememory_get_global_pointer(void** global_ptr, wholememory_handle_t wholememory_handle) - cdef wholememory_error_code_t wholememory_determine_partition_plan(size_t * size_per_rank, - size_t total_size, - size_t data_granularity, - int world_size) - - cdef wholememory_error_code_t wholememory_determine_entry_partition_plan(size_t * entry_per_rank, - size_t total_entry_count, - int world_size) + cdef wholememory_error_code_t wholememory_get_rank_partition_sizes(size_t * rank_mem_sizes, + wholememory_handle_t wholememory_handle) - cdef wholememory_error_code_t wholememory_get_partition_plan(size_t * size_per_rank, - wholememory_handle_t wholememory_handle) + cdef wholememory_error_code_t wholememory_get_rank_partition_offsets(size_t * rank_mem_offsets, + wholememory_handle_t wholememory_handle) cdef int fork_get_device_count() @@ -549,7 +554,8 @@ cdef extern from "wholememory/wholememory_tensor.h": wholememory_tensor_description_t *tensor_description, wholememory_comm_t comm, wholememory_memory_type_t memory_type, - wholememory_memory_location_t memory_location) + wholememory_memory_location_t memory_location, + size_t * tensor_entry_partition) cdef wholememory_error_code_t wholememory_destroy_tensor(wholememory_tensor_t wholememory_tensor) @@ -568,6 +574,18 @@ cdef extern from "wholememory/wholememory_tensor.h": cdef wholememory_tensor_description_t * wholememory_tensor_get_tensor_description( wholememory_tensor_t wholememory_tensor) + cdef wholememory_error_code_t wholememory_tensor_get_entry_offsets( + size_t * entry_offsets, wholememory_tensor_t wholememory_tensor); + + cdef wholememory_error_code_t wholememory_tensor_get_entry_partition_sizes( + size_t * entry_partition, wholememory_tensor_t wholememory_tensor); + + cdef wholememory_error_code_t wholememory_tensor_get_local_entry_count( + size_t * local_entry_count, wholememory_tensor_t wholememory_tensor); + + cdef wholememory_error_code_t wholememory_tensor_get_local_entry_start( + size_t * local_entry_start, wholememory_tensor_t wholememory_tensor); + cdef wholememory_error_code_t wholememory_tensor_get_subtensor(wholememory_tensor_t wholememory_tensor, int64_t *starts, int64_t *ends, @@ -643,6 +661,7 @@ cdef extern from "wholememory/embedding.h": wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, wholememory_embedding_cache_policy_t cache_policy, + size_t * embedding_entry_partition, int user_defined_sms, int round_robin_size) @@ -815,16 +834,21 @@ cdef class PyWholeMemoryEmbedding: WholeMemoryMemoryType memory_type, WholeMemoryMemoryLocation memory_location, WholeMemoryCachePolicy cache_policy, + cython.size_t[:] embedding_entry_partition, int user_defined_sms, int round_robin_size): self.memory_type = memory_type self.memory_location = memory_location + cdef size_t* partition_ptr = NULL + if embedding_entry_partition is not None and embedding_entry_partition.size > 0: + partition_ptr = &embedding_entry_partition[0] check_wholememory_error_code(wholememory_create_embedding(&self.wm_embedding, &tensor_desc.tensor_description, comm.comm_id, self.memory_type, self.memory_location, cache_policy.cache_policy, + partition_ptr, user_defined_sms, round_robin_size)) @@ -872,6 +896,7 @@ def create_embedding(PyWholeMemoryTensorDescription tensor_desc, WholeMemoryMemoryType memory_type, WholeMemoryMemoryLocation memory_location, WholeMemoryCachePolicy cache_policy, + cython.size_t[:] embedding_entry_partition, int user_defined_sms, int round_robin_size): wm_embedding = PyWholeMemoryEmbedding() @@ -880,6 +905,7 @@ def create_embedding(PyWholeMemoryTensorDescription tensor_desc, memory_type, memory_location, cache_policy, + embedding_entry_partition, user_defined_sms, round_robin_size) return wm_embedding @@ -1322,11 +1348,6 @@ cdef class PyWholeMemoryHandle: def get_memory_location(self): return WholeMemoryMemoryLocation(wholememory_get_memory_location(self.wholememory_handle)) - def get_partition_plan(self): - cdef size_t size_per_rank - check_wholememory_error_code(wholememory_get_partition_plan(&size_per_rank, self.wholememory_handle)) - return size_per_rank - def get_global_flatten_tensor(self, object import_dlpack_fn, WholeMemoryDataType data_type, @@ -1514,12 +1535,15 @@ cdef class PyWholeMemoryTensor: def storage_offset(self): return self.tensor_description.storage_offset - def get_partition_plan(self): - mem_size_per_rank = self.get_wholememory_handle().get_partition_plan() - element_size = wholememory_dtype_get_element_size(self.tensor_description.dtype) - vector_size = element_size * self.stride()[0] - assert mem_size_per_rank % vector_size == 0 - return mem_size_per_rank // vector_size + def get_local_entry_count(self): + cdef size_t local_entry_count = 0 + check_wholememory_error_code(wholememory_tensor_get_local_entry_count(&local_entry_count, self.wholememory_tensor)) + return local_entry_count + + def get_local_entry_start(self): + cdef size_t local_entry_start = 0 + check_wholememory_error_code(wholememory_tensor_get_local_entry_start(&local_entry_start, self.wholememory_tensor)) + return local_entry_start def get_sub_tensor(self, starts, ends): cdef int64_t start_array[2] @@ -1662,10 +1686,10 @@ def split_communicator(PyWholeMemoryComm comm,int color,int key): def communicator_set_distributed_backend(PyWholeMemoryComm py_comm,WholeMemoryDistributedBackend distributed_backend): check_wholememory_error_code(wholememory_communicator_set_distributed_backend(py_comm.comm_id,int(distributed_backend))) -def determine_partition_plan(int64_t entry_count, +def equal_partition_plan(int64_t entry_count, int world_size): cdef size_t per_rank_count - check_wholememory_error_code(wholememory_determine_entry_partition_plan(&per_rank_count, + check_wholememory_error_code(wholememory_equal_entry_partition_plan(&per_rank_count, entry_count, world_size)) return per_rank_count @@ -1674,11 +1698,15 @@ def malloc(cython.size_t total_size, PyWholeMemoryComm py_comm, WholeMemoryMemoryType memory_type, WholeMemoryMemoryLocation memory_location, - cython.size_t data_granularity): + cython.size_t data_granularity, + cython.size_t[:] rank_entry_partition=None): handle = PyWholeMemoryHandle() + cdef size_t* partition_ptr = NULL + if rank_entry_partition is not None and rank_entry_partition.size > 0: + partition_ptr = &rank_entry_partition[0] check_wholememory_error_code(wholememory_malloc(&handle.wholememory_handle, total_size, py_comm.comm_id, int(memory_type), int(memory_location), - data_granularity)) + data_granularity, partition_ptr)) return handle def free(PyWholeMemoryHandle handle): @@ -1688,18 +1716,23 @@ def create_wholememory_array(WholeMemoryDataType dtype, int64_t size, PyWholeMemoryComm comm, WholeMemoryMemoryType mem_type, - WholeMemoryMemoryLocation mem_location): + WholeMemoryMemoryLocation mem_location, + cython.size_t[:] tensor_entry_partition=None): wholememory_tensor = PyWholeMemoryTensor() wholememory_tensor.tensor_description.dtype = int(dtype) wholememory_tensor.tensor_description.storage_offset = 0 wholememory_tensor.tensor_description.dim = 1 wholememory_tensor.tensor_description.strides[0] = 1 wholememory_tensor.tensor_description.sizes[0] = size + cdef size_t* partition_ptr = NULL + if tensor_entry_partition is not None and tensor_entry_partition.size > 0: + partition_ptr = &tensor_entry_partition[0] check_wholememory_error_code(wholememory_create_tensor(&wholememory_tensor.wholememory_tensor, &wholememory_tensor.tensor_description, comm.comm_id, int(mem_type), - int(mem_location))) + int(mem_location), + partition_ptr)) return wholememory_tensor def create_wholememory_matrix(WholeMemoryDataType dtype, @@ -1708,7 +1741,8 @@ def create_wholememory_matrix(WholeMemoryDataType dtype, int64_t stride, PyWholeMemoryComm comm, WholeMemoryMemoryType mem_type, - WholeMemoryMemoryLocation mem_location): + WholeMemoryMemoryLocation mem_location, + cython.size_t[:] tensor_entry_partition=None): wholememory_tensor = PyWholeMemoryTensor() wholememory_tensor.tensor_description.dtype = int(dtype) wholememory_tensor.tensor_description.storage_offset = 0 @@ -1719,17 +1753,22 @@ def create_wholememory_matrix(WholeMemoryDataType dtype, wholememory_tensor.tensor_description.strides[1] = 1 wholememory_tensor.tensor_description.sizes[0] = row wholememory_tensor.tensor_description.sizes[1] = column + cdef size_t* partition_ptr = NULL + if tensor_entry_partition is not None and tensor_entry_partition.size > 0: + partition_ptr = &tensor_entry_partition[0] check_wholememory_error_code(wholememory_create_tensor(&wholememory_tensor.wholememory_tensor, &wholememory_tensor.tensor_description, comm.comm_id, int(mem_type), - int(mem_location))) + int(mem_location), + partition_ptr)) return wholememory_tensor def create_wholememory_tensor(PyWholeMemoryTensorDescription tensor_description, PyWholeMemoryComm comm, WholeMemoryMemoryType mem_type, - WholeMemoryMemoryLocation mem_location): + WholeMemoryMemoryLocation mem_location, + cython.size_t[:] tensor_entry_partition=None): if tensor_description.dim() != 1 and tensor_description.dim() != 2: raise NotImplementedError('WholeMemory currently only support 1D or 2D tensor') if tensor_description.stride()[tensor_description.dim() - 1] != 1: @@ -1738,11 +1777,15 @@ def create_wholememory_tensor(PyWholeMemoryTensorDescription tensor_description, raise ValueError('storage_offset be 0 when created') wholememory_tensor = PyWholeMemoryTensor() wholememory_tensor.tensor_description = tensor_description.tensor_description + cdef size_t* partition_ptr = NULL + if tensor_entry_partition is not None and tensor_entry_partition.size > 0: + partition_ptr = &tensor_entry_partition[0] check_wholememory_error_code(wholememory_create_tensor(&wholememory_tensor.wholememory_tensor, &wholememory_tensor.tensor_description, comm.comm_id, int(mem_type), - int(mem_location))) + int(mem_location), + partition_ptr)) return wholememory_tensor def make_tensor_as_wholememory(PyWholeMemoryTensorDescription tensor_description, diff --git a/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py b/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py index 438a485e1..f9f87f721 100644 --- a/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py +++ b/python/pylibwholegraph/pylibwholegraph/test_utils/test_comm.py @@ -12,6 +12,7 @@ # limitations under the License. import torch +import numpy as np import pylibwholegraph.binding.wholememory_binding as wmb from pylibwholegraph.torch.dlpack_utils import torch_import_from_dlpack from packaging import version @@ -129,21 +130,11 @@ def copy_host_1D_tensor_to_wholememory( torch_import_from_dlpack, wmb.WholeMemoryMemoryLocation.MlDevice, world_rank ) assert local_tensor_cuda.dim() == 1 - wm_array_size = wm_array.shape[0] - - local_start_ref = min( - wmb.determine_partition_plan(wm_array_size, world_size) * world_rank, - wm_array_size, - ) - local_end = min( - wmb.determine_partition_plan(wm_array_size, world_size) * (world_rank + 1), - wm_array_size, - ) - local_count = local_end - local_start - + local_count = wm_array.get_local_entry_count() + local_start_ref = wm_array.get_local_entry_start() assert local_start == local_start_ref assert local_tensor_cuda.shape[0] == local_count - local_tensor_cuda.copy_(host_tensor[local_start:local_end]) + local_tensor_cuda.copy_(host_tensor[local_start : local_start + local_count]) wm_comm.barrier() @@ -196,3 +187,13 @@ def int_to_wholememory_type(value: int): return wmb.WholeMemoryMemoryType.MtDistributed else: raise ValueError("invalid int_to_wholememory_type value") + + +def random_partition(total_entry_count: int, world_size: int) -> np.array: + np.random.seed(42) + random_array = np.random.uniform(90, 100, size=world_size) + random_sum = np.sum(random_array) + partition = ((random_array / random_sum) * total_entry_count).astype(np.uintp) + diff = total_entry_count - np.sum(partition) + partition[0] += diff + return partition diff --git a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py index e9bed3a5b..4ffa4ed86 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py @@ -16,6 +16,7 @@ from pylibwholegraph.utils.multiprocess import multiprocess_run from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm from pylibwholegraph.torch.dlpack_utils import torch_import_from_dlpack +from pylibwholegraph.test_utils.test_comm import random_partition import torch import numpy as np import os @@ -56,6 +57,7 @@ def load_routine_func( embedding_stride, storage_offset, round_robin_size=0, + entry_partition=None ): wm_comm, _ = init_torch_env_and_create_wm_comm( world_rank, world_size, world_rank, world_size @@ -78,12 +80,18 @@ def load_routine_func( + first_rank_extra_embedding_entry_count * world_size ) - per_rank_entry = wmb.determine_partition_plan(extra_embedding_count, world_size) - rank_start_entry = min(per_rank_entry * world_rank, extra_embedding_count) - rank_end_entry = min(per_rank_entry * (world_rank + 1), extra_embedding_count) - rank_entry_count = rank_end_entry - rank_start_entry + if entry_partition is None: + per_rank_entry = wmb.equal_partition_plan(extra_embedding_count, world_size) + rank_start_entry = min(per_rank_entry * world_rank, extra_embedding_count) + rank_end_entry = min(per_rank_entry * (world_rank + 1), extra_embedding_count) + rank_entry_count = rank_end_entry - rank_start_entry + else: + rank_start_entry = np.sum(entry_partition[:world_rank]) + rank_entry_count = entry_partition[world_rank] + rank_end_entry = rank_start_entry + rank_entry_count if round_robin_size != 0: + per_rank_entry = wmb.equal_partition_plan(extra_embedding_count, world_size) first_rank_extra_embedding_entry_count = embedding_entry_count % ( world_size * round_robin_size ) @@ -137,6 +145,7 @@ def load_routine_func( wm_comm, mt, ml, + entry_partition ) wholememory_tensor = wholememory_root_tensor.get_sub_tensor( @@ -173,6 +182,7 @@ def load_routine_func( @pytest.mark.parametrize("embedding_stride", [16, 32, 64]) @pytest.mark.parametrize("storage_offset", [0, 3]) @pytest.mark.parametrize("round_robin_size", [256, 1024, 0]) +@pytest.mark.parametrize("partition_method", ['random', 'default']) def test_wholememory_load( file_part_count, embedding_entry_count, @@ -180,6 +190,7 @@ def test_wholememory_load( embedding_stride, storage_offset, round_robin_size, + partition_method ): if embedding_stride < storage_offset + embedding_dim: pytest.skip( @@ -189,9 +200,18 @@ def test_wholememory_load( pytest.skip( "Skipping due to round_robin_size!=0 and storage offset !=0 , the configuration is not valid." ) + if partition_method != 'default' and round_robin_size != 0: + pytest.skip( + "Skipping due to round_robin_size!=0 and partition method != 'default' , the configuration is not valid." + ) global gpu_count if not gpu_count: gpu_count = 1 + + entry_partition = None + if partition_method == 'random': + entry_partition = random_partition(embedding_entry_count, gpu_count) + extra_embedding_count = embedding_entry_count if round_robin_size != 0: first_rank_extra_embedding_entry_count = embedding_entry_count % ( @@ -229,7 +249,7 @@ def test_wholememory_load( ) if round_robin_size != 0: - entry_per_rank = wmb.determine_partition_plan(extra_embedding_count, gpu_count) + entry_per_rank = wmb.equal_partition_plan(extra_embedding_count, gpu_count) cpu_embedding_tensor_base_extra = torch.empty( (extra_embedding_count, embedding_dim), dtype=torch.int, device="cpu" @@ -260,6 +280,7 @@ def test_wholememory_load( embedding_stride=embedding_stride, storage_offset=storage_offset, round_robin_size=round_robin_size, + entry_partition=entry_partition ) multiprocess_run(gpu_count, load_routine_func_partial) @@ -278,6 +299,7 @@ def store_routine_func( embedding_dim, embedding_stride, storage_offset, + entry_partition ): (wm_comm, _) = init_torch_env_and_create_wm_comm( world_rank, world_size, world_rank, world_size @@ -297,6 +319,7 @@ def store_routine_func( wm_comm, mt, ml, + entry_partition ) local_root_tensor, local_root_offset = wholememory_root_tensor.get_local_tensor( torch_import_from_dlpack, wmb.WholeMemoryMemoryLocation.MlHost, world_rank @@ -324,13 +347,19 @@ def store_routine_func( @pytest.mark.parametrize("embedding_dim", [16, 31, 33]) @pytest.mark.parametrize("embedding_stride", [16, 32, 64]) @pytest.mark.parametrize("storage_offset", [0, 3]) +@pytest.mark.parametrize("partition_method", ['random']) def test_wholememory_store( - embedding_entry_count, embedding_dim, embedding_stride, storage_offset + embedding_entry_count, embedding_dim, embedding_stride, storage_offset, partition_method ): if embedding_stride < storage_offset + embedding_dim: pytest.skip( "Skipping due to embedding_stride, embedding_dim and storage_offset configuration not valid." ) + + global gpu_count + entry_partition = None + if partition_method == 'random': + entry_partition = random_partition(embedding_entry_count, gpu_count) file_name_prefix = "pytest_store_temp_file" store_routine_func_partial = partial( store_routine_func, @@ -339,9 +368,9 @@ def test_wholememory_store( embedding_dim=embedding_dim, embedding_stride=embedding_stride, storage_offset=storage_offset, + entry_partition=entry_partition ) - global gpu_count multiprocess_run(gpu_count, store_routine_func_partial) embedding_entry_offset = 0 file_part_count = gpu_count diff --git a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_tensor.py b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_tensor.py index f6a7b2f12..dce75fabc 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_tensor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,19 +14,20 @@ import pylibwholegraph.binding.wholememory_binding as wmb from pylibwholegraph.utils.multiprocess import multiprocess_run from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm +from pylibwholegraph.test_utils.test_comm import random_partition # Run with: # python3 -m pytest ../tests/pylibwholegraph/test_wholememory_tensor.py -s -def array_test_case(wm_comm, dt, mt, ml, size): +def array_test_case(wm_comm, dt, mt, ml, size, entry_partition): world_rank = wm_comm.get_rank() print( "Rank=%d testing array size=%d dt=%s, mt=%s, ml=%s" % (world_rank, size, dt, mt, ml) ) - wm_array = wmb.create_wholememory_array(dt, size, wm_comm, mt, ml) + wm_array = wmb.create_wholememory_array(dt, size, wm_comm, mt, ml, entry_partition) assert wm_array.dtype == dt assert wm_array.dim() == 1 assert len(wm_array.shape) == 1 @@ -47,14 +48,14 @@ def array_test_case(wm_comm, dt, mt, ml, size): wmb.destroy_wholememory_tensor(wm_array) -def matrix_test_case(wm_comm, dt, mt, ml, mat_size): +def matrix_test_case(wm_comm, dt, mt, ml, mat_size, entry_partition): world_rank = wm_comm.get_rank() print( "Rank=%d testing matrix size=%s dt=%s, mt=%s, ml=%s" % (world_rank, mat_size, dt, mt, ml) ) wm_matrix = wmb.create_wholememory_matrix( - dt, mat_size[0], mat_size[1], -1, wm_comm, mt, ml + dt, mat_size[0], mat_size[1], -1, wm_comm, mt, ml, entry_partition ) assert wm_matrix.dtype == dt @@ -93,7 +94,8 @@ def routine_func(world_rank: int, world_size: int): single_array_size = 128 * 1024 * 1024 * world_size single_matrix_size = (1024 * 1024 * world_size, 128) dt = wmb.WholeMemoryDataType.DtFloat - + array_entry_partition = random_partition(single_array_size, world_size) + matrix_entry_partition = random_partition(single_matrix_size[0], world_size) print("") for mt in [ @@ -106,8 +108,8 @@ def routine_func(world_rank: int, world_size: int): wmb.WholeMemoryMemoryLocation.MlDevice, ]: if wm_comm.support_type_location(mt, ml): - array_test_case(wm_comm, dt, mt, ml, single_array_size) - matrix_test_case(wm_comm, dt, mt, ml, single_matrix_size) + array_test_case(wm_comm, dt, mt, ml, single_array_size, array_entry_partition) + matrix_test_case(wm_comm, dt, mt, ml, single_matrix_size, matrix_entry_partition) wmb.finalize() diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py index b4102e17c..a1fbad89e 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,6 +15,7 @@ from pylibwholegraph.utils.multiprocess import multiprocess_run from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm from pylibwholegraph.torch.dlpack_utils import torch_import_from_dlpack +from pylibwholegraph.test_utils.test_comm import random_partition import torch import pylibwholegraph.torch.wholememory_ops as wm_ops @@ -45,6 +46,7 @@ def scatter_gather_test_cast( embedding_dim, indice_count, use_python_binding=True, + entry_partition=None ): world_rank = wm_comm.get_rank() world_size = wm_comm.get_size() @@ -53,7 +55,7 @@ def scatter_gather_test_cast( % (world_rank, embedding_count, embedding_dim, indice_count, dt, mt, ml) ) wm_embedding = wmb.create_wholememory_matrix( - dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml + dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition ) scatter_indice = torch.arange( @@ -86,22 +88,15 @@ def scatter_gather_test_cast( torch_import_from_dlpack, wmb.WholeMemoryMemoryLocation.MlDevice, world_rank ) - local_ref_start = min( - wmb.determine_partition_plan(embedding_count, world_size) * world_rank, - embedding_count, - ) - local_ref_end = min( - wmb.determine_partition_plan(embedding_count, world_size) * (world_rank + 1), - embedding_count, - ) - local_ref_count = local_ref_end - local_ref_start + local_ref_start = wm_embedding.get_local_entry_start() + local_ref_count = wm_embedding.get_local_entry_count() assert local_start == local_ref_start assert local_tensor_cuda.dim() == 2 assert local_tensor_cuda.shape[0] == local_ref_count assert local_tensor_cuda.shape[1] == embedding_dim local_tensor = local_tensor_cuda.cpu() - local_indices = torch.arange(local_ref_start, local_ref_end, dtype=torch.int64) + local_indices = torch.arange(local_ref_start, local_ref_start + local_ref_count, dtype=torch.int64) local_tensor_ref = gen_int_embedding(local_indices, embedding_dim, torch.float) # print('\nlocal_tensor %s =%s\nlocal_tensor_ref %s =%s' % ( # local_tensor.shape, local_tensor, local_tensor_ref.shape, local_tensor_ref)) @@ -142,6 +137,7 @@ def routine_func(world_rank: int, world_size: int): embedding_dim = 256 indice_count = 100001 dt = wmb.WholeMemoryDataType.DtFloat + entry_partition = random_partition(embedding_count, world_size) print("") @@ -156,7 +152,7 @@ def routine_func(world_rank: int, world_size: int): ]: if wm_comm.support_type_location(mt, ml): scatter_gather_test_cast( - wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, True + wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, True, entry_partition ) # scatter_gather_test_cast(wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, False) wmb.finalize() diff --git a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py index 634508408..825c8cbaa 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py @@ -384,9 +384,10 @@ def create_embedding( sizes: List[int], *, cache_policy: Union[WholeMemoryCachePolicy, None] = None, + embedding_entry_partition: Union[List[int], None] = None, random_init: bool = False, gather_sms: int = -1, - round_robin_size: int = 0, + round_robin_size: int = 0 ): r""" Create embedding @@ -396,6 +397,9 @@ def create_embedding( :param dtype: data type :param sizes: size of the embedding, must be 2D :param cache_policy: cache policy + :param embedding_entry_partition: rank partition based on entry; embedding_entry_partition[i] determines the + entry count of rank i and shoud be a positive integer; the sum of embedding_entry_partition should equal to + total entry count; entries will be equally partitioned if None :param gather_sms: the number of SMs used in gather process :param round_robin_size: continuous embedding size of a rank using round robin shard strategy :return: WholeMemoryEmbedding @@ -415,7 +419,12 @@ def create_embedding( raise AssertionError ("The caching feature is not supported yet when using NVSHMEM." "Please consider disable it by passing cache_policy = None.") - + if embedding_entry_partition is not None and cache_policy is not None: + print("embedding_entry_partition is ignored because cache_policy is specified") + embedding_entry_partition = None + if embedding_entry_partition is not None and round_robin_size != 0: + print("round_robin_size is ignored because embedding_entry_partition is specified") + round_robin_size = 0 wm_embedding = WholeMemoryEmbedding( wmb.create_embedding( tensor_desc, @@ -423,8 +432,9 @@ def create_embedding( str_to_wmb_wholememory_memory_type(memory_type), str_to_wmb_wholememory_location(memory_location), wmb_cache_policy, + embedding_entry_partition=embedding_entry_partition, user_defined_sms=gather_sms, - round_robin_size=round_robin_size, + round_robin_size=round_robin_size ), cache_policy, ) @@ -447,8 +457,9 @@ def create_embedding_from_filelist( last_dim_size: int, *, cache_policy: Union[WholeMemoryCachePolicy, None] = None, + embedding_entry_partition: Union[List[int], None] = None, gather_sms: int = -1, - round_robin_size: int = 0, + round_robin_size: int = 0 ): r""" Create embedding from file list @@ -459,6 +470,9 @@ def create_embedding_from_filelist( :param dtype: data type :param last_dim_size: size of last dim :param cache_policy: cache policy + :param embedding_entry_partition: rank partition based on entry; embedding_entry_partition[i] determines the + entry count of rank i and shoud be a positive integer; the sum of embedding_entry_partition should equal to + total entry count; entries will be equally partitioned if None :param gather_sms: the number of SMs used in gather process :param round_robin_size: continuous embedding size of a rank using round robin shard strategy :return: @@ -466,6 +480,12 @@ def create_embedding_from_filelist( if isinstance(filelist, str): filelist = [filelist] assert last_dim_size > 0 + if embedding_entry_partition is not None and cache_policy is not None: + print("embedding_entry_partition is ignored because cache_policy is specified") + embedding_entry_partition = None + if embedding_entry_partition is not None and round_robin_size != 0: + print("round_robin_size is ignored because embedding_entry_partition is specified") + round_robin_size = 0 element_size = torch.tensor([], dtype=dtype).element_size() file_entry_size = element_size * last_dim_size total_file_size = 0 @@ -485,8 +505,9 @@ def create_embedding_from_filelist( dtype, [total_entry_count, last_dim_size], cache_policy=cache_policy, + embedding_entry_partition=embedding_entry_partition, gather_sms=gather_sms, - round_robin_size=round_robin_size, + round_robin_size=round_robin_size ) wm_embedding.get_embedding_tensor().from_filelist(filelist, round_robin_size) return wm_embedding diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index 84ee59eee..be0b1bfff 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -204,6 +204,7 @@ def create_wholememory_tensor( sizes: List[int], dtype: torch.dtype, strides: List[int], + tensor_entry_partition: Union[List[int], None] = None ): """ Create empty WholeMemory Tensor. Now only support dim = 1 or 2 @@ -213,6 +214,9 @@ def create_wholememory_tensor( :param sizes: size of the tensor :param dtype: data type of the tensor :param strides: strides of the tensor + :param tensor_entry_partition: rank partition based on entry; tensor_entry_partition[i] determines the + entry count of rank i and shoud be a positive integer; the sum of tensor_entry_partition should equal to + total entry count; entries will be equally partitioned if None :return: Allocated WholeMemoryTensor """ dim = len(sizes) @@ -235,7 +239,7 @@ def create_wholememory_tensor( wm_location = str_to_wmb_wholememory_location(memory_location) return WholeMemoryTensor( - wmb.create_wholememory_tensor(td, comm.wmb_comm, wm_memory_type, wm_location) + wmb.create_wholememory_tensor(td, comm.wmb_comm, wm_memory_type, wm_location, tensor_entry_partition) ) @@ -247,6 +251,7 @@ def create_wholememory_tensor_from_filelist( dtype: torch.dtype, last_dim_size: int = 0, last_dim_strides: int = -1, + tensor_entry_partition: Union[List[int], None] = None ): """ Create WholeMemory Tensor from list of binary files. @@ -257,6 +262,9 @@ def create_wholememory_tensor_from_filelist( :param dtype: data type of the tensor :param last_dim_size: 0 for create 1-D array, positive value for create matrix column size :param last_dim_strides: stride of last_dim, -1 for same as size of last dim. + :param tensor_entry_partition: rank partition based on entry; tensor_entry_partition[i] determines the + entry count of rank i and shoud be a positive integer; the sum of tensor_entry_partition should equal to + total entry count; entries will be equally partitioned if None :return: WholeMemoryTensor """ if isinstance(filelist, str): @@ -284,7 +292,7 @@ def create_wholememory_tensor_from_filelist( sizes = [total_entry_count, last_dim_size] strides = [last_dim_strides, 1] wm_tensor = create_wholememory_tensor( - comm, memory_type, memory_location, sizes, dtype, strides + comm, memory_type, memory_location, sizes, dtype, strides, tensor_entry_partition ) wm_tensor.from_filelist(filelist) return wm_tensor From 529b8bdb2ae3217b2597fafc84e2af7de4e2f0ee Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 7 Aug 2024 19:38:38 -0500 Subject: [PATCH 04/17] Use tool.scikit-build.cmake.version, set scikit-build-core minimum-version (#203) Contributes to https://github.com/rapidsai/build-planning/issues/58. `scikit-build-core==0.10.0` was released today (https://github.com/scikit-build/scikit-build-core/releases/tag/v0.10.0), and wheel-building configurations across RAPIDS are incompatible with it. This proposes upgrading to that version and fixing configuration here in a way that: * is compatible with that new `scikit-build-core` version * takes advantage of the forward-compatibility mechanism (`minimum-version`) that `scikit-build-core` provides, to reduce the risk of needing to do this again in the future Authors: - James Lamb (https://github.com/jameslamb) Approvers: - https://github.com/jakirkham URL: https://github.com/rapidsai/wholegraph/pull/203 --- conda/environments/all_cuda-118_arch-x86_64.yaml | 2 +- conda/environments/all_cuda-125_arch-x86_64.yaml | 2 +- conda/recipes/pylibwholegraph/conda_build_config.yaml | 2 +- dependencies.yaml | 4 ++-- python/pylibwholegraph/pyproject.toml | 5 +++-- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 100d086bc..f06e05c79 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -43,7 +43,7 @@ dependencies: - pytorch=2.0.0 - rapids-build-backend>=0.3.0,<0.4.0.dev0 - recommonmark -- scikit-build-core>=0.7.0 +- scikit-build-core>=0.10.0 - sphinx-copybutton - sphinx-markdown-tables - sphinx<6 diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 376a92ae7..f631a6cc9 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -42,7 +42,7 @@ dependencies: - python>=3.9,<3.12 - rapids-build-backend>=0.3.0,<0.4.0.dev0 - recommonmark -- scikit-build-core>=0.7.0 +- scikit-build-core>=0.10.0 - sphinx-copybutton - sphinx-markdown-tables - sphinx<6 diff --git a/conda/recipes/pylibwholegraph/conda_build_config.yaml b/conda/recipes/pylibwholegraph/conda_build_config.yaml index 46f3a251b..b5e529cbc 100644 --- a/conda/recipes/pylibwholegraph/conda_build_config.yaml +++ b/conda/recipes/pylibwholegraph/conda_build_config.yaml @@ -14,7 +14,7 @@ cmake_version: - ">=3.26.4,!=3.30.0" scikit_build_core_version: - - ">=0.7.0" + - ">=0.10.0" c_stdlib: - sysroot diff --git a/dependencies.yaml b/dependencies.yaml index 7236b0fe8..24834d1be 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -336,10 +336,10 @@ dependencies: - rapids-build-backend>=0.3.0,<0.4.0.dev0 - output_types: conda packages: - - scikit-build-core>=0.7.0 + - scikit-build-core>=0.10.0 - output_types: [requirements, pyproject] packages: - - scikit-build-core[pyproject]>=0.7.0 + - scikit-build-core[pyproject]>=0.10.0 python_build_wheel: common: - output_types: [pyproject] diff --git a/python/pylibwholegraph/pyproject.toml b/python/pylibwholegraph/pyproject.toml index 19b48cb9f..20b0b3aa9 100644 --- a/python/pylibwholegraph/pyproject.toml +++ b/python/pylibwholegraph/pyproject.toml @@ -16,7 +16,7 @@ build-backend = "rapids_build_backend.build" requires = [ "rapids-build-backend>=0.3.0,<0.4.0.dev0", - "scikit-build-core[pyproject]>=0.7.0", + "scikit-build-core[pyproject]>=0.10.0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. [project] @@ -52,7 +52,8 @@ requires = [ [tool.scikit-build] build-dir = "build/{wheel_tag}" cmake.build-type = "Release" -cmake.minimum-version = "3.26.4" +cmake.version = "CMakeLists.txt" +minimum-version = "build-system.requires" ninja.make-fallback = true sdist.exclude = ["*tests*"] sdist.reproducible = true From 928b5d69bcb6475f0ae5d46b1a873f374f4674f7 Mon Sep 17 00:00:00 2001 From: Tommy Li Date: Thu, 8 Aug 2024 05:51:18 -0700 Subject: [PATCH 05/17] Add horovodrun launch agent for Wholegraph (#200) We have many users running the [Kubeflow training operator](https://github.com/kubeflow/training-operator) who are also interested in using Wholegraph. For our MPIJobs users, many of them still use [HorovodRun](https://github.com/horovod/horovod/tree/master) as the startup command. Therefore, we want to add HorovodRun as one of the Wholegraph launch agents so our users can use Wholegraph on top of Kubeflow. The new function will be similar to the existing MPI launcher agent, where the horovod library is only imported on demand. The horovod.tensorflow library will be used solely for the Horovod initialization command due to the issue with horovod.torch (see https://github.com/horovod/horovod/issues/4009). After the Horovod initialization, the program can continue to run normal PyTorch code within each rank just like the mpi4py. fixes #201 Authors: - Tommy Li (https://github.com/Tomcli) Approvers: - https://github.com/linhu-nv - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/200 --- .../torch/distributed_launch.py | 42 ++++++++++++++++++- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/python/pylibwholegraph/pylibwholegraph/torch/distributed_launch.py b/python/pylibwholegraph/pylibwholegraph/torch/distributed_launch.py index 8fbcbfa4d..7abb5067c 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/distributed_launch.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/distributed_launch.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -86,7 +86,7 @@ def add_distributed_launch_options(parser: ArgumentParser): "--launch-agent", dest="launch_agent", default="mpi", - help="launch agent used, mpi, pytorch or spawn", + help="launch agent used, mpi, horovodrun, pytorch or spawn", ) # command line flags parser.add_argument( @@ -215,6 +215,38 @@ def distributed_launch_mpi(args, main_func): main_func() +def distributed_launch_horovodrun(args, main_func): + # Horovod is used to launch the job and set up the distributed environment, not to run + # framework-specific distributed training code. + # + # Using horovod.tensorflow for launching Kubeflow MPIJobs with Horovodrun. + # horovod.torch is not used because it is not compatible with certain versions of PyTorch. + import horovod.tensorflow as hvd + + hvd.init() + + global distributed_config + distributed_config.rank = hvd.rank() + distributed_config.world_size = hvd.size() + distributed_config.local_rank = hvd.local_rank() + distributed_config.local_size = hvd.local_size() + distributed_config.master_addr = get_value_from_option_and_env( + args.master_addr, args.launch_env_name_master_addr, "", "localhost" + ) + distributed_config.master_port = int( + get_value_from_option_and_env( + args.master_port, args.launch_env_name_master_port, -1, 12335 + ) + ) + + os.environ["RANK"] = str(distributed_config.rank) + os.environ["WORLD_SIZE"] = str(distributed_config.world_size) + os.environ["MASTER_ADDR"] = distributed_config.master_addr + os.environ["MASTER_PORT"] = str(distributed_config.master_port) + + main_func() + + def distributed_launch_pytorch( args, main_func, @@ -310,6 +342,7 @@ def distributed_launch_spawn(args, main_func): def distributed_launch(args, main_func): assert ( args.launch_agent == "mpi" + or args.launch_agent == "horovodrun" or args.launch_agent == "pytorch" or args.launch_agent == "spawn" ) @@ -318,6 +351,11 @@ def distributed_launch(args, main_func): # when using MPI, command is like: # mpirun python [train_script.py] distributed_launch_mpi(args, main_func) + elif args.launch_agent == "horovodrun": + # use horovodrun to launch multiprocess + # when using horovodrun, command is like: + # horovodrun python [train_script.py] --launch_agent=horovodrun + distributed_launch_horovodrun(args, main_func) elif args.launch_agent == "pytorch": # use pytorch DDP to launch multiprocess # when using pytorch DDP, assume two nodes with 8 GPU each, command is like: From e616d9880bbcf198f83b046370064c79520648e2 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Thu, 8 Aug 2024 10:24:27 -0500 Subject: [PATCH 06/17] Improve update-version.sh (#204) A few small tweaks to `update-version.sh` for alignment across RAPIDS. This PR removes the `UCX_PY` version HTTP call from `update-version.sh` because it is not used. Authors: - Bradley Dice (https://github.com/bdice) Approvers: - James Lamb (https://github.com/jameslamb) URL: https://github.com/rapidsai/wholegraph/pull/204 --- ci/release/update-version.sh | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index 1ba99e790..607ce165b 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -26,16 +26,14 @@ CURRENT_MINOR=$(echo $CURRENT_TAG | awk '{split($0, a, "."); print a[2]}') CURRENT_PATCH=$(echo $CURRENT_TAG | awk '{split($0, a, "."); print a[3]}') CURRENT_SHORT_TAG=${CURRENT_MAJOR}.${CURRENT_MINOR} -#Get . for next version +# Get . for next version NEXT_MAJOR=$(echo $NEXT_FULL_TAG | awk '{split($0, a, "."); print a[1]}') NEXT_MINOR=$(echo $NEXT_FULL_TAG | awk '{split($0, a, "."); print a[2]}') NEXT_SHORT_TAG=${NEXT_MAJOR}.${NEXT_MINOR} -NEXT_UCX_PY_VERSION="$(curl -sL https://version.gpuci.io/rapids/${NEXT_SHORT_TAG}).*" # Need to distutils-normalize the versions for some use cases CURRENT_SHORT_TAG_PEP440=$(python -c "from setuptools.extern import packaging; print(packaging.version.Version('${CURRENT_SHORT_TAG}'))") NEXT_SHORT_TAG_PEP440=$(python -c "from setuptools.extern import packaging; print(packaging.version.Version('${NEXT_SHORT_TAG}'))") -echo "current is ${CURRENT_SHORT_TAG_PEP440}, next is ${NEXT_SHORT_TAG_PEP440}" echo "Preparing release $CURRENT_TAG => $NEXT_FULL_TAG" @@ -68,17 +66,16 @@ DEPENDENCIES=( ) for DEP in "${DEPENDENCIES[@]}"; do for FILE in dependencies.yaml conda/environments/*.yaml; do - sed_runner "/-.* ${DEP}\(-cu[[:digit:]]\{2\}\)\{0,1\}==/ s/==.*/==${NEXT_SHORT_TAG_PEP440}.*,>=0.0.0a0/g" ${FILE} + sed_runner "/-.* ${DEP}\(-cu[[:digit:]]\{2\}\)\{0,1\}==/ s/==.*/==${NEXT_SHORT_TAG_PEP440}.*,>=0.0.0a0/g" "${FILE}" done for FILE in python/**/pyproject.toml; do - sed_runner "/\"${DEP}\(-cu[[:digit:]]\{2\}\)\{0,1\}==/ s/==.*\"/==${NEXT_SHORT_TAG_PEP440}.*,>=0.0.0a0\"/g" ${FILE} + sed_runner "/\"${DEP}\(-cu[[:digit:]]\{2\}\)\{0,1\}==/ s/==.*\"/==${NEXT_SHORT_TAG_PEP440}.*,>=0.0.0a0\"/g" "${FILE}" done done # Doxyfile update sed_runner "/^PROJECT_NUMBER/ s|=.*|= ${NEXT_SHORT_TAG}|" cpp/Doxyfile - # CI files for FILE in .github/workflows/*.yaml; do sed_runner "/shared-workflows/ s/@.*/@branch-${NEXT_SHORT_TAG}/g" "${FILE}" From 604d7a8d0260fff83ec36ffcfae97c7de4e3f853 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Thu, 8 Aug 2024 11:49:52 -0400 Subject: [PATCH 07/17] Update pre-commit hooks (#206) This PR updates pre-commit hooks to the latest versions that are supported without causing style check errors. Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) Approvers: - James Lamb (https://github.com/jameslamb) URL: https://github.com/rapidsai/wholegraph/pull/206 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 56f2ca814..68680215b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: types_or: [c, c++, cuda] args: ["-fallback-style=none", "-style=file", "-i"] - repo: https://github.com/rapidsai/pre-commit-hooks - rev: v0.2.0 + rev: v0.3.1 hooks: - id: verify-copyright files: | From bc1389907f8b5be7205c8d30c5f50655ff280a9a Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Fri, 16 Aug 2024 22:28:44 +0800 Subject: [PATCH 08/17] fix_mnnvl_with_uuid (#207) Authors: - Chuang Zhu (https://github.com/chuangz0) Approvers: - https://github.com/linhu-nv - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/207 --- cpp/src/wholememory/communicator.cpp | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/cpp/src/wholememory/communicator.cpp b/cpp/src/wholememory/communicator.cpp index d08fe0804..dabb9ba1b 100644 --- a/cpp/src/wholememory/communicator.cpp +++ b/cpp/src/wholememory/communicator.cpp @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -557,7 +558,7 @@ void exchange_rank_info(wholememory_comm_t wm_comm) wm_comm->clique_info.clique_rank = -1; wm_comm->clique_info.clique_rank_num = 0; - std::set clique_ids{}; + std::set clique_uuids{}; for (int r = 0; r < wm_comm->world_size; r++) { WHOLEMEMORY_CHECK(r == p_rank_info.get()[r].rank); @@ -583,16 +584,21 @@ void exchange_rank_info(wholememory_comm_t wm_comm) if (wm_comm->clique_info.clique_rank_num == 0) { wm_comm->clique_info.clique_first_rank = r; } wm_comm->clique_info.clique_rank_num++; } - clique_ids.insert(p_rank_info.get()[r].fabric_info.cliqueId); + clique_uuids.insert( + std::string(reinterpret_cast(p_rank_info.get()[r].fabric_info.clusterUuid), + NVML_GPU_FABRIC_UUID_LEN)); #endif } #if CUDA_VERSION >= 12030 - wm_comm->clique_info.clique_num = clique_ids.size(); - int id = 0; - for (auto clique_id : clique_ids) { - if (clique_id == ri.fabric_info.cliqueId) { wm_comm->clique_info.clique_id = id; } + wm_comm->clique_info.clique_num = clique_uuids.size(); + + std::string uuid = std::string(reinterpret_cast(ri.fabric_info.clusterUuid), + NVML_GPU_FABRIC_UUID_LEN); + int id = 0; + for (auto clique_uuid : clique_uuids) { + if (clique_uuid == uuid) { wm_comm->clique_info.clique_id = id; } id++; } From 9b877a45519d3aba707884fddf9c3a837ac32719 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 23 Aug 2024 15:13:15 -0400 Subject: [PATCH 09/17] Drop Python 3.9 support (#209) Contributes to https://github.com/rapidsai/build-planning/issues/88 Finishes the work of dropping Python 3.9 support. This project stopped building / testing against Python 3.9 as of https://github.com/rapidsai/shared-workflows/pull/235. This PR updates configuration and docs to reflect that. ## Notes for Reviewers ### How I tested this Checked that there were no remaining uses like this: ```shell git grep -E '3\.9' git grep '39' git grep 'py39' ``` And similar for variations on Python 3.8 (to catch things that were missed the last time this was done). Authors: - James Lamb (https://github.com/jameslamb) Approvers: - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/wholegraph/pull/209 --- conda/environments/all_cuda-118_arch-x86_64.yaml | 2 +- conda/environments/all_cuda-125_arch-x86_64.yaml | 2 +- dependencies.yaml | 6 +----- python/pylibwholegraph/pyproject.toml | 6 +----- 4 files changed, 4 insertions(+), 12 deletions(-) diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index f06e05c79..a225b5343 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -38,7 +38,7 @@ dependencies: - pytest - pytest-forked - pytest-xdist -- python>=3.9,<3.12 +- python>=3.10,<3.12 - pytorch-cuda=11.8 - pytorch=2.0.0 - rapids-build-backend>=0.3.0,<0.4.0.dev0 diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index f631a6cc9..835d35839 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -39,7 +39,7 @@ dependencies: - pytest - pytest-forked - pytest-xdist -- python>=3.9,<3.12 +- python>=3.10,<3.12 - rapids-build-backend>=0.3.0,<0.4.0.dev0 - recommonmark - scikit-build-core>=0.10.0 diff --git a/dependencies.yaml b/dependencies.yaml index 24834d1be..3ceafa51c 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -193,10 +193,6 @@ dependencies: specific: - output_types: conda matrices: - - matrix: - py: "3.9" - packages: - - python=3.9 - matrix: py: "3.10" packages: @@ -207,7 +203,7 @@ dependencies: - python=3.11 - matrix: packages: - - python>=3.9,<3.12 + - python>=3.10,<3.12 run: common: - output_types: [conda, requirements] diff --git a/python/pylibwholegraph/pyproject.toml b/python/pylibwholegraph/pyproject.toml index 20b0b3aa9..353911508 100644 --- a/python/pylibwholegraph/pyproject.toml +++ b/python/pylibwholegraph/pyproject.toml @@ -27,15 +27,11 @@ authors = [ { name = "NVIDIA Corporation" }, ] license = { text = "Apache 2.0" } -requires-python = ">=3.6" +requires-python = ">=3.10" classifiers = [ "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ] From 44e8a4c852c3dcd83bd2e5252758e45f99866fa9 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Sat, 24 Aug 2024 00:41:30 +0200 Subject: [PATCH 10/17] Remove NumPy <2 pin (#208) This PR removes the NumPy<2 pin. `wholegraph` does not appear to be a heavy user of NumPy or CuPy, so it should be fine to simply remove the pin. For other RAPIDS projects with heavier dependency, CuPy 13.3.0 was required (just released) to have sufficient good CuPy/NumPy interoperability. Authors: - Sebastian Berg (https://github.com/seberg) Approvers: - https://github.com/jakirkham URL: https://github.com/rapidsai/wholegraph/pull/208 --- conda/environments/all_cuda-118_arch-x86_64.yaml | 2 +- conda/environments/all_cuda-125_arch-x86_64.yaml | 2 +- dependencies.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index a225b5343..fe6feb654 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -30,7 +30,7 @@ dependencies: - nbsphinx - nccl - ninja -- numpy>=1.23,<2.0a0 +- numpy>=1.23,<3.0a0 - numpydoc - nvcc_linux-64=11.8 - pre-commit diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 835d35839..7636a6ce8 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -32,7 +32,7 @@ dependencies: - nbsphinx - nccl - ninja -- numpy>=1.23,<2.0a0 +- numpy>=1.23,<3.0a0 - numpydoc - pre-commit - pydata-sphinx-theme diff --git a/dependencies.yaml b/dependencies.yaml index 3ceafa51c..c5585bc3f 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -223,7 +223,7 @@ dependencies: - output_types: [conda, requirements] packages: - ninja - - numpy>=1.23,<2.0a0 + - numpy>=1.23,<3.0a0 - pytest - pytest-forked - pytest-xdist From 31f24e8210e5d35220598878df071ca14d808814 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Tue, 27 Aug 2024 13:39:15 -0400 Subject: [PATCH 11/17] Update rapidsai/pre-commit-hooks (#213) This PR updates rapidsai/pre-commit-hooks to the version 0.4.0. Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) Approvers: - James Lamb (https://github.com/jameslamb) URL: https://github.com/rapidsai/wholegraph/pull/213 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 68680215b..66696b06f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: types_or: [c, c++, cuda] args: ["-fallback-style=none", "-style=file", "-i"] - repo: https://github.com/rapidsai/pre-commit-hooks - rev: v0.3.1 + rev: v0.4.0 hooks: - id: verify-copyright files: | From 92d680287cbcfd655a51885b4c3e13cef36758a3 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 4 Sep 2024 21:58:07 -0500 Subject: [PATCH 12/17] Add support for Python 3.12 (#214) Contributes to https://github.com/rapidsai/build-planning/issues/40 This PR adds support for Python 3.12. ## Notes for Reviewers This is part of ongoing work to add Python 3.12 support across RAPIDS. It temporarily introduces a build/test matrix including Python 3.12, from https://github.com/rapidsai/shared-workflows/pull/213. A follow-up PR will revert back to pointing at the `branch-24.10` branch of `shared-workflows` once all RAPIDS repos have added Python 3.12 support. ### This will fail until all dependencies have been updates to Python 3.12 CI here is expected to fail until all of this project's upstream dependencies support Python 3.12. This can be merged whenever all CI jobs are passing. Authors: - James Lamb (https://github.com/jameslamb) Approvers: - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/wholegraph/pull/214 --- .github/workflows/build.yaml | 12 ++++++------ .github/workflows/pr.yaml | 18 +++++++++--------- .github/workflows/test.yaml | 6 +++--- .../environments/all_cuda-118_arch-x86_64.yaml | 2 +- .../environments/all_cuda-125_arch-x86_64.yaml | 2 +- dependencies.yaml | 6 +++++- python/pylibwholegraph/pyproject.toml | 1 + 7 files changed, 26 insertions(+), 21 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 454185f4a..2c5ceeb6a 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@python-3.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -38,7 +38,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@python-3.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -49,7 +49,7 @@ jobs: if: github.ref_type == 'branch' needs: [python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@python-3.12 with: arch: "amd64" branch: ${{ inputs.branch }} @@ -62,7 +62,7 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@python-3.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -70,7 +70,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@python-3.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -80,7 +80,7 @@ jobs: wheel-publish-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@python-3.12 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 6e5c86c54..ec7848ddf 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -21,41 +21,41 @@ jobs: - wheel-build-pylibwholegraph - wheel-test-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@python-3.12 checks: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@python-3.12 with: enable_check_generated_files: false conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@python-3.12 with: build_type: pull-request node_type: cpu16 conda-cpp-tests: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@python-3.12 with: build_type: pull-request conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@python-3.12 with: build_type: pull-request conda-python-tests: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@python-3.12 with: build_type: pull-request docs-build: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@python-3.12 with: build_type: pull-request arch: "amd64" @@ -64,14 +64,14 @@ jobs: wheel-build-pylibwholegraph: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@python-3.12 with: build_type: pull-request script: ci/build_wheel.sh wheel-test-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@python-3.12 with: build_type: pull-request script: ci/test_wheel.sh diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f2d7e1cc7..e65996647 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,7 +16,7 @@ on: jobs: conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@python-3.12 with: build_type: nightly branch: ${{ inputs.branch }} @@ -24,7 +24,7 @@ jobs: sha: ${{ inputs.sha }} conda-pytorch-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@python-3.12 with: build_type: nightly branch: ${{ inputs.branch }} @@ -32,7 +32,7 @@ jobs: sha: ${{ inputs.sha }} wheel-tests-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.10 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@python-3.12 with: build_type: nightly branch: ${{ inputs.branch }} diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index fe6feb654..d989d880e 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -38,7 +38,7 @@ dependencies: - pytest - pytest-forked - pytest-xdist -- python>=3.10,<3.12 +- python>=3.10,<3.13 - pytorch-cuda=11.8 - pytorch=2.0.0 - rapids-build-backend>=0.3.0,<0.4.0.dev0 diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 7636a6ce8..5b152cd31 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -39,7 +39,7 @@ dependencies: - pytest - pytest-forked - pytest-xdist -- python>=3.10,<3.12 +- python>=3.10,<3.13 - rapids-build-backend>=0.3.0,<0.4.0.dev0 - recommonmark - scikit-build-core>=0.10.0 diff --git a/dependencies.yaml b/dependencies.yaml index c5585bc3f..8aaf92cd9 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -202,8 +202,12 @@ dependencies: packages: - python=3.11 - matrix: + py: "3.12" packages: - - python>=3.10,<3.12 + - python=3.12 + - matrix: + packages: + - python>=3.10,<3.13 run: common: - output_types: [conda, requirements] diff --git a/python/pylibwholegraph/pyproject.toml b/python/pylibwholegraph/pyproject.toml index 353911508..0c4233529 100644 --- a/python/pylibwholegraph/pyproject.toml +++ b/python/pylibwholegraph/pyproject.toml @@ -34,6 +34,7 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] [tool.rapids-build-backend] From dd25ee69d93c70f84c2e7502e37f1f5e294932dd Mon Sep 17 00:00:00 2001 From: Ray Douglass <3107146+raydouglass@users.noreply.github.com> Date: Mon, 9 Sep 2024 12:16:20 -0400 Subject: [PATCH 13/17] Ensure pylibwholegraph conda packages have the license (#215) Just adds the existing license to the `pylibwholegraph` conda recipe. Authors: - Ray Douglass (https://github.com/raydouglass) Approvers: - James Lamb (https://github.com/jameslamb) URL: https://github.com/rapidsai/wholegraph/pull/215 --- conda/recipes/pylibwholegraph/meta.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/conda/recipes/pylibwholegraph/meta.yaml b/conda/recipes/pylibwholegraph/meta.yaml index d3f9a49b7..8149fdc95 100644 --- a/conda/recipes/pylibwholegraph/meta.yaml +++ b/conda/recipes/pylibwholegraph/meta.yaml @@ -75,4 +75,6 @@ requirements: about: home: https://rapids.ai/ + license: Apache-2.0 + license_file: ../../../LICENSE summary: pylibwholegraph library From b33326f2ccd285add01ae27c0843511392a106c8 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 18 Sep 2024 14:44:10 -0500 Subject: [PATCH 14/17] Use CI workflow branch 'branch-24.10' again [skip ci] (#216) --- .github/workflows/build.yaml | 12 ++++++------ .github/workflows/pr.yaml | 18 +++++++++--------- .github/workflows/test.yaml | 6 +++--- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 2c5ceeb6a..454185f4a 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -38,7 +38,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -49,7 +49,7 @@ jobs: if: github.ref_type == 'branch' needs: [python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.10 with: arch: "amd64" branch: ${{ inputs.branch }} @@ -62,7 +62,7 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-24.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -70,7 +70,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -80,7 +80,7 @@ jobs: wheel-publish-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index ec7848ddf..6e5c86c54 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -21,41 +21,41 @@ jobs: - wheel-build-pylibwholegraph - wheel-test-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.10 checks: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.10 with: enable_check_generated_files: false conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.10 with: build_type: pull-request node_type: cpu16 conda-cpp-tests: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.10 with: build_type: pull-request conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.10 with: build_type: pull-request conda-python-tests: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.10 with: build_type: pull-request docs-build: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.10 with: build_type: pull-request arch: "amd64" @@ -64,14 +64,14 @@ jobs: wheel-build-pylibwholegraph: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.10 with: build_type: pull-request script: ci/build_wheel.sh wheel-test-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.10 with: build_type: pull-request script: ci/test_wheel.sh diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index e65996647..f2d7e1cc7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,7 +16,7 @@ on: jobs: conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.10 with: build_type: nightly branch: ${{ inputs.branch }} @@ -24,7 +24,7 @@ jobs: sha: ${{ inputs.sha }} conda-pytorch-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.10 with: build_type: nightly branch: ${{ inputs.branch }} @@ -32,7 +32,7 @@ jobs: sha: ${{ inputs.sha }} wheel-tests-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@python-3.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.10 with: build_type: nightly branch: ${{ inputs.branch }} From 09e90be8894833757d3d0a6119d0e24260b0d129 Mon Sep 17 00:00:00 2001 From: Jake Awe <50372925+AyodeAwe@users.noreply.github.com> Date: Tue, 24 Sep 2024 14:11:42 -0500 Subject: [PATCH 15/17] update update-version.sh to use packaging lib (#219) --- ci/release/update-version.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index 607ce165b..3b66b2304 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -32,8 +32,8 @@ NEXT_MINOR=$(echo $NEXT_FULL_TAG | awk '{split($0, a, "."); print a[2]}') NEXT_SHORT_TAG=${NEXT_MAJOR}.${NEXT_MINOR} # Need to distutils-normalize the versions for some use cases -CURRENT_SHORT_TAG_PEP440=$(python -c "from setuptools.extern import packaging; print(packaging.version.Version('${CURRENT_SHORT_TAG}'))") -NEXT_SHORT_TAG_PEP440=$(python -c "from setuptools.extern import packaging; print(packaging.version.Version('${NEXT_SHORT_TAG}'))") +CURRENT_SHORT_TAG_PEP440=$(python -c "from packaging.version import Version; print(Version('${CURRENT_SHORT_TAG}'))") +NEXT_SHORT_TAG_PEP440=$(python -c "from packaging.version import Version; print(Version('${NEXT_SHORT_TAG}'))") echo "Preparing release $CURRENT_TAG => $NEXT_FULL_TAG" From 73266e23ddfacf5fee3091b8da2a44af851a68cb Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 25 Sep 2024 09:11:35 -0500 Subject: [PATCH 16/17] bump NCCL floor to 2.18.1.1, relax PyTorch pin (#218) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Contributes to https://github.com/rapidsai/build-planning/issues/102 Fixes #217 ## Notes for Reviewers ### How I tested this Temporarily added a CUDA 11.4.3 test job to CI here (the same specs as the failing nightly), by pointing at the branch from https://github.com/rapidsai/shared-workflows/pull/246. Observed the exact same failures with CUDA 11.4 reported in https://github.com/rapidsai/build-planning/issues/102. ```text ... + nccl 2.10.3.1 hcad2f07_0 rapidsai-nightly 125MB ... ./WHOLEGRAPH_CSR_WEIGHTED_SAMPLE_WITHOUT_REPLACEMENT_TEST: symbol lookup error: /opt/conda/envs/test/bin/gtests/libwholegraph/../../../lib/libwholegraph.so: undefined symbol: ncclCommSplit sh -c exec "$0" ./WHOLEMEMORY_HANDLE_TEST ./WHOLEMEMORY_HANDLE_TEST: symbol lookup error: /opt/conda/envs/test/bin/gtests/libwholegraph/../../../lib/libwholegraph.so: undefined symbol: ncclCommSplit sh -c exec "$0" ./GRAPH_APPEND_UNIQUE_TEST ``` ([build link](https://github.com/rapidsai/wholegraph/actions/runs/10966022370/job/30453393224?pr=218)) Pushed a commit adding a floor of `nccl>=2.18.1.1`. Saw all tests pass with CUDA 11.4 😁 ```text ... + nccl 2.22.3.1 hee583db_1 conda-forge 131MB ... (various log messages showing all tests passed) ``` ([build link](https://github.com/rapidsai/wholegraph/actions/runs/10966210441/job/30454147250?pr=218)) Authors: - James Lamb (https://github.com/jameslamb) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) - https://github.com/linhu-nv - https://github.com/jakirkham URL: https://github.com/rapidsai/wholegraph/pull/218 --- conda/environments/all_cuda-118_arch-x86_64.yaml | 4 ++-- conda/environments/all_cuda-125_arch-x86_64.yaml | 2 +- conda/recipes/libwholegraph/conda_build_config.yaml | 2 +- dependencies.yaml | 12 ++++++------ 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index d989d880e..f20d98977 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -28,7 +28,7 @@ dependencies: - librmm==24.10.*,>=0.0.0a0 - nanobind>=0.2.0 - nbsphinx -- nccl +- nccl>=2.18.1.1 - ninja - numpy>=1.23,<3.0a0 - numpydoc @@ -40,7 +40,7 @@ dependencies: - pytest-xdist - python>=3.10,<3.13 - pytorch-cuda=11.8 -- pytorch=2.0.0 +- pytorch>=2.0,<2.4.0a0 - rapids-build-backend>=0.3.0,<0.4.0.dev0 - recommonmark - scikit-build-core>=0.10.0 diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 5b152cd31..5988a9893 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -30,7 +30,7 @@ dependencies: - librmm==24.10.*,>=0.0.0a0 - nanobind>=0.2.0 - nbsphinx -- nccl +- nccl>=2.18.1.1 - ninja - numpy>=1.23,<3.0a0 - numpydoc diff --git a/conda/recipes/libwholegraph/conda_build_config.yaml b/conda/recipes/libwholegraph/conda_build_config.yaml index 35b1d6b62..8b6dd3439 100644 --- a/conda/recipes/libwholegraph/conda_build_config.yaml +++ b/conda/recipes/libwholegraph/conda_build_config.yaml @@ -17,7 +17,7 @@ doxygen_version: - ">=1.8.11" nccl_version: - - ">=2.9.9" + - ">=2.18.1.1" c_stdlib: - sysroot diff --git a/dependencies.yaml b/dependencies.yaml index 8aaf92cd9..950e1979a 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -87,7 +87,7 @@ dependencies: - libraft-headers==24.10.*,>=0.0.0a0 - librmm==24.10.*,>=0.0.0a0 - nanobind>=0.2.0 - - nccl + - &nccl nccl>=2.18.1.1 specific: - output_types: conda matrices: @@ -216,14 +216,14 @@ dependencies: common: - output_types: [conda] packages: - - nccl + - *nccl test_python: common: - output_types: [conda] packages: - c-compiler - cxx-compiler - - nccl + - *nccl - output_types: [conda, requirements] packages: - ninja @@ -285,13 +285,13 @@ dependencies: # If conda-forge supports the new cuda-* packages for CUDA 11.8 # at some point, then we can fully support/properly specify # this environment. - - pytorch=2.0.0 + - &pytorch pytorch>=2.0,<2.4.0a0 - pytorch-cuda=11.8 - matrix: arch: aarch64 cuda: "11.8" packages: - - pytorch=2.0.0 + - *pytorch - pytorch-cuda=11.8 - matrix: packages: @@ -318,7 +318,7 @@ dependencies: common: - output_types: [conda] packages: - - pytorch=2.0.0 + - *pytorch - cpuonly clang_tools: common: From f8ad9a1167c51fbb0dc9f82d2364cef62f46ec27 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 26 Sep 2024 15:18:35 -0500 Subject: [PATCH 17/17] bump NCCL floor to 2.19 (#223) Follow-up to #218 This bumps the NCCL floor here slightly higher, to `>=2.19`. Part of a RAPIDS-wide update of that floor for the 24.10 release. See https://github.com/rapidsai/build-planning/issues/102#issuecomment-2375595743 for context. cc @linhu-nv for awareness Authors: - James Lamb (https://github.com/jameslamb) Approvers: - https://github.com/jakirkham URL: https://github.com/rapidsai/wholegraph/pull/223 --- conda/environments/all_cuda-118_arch-x86_64.yaml | 2 +- conda/environments/all_cuda-125_arch-x86_64.yaml | 2 +- conda/recipes/libwholegraph/conda_build_config.yaml | 2 +- dependencies.yaml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index f20d98977..ec0452748 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -28,7 +28,7 @@ dependencies: - librmm==24.10.*,>=0.0.0a0 - nanobind>=0.2.0 - nbsphinx -- nccl>=2.18.1.1 +- nccl>=2.19 - ninja - numpy>=1.23,<3.0a0 - numpydoc diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 5988a9893..8ec405faf 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -30,7 +30,7 @@ dependencies: - librmm==24.10.*,>=0.0.0a0 - nanobind>=0.2.0 - nbsphinx -- nccl>=2.18.1.1 +- nccl>=2.19 - ninja - numpy>=1.23,<3.0a0 - numpydoc diff --git a/conda/recipes/libwholegraph/conda_build_config.yaml b/conda/recipes/libwholegraph/conda_build_config.yaml index 8b6dd3439..ebb154c29 100644 --- a/conda/recipes/libwholegraph/conda_build_config.yaml +++ b/conda/recipes/libwholegraph/conda_build_config.yaml @@ -17,7 +17,7 @@ doxygen_version: - ">=1.8.11" nccl_version: - - ">=2.18.1.1" + - ">=2.19" c_stdlib: - sysroot diff --git a/dependencies.yaml b/dependencies.yaml index 950e1979a..70542003d 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -87,7 +87,7 @@ dependencies: - libraft-headers==24.10.*,>=0.0.0a0 - librmm==24.10.*,>=0.0.0a0 - nanobind>=0.2.0 - - &nccl nccl>=2.18.1.1 + - &nccl nccl>=2.19 specific: - output_types: conda matrices: