From d589258e2897fccd834ee68ee28fd1729a19d2b2 Mon Sep 17 00:00:00 2001 From: Ray Douglass Date: Thu, 9 Nov 2023 16:34:03 -0500 Subject: [PATCH 01/16] v24.02 Updates [skip ci] --- .github/workflows/build.yaml | 12 ++++++------ .github/workflows/pr.yaml | 18 +++++++++--------- .github/workflows/test.yaml | 6 +++--- VERSION | 2 +- ci/build_docs.sh | 2 +- .../environments/all_cuda-118_arch-x86_64.yaml | 4 ++-- .../environments/all_cuda-120_arch-x86_64.yaml | 4 ++-- cpp/CMakeLists.txt | 2 +- cpp/Doxyfile | 2 +- dependencies.yaml | 4 ++-- fetch_rapids.cmake | 2 +- python/pylibwholegraph/CMakeLists.txt | 2 +- 12 files changed, 30 insertions(+), 30 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index f399885cc..c423c49ea 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.02 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -38,7 +38,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.02 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -49,7 +49,7 @@ jobs: if: github.ref_type == 'branch' needs: [python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.02 with: arch: "amd64" branch: ${{ inputs.branch }} @@ -62,7 +62,7 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-24.02 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -70,7 +70,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.02 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -80,7 +80,7 @@ jobs: wheel-publish-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.02 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 72813cdec..3b6655c78 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -21,41 +21,41 @@ jobs: - wheel-build-pylibwholegraph - wheel-test-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.02 checks: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.02 with: enable_check_generated_files: false conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-build.yaml@branch-24.02 with: build_type: pull-request node_type: cpu16 conda-cpp-tests: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.02 with: build_type: pull-request conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.02 with: build_type: pull-request conda-python-tests: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.02 with: build_type: pull-request docs-build: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/custom-job.yaml@branch-24.02 with: build_type: pull-request arch: "amd64" @@ -64,14 +64,14 @@ jobs: wheel-build-pylibwholegraph: needs: checks secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.02 with: build_type: pull-request script: ci/build_wheel.sh wheel-test-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.02 with: build_type: pull-request script: ci/test_wheel.sh diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 13f23bada..52319b3fd 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,7 +16,7 @@ on: jobs: conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.02 with: build_type: nightly branch: ${{ inputs.branch }} @@ -24,7 +24,7 @@ jobs: sha: ${{ inputs.sha }} conda-pytorch-tests: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-tests.yaml@branch-24.02 with: build_type: nightly branch: ${{ inputs.branch }} @@ -32,7 +32,7 @@ jobs: sha: ${{ inputs.sha }} wheel-tests-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-23.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.02 with: build_type: nightly branch: ${{ inputs.branch }} diff --git a/VERSION b/VERSION index a193fff41..3c6c5e2b7 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -23.12.00 +24.02.00 diff --git a/ci/build_docs.sh b/ci/build_docs.sh index 2978dd383..4f3af9c18 100755 --- a/ci/build_docs.sh +++ b/ci/build_docs.sh @@ -22,7 +22,7 @@ rapids-print-env rapids-logger "Downloading artifacts from previous jobs" CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp) -export RAPIDS_VERSION_NUMBER="23.12" +export RAPIDS_VERSION_NUMBER="24.02" export RAPIDS_DOCS_DIR="$(mktemp -d)" rapids-mamba-retry install \ diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 8d6b0b1bb..85e53203d 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -23,8 +23,8 @@ dependencies: - graphviz - ipykernel - ipython -- libraft-headers==23.12.* -- librmm==23.12.* +- libraft-headers==24.2.* +- librmm==24.2.* - nanobind>=0.2.0 - nbsphinx - nccl diff --git a/conda/environments/all_cuda-120_arch-x86_64.yaml b/conda/environments/all_cuda-120_arch-x86_64.yaml index 09f7fe7fa..f36f5e891 100644 --- a/conda/environments/all_cuda-120_arch-x86_64.yaml +++ b/conda/environments/all_cuda-120_arch-x86_64.yaml @@ -25,8 +25,8 @@ dependencies: - graphviz - ipykernel - ipython -- libraft-headers==23.12.* -- librmm==23.12.* +- libraft-headers==24.2.* +- librmm==24.2.* - nanobind>=0.2.0 - nbsphinx - nccl diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 431459876..cd919aa71 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -14,7 +14,7 @@ # limitations under the License. #============================================================================= -set(RAPIDS_VERSION "23.12") +set(RAPIDS_VERSION "24.02") set(WHOLEGRAPH_VERSION "${RAPIDS_VERSION}.00") cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) diff --git a/cpp/Doxyfile b/cpp/Doxyfile index 841151312..357e85685 100644 --- a/cpp/Doxyfile +++ b/cpp/Doxyfile @@ -38,7 +38,7 @@ PROJECT_NAME = "WholeGraph C API" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 23.12 +PROJECT_NUMBER = 24.02 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/dependencies.yaml b/dependencies.yaml index 3490469ef..000bc2b3c 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -164,8 +164,8 @@ dependencies: common: - output_types: [conda, requirements] packages: - - libraft-headers==23.12.* - - librmm==23.12.* + - libraft-headers==24.2.* + - librmm==24.2.* test_cpp: common: - output_types: [conda, requirements] diff --git a/fetch_rapids.cmake b/fetch_rapids.cmake index 2c1dd855c..1f099e8f8 100644 --- a/fetch_rapids.cmake +++ b/fetch_rapids.cmake @@ -12,7 +12,7 @@ # the License. # ============================================================================= if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/CUGRAPH_RAPIDS.cmake) - file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-23.12/RAPIDS.cmake + file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-24.02/RAPIDS.cmake ${CMAKE_CURRENT_BINARY_DIR}/CUGRAPH_RAPIDS.cmake ) endif() diff --git a/python/pylibwholegraph/CMakeLists.txt b/python/pylibwholegraph/CMakeLists.txt index 5c01e0956..a8cd320e0 100644 --- a/python/pylibwholegraph/CMakeLists.txt +++ b/python/pylibwholegraph/CMakeLists.txt @@ -16,7 +16,7 @@ cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) -set(RAPIDS_VERSION "23.12") +set(RAPIDS_VERSION "24.02") set(WHOLEGRAPH_VERSION "${RAPIDS_VERSION}.00") include(FetchContent) From 9853c624dc78dfdd081df20fb57221b7329da5f5 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Mon, 27 Nov 2023 11:49:56 -0800 Subject: [PATCH 02/16] Fix whitespace. --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 77e8f8059..9d116b4dd 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -57,7 +57,7 @@ endfunction() # CPM_raft_SOURCE=/path/to/local/raft find_and_configure_raft(VERSION ${WHOLEGRAPH_MIN_VERSION_raft} FORK rapidsai - PINNED_TAG branch-${WHOLEGRAPH_BRANCH_VERSION_raft} + PINNED_TAG branch-${WHOLEGRAPH_BRANCH_VERSION_raft} # When PINNED_TAG above doesn't match wholegraph, # force local raft clone in build directory From c43f6d173cdc15131176ed35c7fd93fbc5e14c06 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Mon, 18 Dec 2023 12:45:53 -0600 Subject: [PATCH 03/16] Align versions for cudnn, clang-tools, cython, and doxygen with the rest of RAPIDS. (#112) This PR aligns versions for wholegraph dependencies with versions used by other RAPIDS packages. This is needed for devcontainers, to make the unified RAPIDS conda environment solvable. See https://github.com/rapidsai/devcontainers/pull/191. Authors: - Bradley Dice (https://github.com/bdice) Approvers: - AJ Schmidt (https://github.com/ajschmidt8) - Ray Douglass (https://github.com/raydouglass) URL: https://github.com/rapidsai/wholegraph/pull/112 --- conda/environments/all_cuda-118_arch-x86_64.yaml | 10 +++++----- conda/environments/all_cuda-120_arch-x86_64.yaml | 10 +++++----- dependencies.yaml | 14 +++++++------- python/pylibwholegraph/pyproject.toml | 2 +- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 292c809d4..c7410ce71 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -9,15 +9,15 @@ channels: dependencies: - breathe - c-compiler -- clang-tools=16.0.0 -- clangxx=16.0.0 +- clang-tools==16.0.6 +- clangxx==16.0.6 - cmake>=3.26.4 - cuda-nvtx=11.8 - cudatoolkit=11.8 -- cudnn=8.4 +- cudnn=8.8 - cxx-compiler -- cython -- doxygen=1.8.20 +- cython>=3.0.0 +- doxygen==1.9.1 - gcc_linux-64=11.* - gitpython - graphviz diff --git a/conda/environments/all_cuda-120_arch-x86_64.yaml b/conda/environments/all_cuda-120_arch-x86_64.yaml index b436d5641..568c312b4 100644 --- a/conda/environments/all_cuda-120_arch-x86_64.yaml +++ b/conda/environments/all_cuda-120_arch-x86_64.yaml @@ -9,17 +9,17 @@ channels: dependencies: - breathe - c-compiler -- clang-tools=16.0.0 -- clangxx=16.0.0 +- clang-tools==16.0.6 +- clangxx==16.0.6 - cmake>=3.26.4 - cuda-cudart-dev - cuda-nvcc - cuda-nvtx - cuda-version=12.0 -- cudnn=8.4 +- cudnn=8.8 - cxx-compiler -- cython -- doxygen=1.8.20 +- cython>=3.0.0 +- doxygen==1.9.1 - gcc_linux-64=11.* - gitpython - graphviz diff --git a/dependencies.yaml b/dependencies.yaml index 71c0dc93d..2a52b450f 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -68,10 +68,10 @@ dependencies: packages: - c-compiler - cmake>=3.26.4 - - cudnn=8.4 + - cudnn=8.8 - cxx-compiler - - cython - - doxygen=1.8.20 + - cython>=3.0.0 + - &doxygen doxygen==1.9.1 - libraft-headers==24.2.* - librmm==24.2.* - nanobind>=0.2.0 @@ -252,7 +252,7 @@ dependencies: - output_types: [conda, requirements] packages: - breathe - - doxygen=1.8.20 + - *doxygen - graphviz - ipython - ipykernel @@ -274,15 +274,15 @@ dependencies: common: - output_types: [conda, requirements] packages: - - clangxx=16.0.0 - - clang-tools=16.0.0 + - clangxx==16.0.6 + - clang-tools==16.0.6 - gitpython python_build_wheel: common: - output_types: [pyproject] packages: - cmake>=3.26.4 - - cython>=0.29,<0.30 + - cython>=3.0.0 - ninja - setuptools - scikit-build>=0.13.1 diff --git a/python/pylibwholegraph/pyproject.toml b/python/pylibwholegraph/pyproject.toml index ccb14b831..910d4ccea 100644 --- a/python/pylibwholegraph/pyproject.toml +++ b/python/pylibwholegraph/pyproject.toml @@ -15,7 +15,7 @@ [build-system] requires = [ "cmake>=3.26.4", - "cython>=0.29,<0.30", + "cython>=3.0.0", "ninja", "scikit-build>=0.13.1", "setuptools", From ef5b3eef3f89dc2c9536fe9abce2d0c93cc9a364 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Tue, 19 Dec 2023 09:28:02 -0800 Subject: [PATCH 04/16] Don't overwrite wholegraph_ROOT if provided (#114) This change allows standard CMake specification of the C++ package directory (via `-Dwholegraph_ROOT`) to also work during the Python build. Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Bradley Dice (https://github.com/bdice) - Ray Douglass (https://github.com/raydouglass) URL: https://github.com/rapidsai/wholegraph/pull/114 --- python/pylibwholegraph/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pylibwholegraph/CMakeLists.txt b/python/pylibwholegraph/CMakeLists.txt index a8cd320e0..758fe3713 100644 --- a/python/pylibwholegraph/CMakeLists.txt +++ b/python/pylibwholegraph/CMakeLists.txt @@ -113,7 +113,9 @@ include(../../cpp/cmake/thirdparty/get_raft.cmake) #include(${CMAKE_CURRENT_LIST_DIR}/../cmake/thirdparty/nanobind.cmake) # use _ROOT here to take precedence over any other package -set(wholegraph_ROOT "$ENV{LIBWHOLEGRAPH_DIR}") +if (DEFINED ENV{LIBWHOLEGRAPH_DIR}) + set(wholegraph_ROOT "$ENV{LIBWHOLEGRAPH_DIR}") +endif() find_package(wholegraph "${RAPIDS_VERSION}.0" EXACT) message("WholeGraph") if (WHOLEGRAPH_FOUND) From aaceac529dc3a87286fd1cb2a435f42d3404d735 Mon Sep 17 00:00:00 2001 From: dongxuy04 <78518666+dongxuy04@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:16:01 +0800 Subject: [PATCH 05/16] added Direct IO support for WholeMemory loading (#113) Add Direct IO support option for WholeMemory loading from disk. Using Direct IO may be faster on some high performance file systems. Authors: - https://github.com/dongxuy04 - Brad Rees (https://github.com/BradReesWork) Approvers: - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/113 --- cpp/src/wholememory/file_io.cpp | 469 +++++++++++++++++++++++++------- 1 file changed, 372 insertions(+), 97 deletions(-) diff --git a/cpp/src/wholememory/file_io.cpp b/cpp/src/wholememory/file_io.cpp index 3274811e1..0540a3f5d 100644 --- a/cpp/src/wholememory/file_io.cpp +++ b/cpp/src/wholememory/file_io.cpp @@ -15,14 +15,17 @@ */ #include "file_io.h" +#include #include #include +#include #include #include #include "communicator.hpp" #include "error.hpp" +#include "integer_utils.hpp" #include "logger.hpp" namespace wholememory { @@ -38,6 +41,15 @@ static size_t StatFileSize(const char* filename) return filesize; } +static size_t StatFileBlockSize(const char* filename) +{ + auto blocksize = static_cast(-1); + struct stat statbuf {}; + if (stat(filename, &statbuf) < 0) { return blocksize; } + blocksize = statbuf.st_blksize; + return blocksize; +} + static size_t get_handle_partial_size(size_t handle_size, size_t memory_offset, size_t memory_entry_stride, @@ -62,6 +74,317 @@ static size_t get_handle_partial_size(size_t handle_size, return partial_size; } +/*! + * Read from file list to local memory of WholeMemory. File list are binary files, which are + * considered to be concatenated together. All ranks in WholeMemory will read the files in parallel + * and load each part into local memory of each rank. + * @param local_ptr : Pointer to local memory of WholeMemory + * @param local_size : Local memory size + * @param local_offset : The offset of local memory in WholeMemory. + * @param entry_size : The entry size of each data entry. + * @param memory_entry_stride : The stride of each entry in WholeMemory + * @param memory_offset : The start offset to place the read data. Should be in range [0, + * memory_entry_stride) + * @param file_count : Total file count of the file list + * @param file_names : File names of the file list. + * @param file_sizes : Sizes of each file. + * @param suggested_buffer_size : Suggested buffer size to read. + * @param wm_rank : WholeMemory rank. + */ +static void read_file_list_to_local_memory(char* local_ptr, + size_t local_size, + size_t local_offset, + size_t entry_size, + size_t memory_entry_stride, + size_t memory_offset, + int file_count, + const char** file_names, + const std::vector& file_sizes, + size_t suggested_buffer_size, + int wm_rank) +{ + size_t buffer_size; + size_t buffer_entry_count = 1; + if (suggested_buffer_size < entry_size) { + buffer_size = entry_size; + } else { + buffer_entry_count = suggested_buffer_size / entry_size; + buffer_size = buffer_entry_count * entry_size; + } + std::vector file_read_buffer(buffer_size); + + size_t local_entry_memory_start_index = local_offset / memory_entry_stride; + size_t local_entry_file_start_index = + local_entry_memory_start_index - memory_offset / memory_entry_stride; + size_t local_entry_count = local_size / memory_entry_stride; + char* local_write_ptr = local_ptr + memory_offset % memory_entry_stride; + if (wm_rank == 0) { + local_entry_count -= memory_offset / memory_entry_stride; + local_write_ptr += (memory_offset / memory_entry_stride) * memory_entry_stride; + } + size_t local_entry_idx = 0; + + size_t file_entry_offset = 0; + size_t total_read_bytes = 0; + for (int i = 0; i < file_count; i++) { + size_t file_entry_count = file_sizes[i] / entry_size; + // already outside reading window + if (file_entry_offset >= local_entry_file_start_index + local_entry_count) break; + // in reading window + if (file_entry_offset + file_entry_count > local_entry_file_start_index) { + size_t file_read_start_offset = 0; + FILE* fp = fopen(file_names[i], "rb"); + if (fp == nullptr) { WHOLEMEMORY_ERROR("Open file %s for read failed.", file_names[i]); } + // maybe in window end, remove possible tailing data that don't belong to current rank. + size_t to_read_file_entry_count = std::min( + file_entry_count, local_entry_file_start_index + local_entry_count - file_entry_offset); + // if in window begin, remove possible data that belongs to previous rank and skip disk + // data. + if (file_entry_offset < local_entry_file_start_index) { + size_t skip_entry_count = local_entry_file_start_index - file_entry_offset; + + file_read_start_offset = skip_entry_count * entry_size; + + if (fseeko(fp, file_read_start_offset, SEEK_SET) != 0) { + WHOLEMEMORY_ERROR( + "File %s seek to %ld failed.", file_names[i], skip_entry_count * entry_size); + } + to_read_file_entry_count -= skip_entry_count; + } + // now all data in file_entry_count need to be read. + size_t bytes_to_read = to_read_file_entry_count * entry_size; + size_t left_entry_count = to_read_file_entry_count; + while (left_entry_count > 0) { + size_t read_entry_count = std::min(left_entry_count, buffer_entry_count); + + int ret = fread(file_read_buffer.data(), entry_size, read_entry_count, fp); + if (ret != read_entry_count) { + WHOLEMEMORY_ERROR( + "File %s line %d: reading from file %s, read_entry_count=%ld, entry_size=%ld, " + "returned %d, error=%s\n", + __FILE__, + __LINE__, + file_names[i], + read_entry_count, + entry_size, + ret, + strerror(errno)); + } + + if (entry_size != memory_entry_stride) { + WM_CUDA_CHECK(cudaMemcpy2D(local_write_ptr, + memory_entry_stride, + file_read_buffer.data(), + entry_size, + entry_size, + read_entry_count, + cudaMemcpyDefault)); + } else { + WM_CUDA_CHECK(cudaMemcpy(local_write_ptr, + file_read_buffer.data(), + read_entry_count * entry_size, + cudaMemcpyDefault)); + } + local_write_ptr += read_entry_count * memory_entry_stride; + + left_entry_count -= read_entry_count; + } + fclose(fp); + WHOLEMEMORY_INFO( + "Rank=%d done Reading %ld bytes from file %s size=%ld, starting from offset=%ld.", + wm_rank, + bytes_to_read, + file_names[i], + file_sizes[i], + file_read_start_offset); + total_read_bytes += bytes_to_read; + } + file_entry_offset += file_entry_count; + } + WHOLEMEMORY_INFO( + "Rank=%d done reading total %ld bytes from needed files.", wm_rank, total_read_bytes); +} + +/*! + * Read from file list to local memory of WholeMemory using DirectIO. Using DirectIO may have better + * performance by bypassing system cache if it is bottleneck. File list are binary files, which are + * considered to be concatenated together. All ranks in WholeMemory will read the files in parallel + * and load each part into local memory of each rank. + * @param local_ptr : Pointer to local memory of WholeMemory + * @param local_size : Local memory size + * @param local_offset : The offset of local memory in WholeMemory. + * @param entry_size : The entry size of each data entry. + * @param memory_entry_stride : The stride of each entry in WholeMemory + * @param memory_offset : The start offset to place the read data. Should be in range [0, + * memory_entry_stride) + * @param file_count : Total file count of the file list + * @param file_names : File names of the file list. + * @param file_sizes : Sizes of each file. + * @param suggested_buffer_size : Suggested buffer size to read. + * @param wm_rank : WholeMemory rank. + */ +static void read_file_list_to_local_memory_directio(char* local_ptr, + size_t local_size, + size_t local_offset, + size_t entry_size, + size_t memory_entry_stride, + size_t memory_offset, + int file_count, + const char** file_names, + const std::vector& file_sizes, + size_t suggested_buffer_size, + int wm_rank) +{ + if (memory_offset + entry_size > memory_entry_stride) { + WHOLEMEMORY_FAIL_NOTHROW("Direct io mode only support reading all entries."); + } + size_t local_entry_start_index = local_offset / memory_entry_stride; + size_t local_entry_count = local_size / memory_entry_stride; + char* local_write_ptr = local_ptr + memory_offset % memory_entry_stride; + + static size_t kAlignSize = 16 * 1024 * 1024; + suggested_buffer_size = round_up_unsafe(suggested_buffer_size, kAlignSize); + + char* block_buffer; + WHOLEMEMORY_CHECK_NOTHROW(posix_memalign(reinterpret_cast(&block_buffer), + kAlignSize, + suggested_buffer_size) == 0); + + size_t file_entry_offset = 0; + size_t read_entry_count = 0; + for (int i = 0; i < file_count; i++) { + size_t file_entry_count = file_sizes[i] / entry_size; + // already outside reading window + if (file_entry_offset >= local_entry_start_index + local_entry_count) break; + // reading window not reached + if (file_entry_offset + file_entry_count <= local_entry_start_index) { + file_entry_offset += file_entry_count; + continue; + } + // in reading window + auto block_size = StatFileBlockSize(file_names[i]); + if (block_size == 0 || block_size == (size_t)-1 || kAlignSize % block_size != 0) { + WHOLEMEMORY_FAIL_NOTHROW( + "block_size=%ld for file %s, but alignment is %ld", block_size, file_names[i], kAlignSize); + } + size_t buffer_block_count = suggested_buffer_size / block_size; + int fd = open(file_names[i], O_DIRECT | O_RDONLY); + if (fd < 0) { WHOLEMEMORY_FAIL_NOTHROW("Open file %s with direct io failed.", file_names[i]); } + + // maybe in window end, remove possible tailing data that don't belong to current rank. + size_t to_read_file_entry_count = + std::min(file_entry_count, local_entry_start_index + local_entry_count - file_entry_offset); + + size_t file_read_end = to_read_file_entry_count * entry_size; + // if in window begin, remove possible data that belongs to previous rank and skip disk + // data. + size_t file_read_start = 0; + if (file_entry_offset < local_entry_start_index) { + size_t skip_entry_count = local_entry_start_index - file_entry_offset; + to_read_file_entry_count -= skip_entry_count; + file_read_start = skip_entry_count * entry_size; + } + + size_t file_block_read_offset = file_read_start / block_size * block_size; + size_t skip_head_size = file_read_start - file_block_read_offset; + + char* local_mem_write_entry_for_file = local_write_ptr + read_entry_count * memory_entry_stride; + size_t first_mem_entry_offset = 0; + size_t useful_data_bytes_read = 0; + size_t physical_data_bytes_read = 0; + while (file_block_read_offset < file_read_end) { + size_t left_size = file_read_end - file_block_read_offset; + size_t left_block_count = div_rounding_up_unsafe(left_size, block_size); + size_t read_block_count = std::min(left_block_count, buffer_block_count); + size_t physical_read_size = read_block_count * block_size; + physical_data_bytes_read += physical_read_size; + + ssize_t pread_size = pread64(fd, block_buffer, physical_read_size, file_block_read_offset); + if (pread_size != physical_read_size && + file_block_read_offset + pread_size != file_sizes[i]) { + WHOLEMEMORY_FAIL_NOTHROW( + "rank=%d, pread_size=%ld, physical_read_size=%ld, file_block_read_offset=%ld, " + "file_sizes[i]=%ld, file=%s", + wm_rank, + pread_size, + physical_read_size, + file_block_read_offset, + file_sizes[i], + file_names[i]); + } + + size_t drop_tail_size = 0; + if (file_block_read_offset + physical_read_size > file_read_end) { + drop_tail_size = file_block_read_offset + physical_read_size - file_read_end; + } + + char* useful_data_ptr = block_buffer + skip_head_size; + size_t useful_data_size = physical_read_size - skip_head_size - drop_tail_size; + + useful_data_bytes_read += useful_data_size; + + if (first_mem_entry_offset != 0) { + // process head + size_t entry_left_size = entry_size - first_mem_entry_offset; + WM_CUDA_CHECK_NO_THROW(cudaMemcpy(local_mem_write_entry_for_file + first_mem_entry_offset, + useful_data_ptr, + entry_left_size, + cudaMemcpyDefault)); + local_mem_write_entry_for_file += memory_entry_stride; + useful_data_ptr += entry_left_size; + useful_data_size -= entry_left_size; + entry_left_size = 0; + } + + size_t full_entry_count = useful_data_size / entry_size; + size_t full_entry_size = full_entry_count * entry_size; + + if (full_entry_size > 0) { + if (entry_size != memory_entry_stride) { + WM_CUDA_CHECK(cudaMemcpy2D(local_mem_write_entry_for_file, + memory_entry_stride, + useful_data_ptr, + entry_size, + entry_size, + full_entry_count, + cudaMemcpyDefault)); + } else { + WM_CUDA_CHECK(cudaMemcpy( + local_mem_write_entry_for_file, useful_data_ptr, full_entry_size, cudaMemcpyDefault)); + } + local_mem_write_entry_for_file += memory_entry_stride * full_entry_count; + useful_data_ptr += full_entry_size; + useful_data_size -= full_entry_size; + } + + size_t tail_entry_size = useful_data_size % entry_size; + if (tail_entry_size != 0) { + // process tail + WM_CUDA_CHECK_NO_THROW(cudaMemcpy( + local_mem_write_entry_for_file, useful_data_ptr, tail_entry_size, cudaMemcpyDefault)); + first_mem_entry_offset = tail_entry_size; + } + + file_block_read_offset += physical_read_size; + skip_head_size = 0; + } + + WHOLEMEMORY_INFO( + "Rank=%d done Reading %ld useful bytes by reading %ld block bytes using DirectIO from file " + "%s size=%ld.", + wm_rank, + useful_data_bytes_read, + physical_data_bytes_read, + file_names[i], + file_sizes[i]); + + close(fd); + file_entry_offset += file_entry_count; + read_entry_count += to_read_file_entry_count; + } + free(block_buffer); +} + wholememory_error_code_t load_file_to_handle(wholememory_handle_t wholememory_handle, size_t memory_offset, size_t memory_entry_stride, @@ -153,107 +476,59 @@ wholememory_error_code_t load_file_to_handle(wholememory_handle_t wholememory_ha (void**)(&local_ptr), &local_size, &local_offset, wholememory_handle) == WHOLEMEMORY_SUCCESS); - constexpr int kSuggestedBufferSize = 16 * 1024 * 1024; - size_t buffer_size; - size_t buffer_entry_count = 1; - if (kSuggestedBufferSize < entry_size) { - buffer_size = entry_size; - } else { - buffer_entry_count = kSuggestedBufferSize / entry_size; - buffer_size = buffer_entry_count * entry_size; + int suggested_buffer_size_mb = 16; + const char* buffer_size_env_var = std::getenv("WG_LOAD_BUFFER_SIZE_MB"); + if (buffer_size_env_var != nullptr) { + try { + suggested_buffer_size_mb = std::stoi(buffer_size_env_var); + } catch (const std::invalid_argument& e) { + suggested_buffer_size_mb = 16; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_BUFFER_SIZE_MB value %s is not valid, using default %d", + buffer_size_env_var, + suggested_buffer_size_mb); + } + if (suggested_buffer_size_mb < 1) { + suggested_buffer_size_mb = 16; + WHOLEMEMORY_WARN( + "Environment variable WG_LOAD_BUFFER_SIZE_MB value %s is not valid, using default %d", + buffer_size_env_var, + suggested_buffer_size_mb); + } } - std::vector file_read_buffer(buffer_size); + size_t suggested_buffer_size = static_cast(suggested_buffer_size_mb) * 1024 * 1024; - size_t local_entry_memory_start_index = local_offset / memory_entry_stride; - size_t local_entry_file_start_index = - local_entry_memory_start_index - memory_offset / memory_entry_stride; - size_t local_entry_count = local_size / memory_entry_stride; - char* local_write_ptr = local_ptr + memory_offset % memory_entry_stride; - if (wm_rank == 0) { - local_entry_count -= memory_offset / memory_entry_stride; - local_write_ptr += (memory_offset / memory_entry_stride) * memory_entry_stride; + const char* directio_env_var = std::getenv("WG_LOAD_USE_DIRECTIO"); + bool use_direct_io = false; + if (directio_env_var != nullptr && directio_env_var[0] == '1' && directio_env_var[1] == '\0') { + use_direct_io = true; } - size_t local_entry_idx = 0; - - size_t file_entry_offset = 0; - size_t total_read_bytes = 0; - for (int i = 0; i < file_count; i++) { - size_t file_entry_count = file_sizes[i] / entry_size; - // already outside reading window - if (file_entry_offset >= local_entry_file_start_index + local_entry_count) break; - // in reading window - if (file_entry_offset + file_entry_count > local_entry_file_start_index) { - size_t file_read_start_offset = 0; - FILE* fp = fopen(file_names[i], "rb"); - if (fp == nullptr) { WHOLEMEMORY_ERROR("Open file %s for read failed.", file_names[i]); } - // maybe in window end, remove possible tailing data that don't belong to current rank. - size_t to_read_file_entry_count = std::min( - file_entry_count, local_entry_file_start_index + local_entry_count - file_entry_offset); - // if in window begin, remove possible data that belongs to previous rank and skip disk - // data. - if (file_entry_offset < local_entry_file_start_index) { - size_t skip_entry_count = local_entry_file_start_index - file_entry_offset; - - file_read_start_offset = skip_entry_count * entry_size; - - if (fseeko(fp, file_read_start_offset, SEEK_SET) != 0) { - WHOLEMEMORY_ERROR( - "File %s seek to %ld failed.", file_names[i], skip_entry_count * entry_size); - } - to_read_file_entry_count -= skip_entry_count; - } - // now all data in file_entry_count need to be read. - size_t bytes_to_read = to_read_file_entry_count * entry_size; - size_t left_entry_count = to_read_file_entry_count; - while (left_entry_count > 0) { - size_t read_entry_count = std::min(left_entry_count, buffer_entry_count); - - int ret = fread(file_read_buffer.data(), entry_size, read_entry_count, fp); - if (ret != read_entry_count) { - WHOLEMEMORY_ERROR( - "File %s line %d: reading from file %s, read_entry_count=%ld, entry_size=%ld, " - "returned %d, error=%s\n", - __FILE__, - __LINE__, - file_names[i], - read_entry_count, - entry_size, - ret, - strerror(errno)); - } - - if (entry_size != memory_entry_stride) { - WM_CUDA_CHECK(cudaMemcpy2D(local_write_ptr, - memory_entry_stride, - file_read_buffer.data(), - entry_size, - entry_size, - read_entry_count, - cudaMemcpyDefault)); - } else { - WM_CUDA_CHECK(cudaMemcpy(local_write_ptr, - file_read_buffer.data(), - read_entry_count * entry_size, - cudaMemcpyDefault)); - } - local_write_ptr += read_entry_count * memory_entry_stride; - - left_entry_count -= read_entry_count; - } - fclose(fp); - WHOLEMEMORY_INFO( - "Rank=%d done Reading %ld bytes from file %s size=%ld, starting from offset=%ld.", - wm_rank, - bytes_to_read, - file_names[i], - file_sizes[i], - file_read_start_offset); - total_read_bytes += bytes_to_read; - } - file_entry_offset += file_entry_count; + if (!use_direct_io) { + read_file_list_to_local_memory(local_ptr, + local_size, + local_offset, + entry_size, + memory_entry_stride, + memory_offset, + file_count, + file_names, + file_sizes, + suggested_buffer_size, + wm_rank); + } else { + read_file_list_to_local_memory_directio(local_ptr, + local_size, + local_offset, + entry_size, + memory_entry_stride, + memory_offset, + file_count, + file_names, + file_sizes, + suggested_buffer_size, + wm_rank); } - WHOLEMEMORY_INFO( - "Rank=%d done reading total %ld bytes from needed files.", wm_rank, total_read_bytes); + wm_comm->barrier(); } catch (wholememory::logic_error& wle) { WHOLEMEMORY_ERROR("Logic error: %s", wle.what()); From 7025eafa75567f7353ae97b5f9077ef9f9879649 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 11 Jan 2024 13:54:01 -0600 Subject: [PATCH 06/16] refactor CUDA versions in dependencies.yaml (#115) Contributes to https://github.com/rapidsai/build-planning/issues/7. Proposes splitting the `cuda-version` dependency in `dependencies.yaml` out to its own thing, separate from the bits of the CUDA Toolkit this project needs. ### Benefits of this change * prevents accidental inclusion of multiple `cuda-version` version in environments * reduces update effort (via enabling more use of globs like `"12.*"`) * improves the chance that errors like "`conda` recipe is missing a dependency" are caught in CI ### Notes for Reviewers This change was intended to just re-organize `dependencies.yaml`, but I do think the one additional change it introduces to `all_cuda-118_arch-x86_64.yaml` is a good one. I *think* requiring the `cuda-version` metapackage in all environments is useful to prevent against environment solves that result in runtime issues. References: * https://github.com/conda-forge/cuda-version-feedstock/blob/902045016fbc6e4dd1350a7390b0411f376d1a19/recipe/meta.yaml#L13-L18 * https://docs.conda.io/projects/conda-build/en/stable/resources/define-metadata.html#run-constrained Authors: - James Lamb (https://github.com/jameslamb) Approvers: - Ray Douglass (https://github.com/raydouglass) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/wholegraph/pull/115 --- .pre-commit-config.yaml | 2 +- .../all_cuda-118_arch-x86_64.yaml | 3 +- dependencies.yaml | 50 ++++++++++++++----- 3 files changed, 40 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6943ae3b0..eef7a0285 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: pass_filenames: false additional_dependencies: [gitpython] - repo: https://github.com/rapidsai/dependency-file-generator - rev: v1.5.1 + rev: v1.8.0 hooks: - id: rapids-dependency-file-generator args: ["--clean"] diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index c7410ce71..825ec1f7d 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -13,7 +13,8 @@ dependencies: - clangxx==16.0.6 - cmake>=3.26.4 - cuda-nvtx=11.8 -- cudatoolkit=11.8 +- cuda-version=11.8 +- cudatoolkit - cudnn=8.8 - cxx-compiler - cython>=3.0.0 diff --git a/dependencies.yaml b/dependencies.yaml index 2a52b450f..17ed61598 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -8,7 +8,8 @@ files: includes: - checks - build - - cudatoolkit + - cuda + - cuda_version - py_version - run - test_python @@ -17,11 +18,11 @@ files: test_cpp: output: none includes: - - cudatoolkit + - cuda_version test_python: output: none includes: - - cudatoolkit + - cuda_version - py_version - test_python checks: @@ -32,7 +33,7 @@ files: docs: output: none includes: - - cudatoolkit + - cuda_version - docs - py_version - pytorch_cpu @@ -40,7 +41,8 @@ files: output: none includes: - build - - cudatoolkit + - cuda + - cuda_version - py_version - run - pytorch_cpu @@ -107,39 +109,61 @@ dependencies: cuda: "11.8" packages: - nvcc_linux-aarch64=11.8 + - matrix: + cuda: "12.*" + packages: + - cuda-nvcc + cuda_version: + specific: + - output_types: conda + matrices: + - matrix: + cuda: "11.2" + packages: + - cuda-version=11.2 + - matrix: + cuda: "11.4" + packages: + - cuda-version=11.4 + - matrix: + cuda: "11.5" + packages: + - cuda-version=11.5 + - matrix: + cuda: "11.8" + packages: + - cuda-version=11.8 - matrix: cuda: "12.0" packages: - cuda-version=12.0 - - cuda-nvcc - cudatoolkit: + cuda: specific: - output_types: conda matrices: - matrix: cuda: "11.2" packages: - - cudatoolkit=11.2 + - cudatoolkit - cuda-nvtx=11.4 # oldest available - matrix: cuda: "11.4" packages: - - cudatoolkit=11.4 + - cudatoolkit - cuda-nvtx=11.4 # oldest available - matrix: cuda: "11.5" packages: - - cudatoolkit=11.5 + - cudatoolkit - cuda-nvtx=11.5 - matrix: cuda: "11.8" packages: - - cudatoolkit=11.8 + - cudatoolkit - cuda-nvtx=11.8 - matrix: - cuda: "12.0" + cuda: "12.*" packages: - - cuda-version=12.0 - cuda-cudart-dev - cuda-nvtx checks: From 0ddab62f9e3c44a51cc9755ce3a7b10caef6fb40 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 12 Jan 2024 11:55:40 -0500 Subject: [PATCH 07/16] Remove usages of rapids-env-update (#117) Reference: https://github.com/rapidsai/ops/issues/2766 Replace rapids-env-update with rapids-configure-conda-channels, rapids-configure-sccache, and rapids-date-string. Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) Approvers: - AJ Schmidt (https://github.com/ajschmidt8) URL: https://github.com/rapidsai/wholegraph/pull/117 --- ci/build_cpp.sh | 6 +++++- ci/build_python.sh | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/ci/build_cpp.sh b/ci/build_cpp.sh index 4e1b7bd2a..bd45a3f57 100755 --- a/ci/build_cpp.sh +++ b/ci/build_cpp.sh @@ -3,7 +3,11 @@ set -euo pipefail -source rapids-env-update +rapids-configure-conda-channels + +source rapids-configure-sccache + +source rapids-date-string export CMAKE_GENERATOR=Ninja diff --git a/ci/build_python.sh b/ci/build_python.sh index b79ba92b1..efb7bfe4a 100755 --- a/ci/build_python.sh +++ b/ci/build_python.sh @@ -3,7 +3,11 @@ set -euo pipefail -source rapids-env-update +rapids-configure-conda-channels + +source rapids-configure-sccache + +source rapids-date-string export CMAKE_GENERATOR=Ninja From 4a92d47c9580ede6f475f8ee382342013d08f092 Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Fri, 19 Jan 2024 23:37:03 +0800 Subject: [PATCH 08/16] fix inferencesample option (#107) fix inferencesample option Authors: - Chuang Zhu (https://github.com/chuangz0) Approvers: - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/107 --- cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh | 2 +- cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu | 1 + python/pylibwholegraph/pylibwholegraph/torch/common_options.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh index d2d040a0e..5fa93ee12 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh +++ b/cpp/src/wholememory_ops/functions/nvshmem_device_reference.cuh @@ -29,7 +29,7 @@ class nvshmem_device_reference { : pointer_(static_cast(nvshmem_ref.pointer)), typed_stride_(nvshmem_ref.stride / sizeof(DataTypeT)) { - assert(gref.stride % sizeof(DataTypeT) == 0); + assert(nvshmem_ref.stride % sizeof(DataTypeT) == 0); } __device__ nvshmem_device_reference() = delete; diff --git a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu index a860cbc6c..4051f12bd 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu @@ -185,6 +185,7 @@ wholememory_error_code_t wholememory_gather_nvshmem( p_env_fns, stream); // ungistre + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); if (nvshmemx_buffer_unregister(temp_output_ptr) != 0) { WHOLEMEMORY_ERROR("nvshmemx_buffer_unregister error in wholememory_gather_nvshmem"); } diff --git a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py index 0999fdfe5..42746add8 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py @@ -132,7 +132,7 @@ def add_common_sampler_options(argparser: ArgumentParser): argparser.add_argument( "-s", "--inferencesample", - type=int, + type=str, dest="inferencesample", default="30", help="inference sample count, -1 is all", From ec609abe85760bc523c42e0e4d64ac4616181363 Mon Sep 17 00:00:00 2001 From: linhu-nv <141609318+linhu-nv@users.noreply.github.com> Date: Fri, 19 Jan 2024 23:42:29 +0800 Subject: [PATCH 09/16] fix a bug for embedding optimizer, which leads to undefined behavior (#108) Fix a bug for embedding optimizer, it leads to undefined behavior when embedding_dim is not multiple of 32. Authors: - https://github.com/linhu-nv Approvers: - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/108 --- .../functions/embedding_optimizer_func.cu | 12 ++++++++---- .../wholememory_embedding_gradient_apply_tests.cu | 8 +++++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/cpp/src/wholememory_ops/functions/embedding_optimizer_func.cu b/cpp/src/wholememory_ops/functions/embedding_optimizer_func.cu index e6d751280..0249ba1ac 100644 --- a/cpp/src/wholememory_ops/functions/embedding_optimizer_func.cu +++ b/cpp/src/wholememory_ops/functions/embedding_optimizer_func.cu @@ -214,7 +214,8 @@ __global__ void sgd_optimizer_step_kernel(const IndiceT* indices_ptr, int local_dim_idx = threadIdx.x; float grad_value = 0.0f; int embedding_idx = local_dim_idx + loop_start_idx; - if (embedding_idx < embedding_dim) { grad_value = grads_ptr[embedding_idx]; } + if (embedding_idx >= embedding_dim) { break; } + grad_value = grads_ptr[embedding_idx]; float embedding_value = embedding_ptr[embedding_idx]; grad_value += weight_decay * embedding_value; embedding_value -= lr * grad_value; @@ -392,7 +393,8 @@ __global__ void lazy_adam_optimizer_step_kernel(const IndiceT* indices_ptr, int local_dim_idx = threadIdx.x; float grad_value = 0.0f; int embedding_idx = local_dim_idx + loop_start_idx; - if (embedding_idx < embedding_dim) { grad_value = grads_ptr[local_dim_idx + loop_start_idx]; } + if (embedding_idx >= embedding_dim) { break; } + grad_value = grads_ptr[local_dim_idx + loop_start_idx]; float embedding_value = embedding_ptr[embedding_idx]; if (AdamW) { embedding_value -= lr * weight_decay * embedding_value; @@ -644,7 +646,8 @@ __global__ void ada_grad_optimizer_step_kernel(const IndiceT* indices_ptr, int local_dim_idx = threadIdx.x; float grad_value = 0.0f; int embedding_idx = local_dim_idx + loop_start_idx; - if (embedding_idx < embedding_dim) { grad_value = grads_ptr[embedding_idx]; } + if (embedding_idx >= embedding_dim) { break; } + grad_value = grads_ptr[embedding_idx]; float embedding_value = embedding_ptr[embedding_idx]; grad_value = grad_value + weight_decay * embedding_value; float state_sum = state_sum_ptr[embedding_idx]; @@ -841,7 +844,8 @@ __global__ void rms_prop_optimizer_step_kernel(const IndiceT* indices_ptr, int local_dim_idx = threadIdx.x; float grad_value = 0.0f; int embedding_idx = local_dim_idx + loop_start_idx; - if (embedding_idx < embedding_dim) { grad_value = grads_ptr[local_dim_idx + loop_start_idx]; } + if (embedding_idx >= embedding_dim) { break; } + grad_value = grads_ptr[local_dim_idx + loop_start_idx]; float embedding_value = embedding_ptr[embedding_idx]; grad_value = grad_value + weight_decay * embedding_value; float v = v_ptr[embedding_idx]; diff --git a/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu b/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu index 453b13b41..bb6360fc0 100644 --- a/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_embedding_gradient_apply_tests.cu @@ -149,7 +149,7 @@ struct EmbeddingBackwardTestParams { wholememory_optimizer_type_t optimizer_type = WHOLEMEMORY_OPT_SGD; float cache_ratio = 0.2; bool use_cache = false; - int run_count = 1; + int run_count = 3; float lr_ = 0.1; @@ -428,7 +428,7 @@ void prepare_data_and_reference( int64_t end_entry = (thread_rank + 1) * total_entry_count / thread_world_size; CPUOptimizer cpu_optimizer(¶ms, start_entry, end_entry); int embedding_dim = params.grad_description.sizes[1]; - for (int step = 0; step <= params.run_count; step++) { + for (int step = 0; step < params.run_count; step++) { int step_id = std::min(step, params.run_count - 1); std::vector indices; std::vector> grads; @@ -625,7 +625,7 @@ TEST_P(WholeMemoryEmbeddingBackwardParameterTests, EmbeddingGatherGradientApplyT EXPECT_EQ(cudaStreamSynchronize(nullptr), cudaSuccess); EXPECT_EQ(wholememory_communicator_barrier(wm_comm), WHOLEMEMORY_SUCCESS); - for (int run = 0; run <= params.run_count; run++) { + for (int run = 0; run < params.run_count; run++) { int step_id = std::min(run, params.run_count - 1); auto& rank_indices_vec = step_rank_indices[step_id][world_rank]; auto& rank_grads_vec = step_rank_grads[step_id][world_rank]; @@ -737,6 +737,8 @@ INSTANTIATE_TEST_SUITE_P( EmbeddingBackwardTestParams().set_use_cache().set_indice_count(10000127).set_optimizer_type(WHOLEMEMORY_OPT_ADAGRAD), EmbeddingBackwardTestParams().set_use_cache().set_indice_count(10000127).set_optimizer_type(WHOLEMEMORY_OPT_LAZY_ADAM), #endif + EmbeddingBackwardTestParams().set_entry_count(500).set_indice_count(400).set_embedding_dim(4), + EmbeddingBackwardTestParams().set_embedding_dim(3), EmbeddingBackwardTestParams().set_use_cache().set_grad_stride(131), EmbeddingBackwardTestParams().set_use_cache().set_grad_stride(131).set_optimizer_type( WHOLEMEMORY_OPT_RMSPROP), From c592185deadecce5109abd346780e2bd6e729c6b Mon Sep 17 00:00:00 2001 From: Chang Liu Date: Mon, 22 Jan 2024 06:12:09 -0800 Subject: [PATCH 10/16] Reset WholeGraph communicators during the finalize call (#111) This PR is to address https://github.com/rapidsai/wholegraph/issues/110. Please feel free to comment and suggest here. Authors: - Chang Liu (https://github.com/chang-l) - Brad Rees (https://github.com/BradReesWork) Approvers: - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/111 --- .../pylibwholegraph/pylibwholegraph/torch/comm.py | 13 +++++++++++++ .../pylibwholegraph/torch/initialize.py | 4 +++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/python/pylibwholegraph/pylibwholegraph/torch/comm.py b/python/pylibwholegraph/pylibwholegraph/torch/comm.py index c7cca2e7b..aa15d3a0a 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/comm.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/comm.py @@ -32,6 +32,19 @@ all_comm_local_size = 1 +def reset_communicators(): + global all_comm_world_rank, all_comm_world_size, all_comm_local_rank, all_comm_local_size + global global_communicators, local_node_communicator, local_device_communicator + global_communicators = {} + local_node_communicator = None + local_device_communicator = None + + all_comm_world_rank = 0 + all_comm_world_size = 1 + all_comm_local_rank = 0 + all_comm_local_size = 1 + + def set_world_info(world_rank: int, world_size: int, local_rank: int, local_size: int): """ Set the global world's information. This is used for create common used communicators, like local node communicator, diff --git a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py index 3e1238c2f..3259a0e82 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py @@ -15,7 +15,7 @@ import torch import torch.utils.dlpack import pylibwholegraph.binding.wholememory_binding as wmb -from .comm import set_world_info, get_global_communicator, get_local_node_communicator +from .comm import set_world_info, get_global_communicator, get_local_node_communicator, reset_communicators def init(world_rank: int, world_size: int, local_rank: int, local_size: int): @@ -73,3 +73,5 @@ def finalize(): :return: None """ wmb.finalize() + reset_communicators() + torch.distributed.destroy_process_group() if torch.distributed.is_initialized() else None From 8cb7c5ced9f87d383d17aa16712e3e5a2631d0be Mon Sep 17 00:00:00 2001 From: Paul Taylor <178183+trxcllnt@users.noreply.github.com> Date: Mon, 22 Jan 2024 07:47:50 -0800 Subject: [PATCH 11/16] Fix pip dependencies (#118) Move conda-only dependencies out of `pyproject` and `requirements` sections in `dependencies.yaml`. Authors: - Paul Taylor (https://github.com/trxcllnt) Approvers: - Bradley Dice (https://github.com/bdice) - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/118 --- .gitignore | 1 + dependencies.yaml | 16 +++++++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index e3729dba5..879346c25 100644 --- a/.gitignore +++ b/.gitignore @@ -87,3 +87,4 @@ cpp/.idea/ cpp/cmake-build-debug/ pylibwholegraph/.idea/ pylibwholegraph/cmake-build-debug/ +compile_commands.json diff --git a/dependencies.yaml b/dependencies.yaml index 17ed61598..30a9ea8c5 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -192,21 +192,23 @@ dependencies: packages: [] test_cpp: common: - - output_types: [conda, requirements] + - output_types: [conda] packages: - nccl test_python: common: - - output_types: [conda, requirements] + - output_types: [conda] packages: - c-compiler - cxx-compiler + - nccl + - output_types: [conda, requirements] + packages: - ninja - numpy>=1.17 - pytest - pytest-forked - pytest-xdist - - nccl specific: - output_types: [conda, requirements] matrices: @@ -273,10 +275,12 @@ dependencies: packages: docs: common: + - output_types: [conda] + packages: + - *doxygen - output_types: [conda, requirements] packages: - breathe - - *doxygen - graphviz - ipython - ipykernel @@ -297,10 +301,12 @@ dependencies: clang_tools: common: - output_types: [conda, requirements] + packages: + - gitpython + - output_types: conda packages: - clangxx==16.0.6 - clang-tools==16.0.6 - - gitpython python_build_wheel: common: - output_types: [pyproject] From 503cdcd1e6f38000967b60de59bd811c379c523c Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Wed, 7 Feb 2024 05:52:54 -0800 Subject: [PATCH 12/16] Exclude tests from builds (#127) --- python/pylibwholegraph/setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pylibwholegraph/setup.py b/python/pylibwholegraph/setup.py index 0a1f3fa42..628583288 100644 --- a/python/pylibwholegraph/setup.py +++ b/python/pylibwholegraph/setup.py @@ -50,7 +50,8 @@ def run(self): include=[ "pylibwholegraph", "pylibwholegraph.*", - ] + ], + exclude=["*tests*"], ), package_data={ "pylibwholegraph": ["VERSION", "torch_cpp_ext/*.cpp", From 58602ede32fc4a43a54ac4d4b6754508b413d0b8 Mon Sep 17 00:00:00 2001 From: Ray Douglass <3107146+raydouglass@users.noreply.github.com> Date: Thu, 8 Feb 2024 17:16:10 -0500 Subject: [PATCH 13/16] Revert "Exclude tests from builds (#127)" (#130) This reverts commit 503cdcd1e6f38000967b60de59bd811c379c523c. --- python/pylibwholegraph/setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pylibwholegraph/setup.py b/python/pylibwholegraph/setup.py index 628583288..0a1f3fa42 100644 --- a/python/pylibwholegraph/setup.py +++ b/python/pylibwholegraph/setup.py @@ -50,8 +50,7 @@ def run(self): include=[ "pylibwholegraph", "pylibwholegraph.*", - ], - exclude=["*tests*"], + ] ), package_data={ "pylibwholegraph": ["VERSION", "torch_cpp_ext/*.cpp", From 40cc75d9d7c55a545d044001ef7aee815420fb60 Mon Sep 17 00:00:00 2001 From: linhu-nv <141609318+linhu-nv@users.noreply.github.com> Date: Tue, 13 Feb 2024 04:07:07 +0800 Subject: [PATCH 14/16] Logging level (#123) Add log-level-control interface in pylibwholegraph to allow user to control log level in wholememory. Authors: - https://github.com/linhu-nv - Ray Douglass (https://github.com/raydouglass) Approvers: - Chuang Zhu (https://github.com/chuangz0) --- cpp/include/wholememory/wholememory.h | 3 ++- cpp/src/wholememory/initialize.cpp | 4 +++- cpp/src/wholememory/initialize.hpp | 2 +- cpp/src/wholememory/wholememory.cpp | 5 ++++- .../examples/node_classfication.py | 3 ++- .../binding/wholememory_binding.pyx | 6 +++--- .../pylibwholegraph/torch/common_options.py | 8 +++++++- .../pylibwholegraph/torch/initialize.py | 19 +++++++++++++------ 8 files changed, 35 insertions(+), 15 deletions(-) diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index 7aac0e874..885dddd8e 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -83,9 +83,10 @@ enum wholememory_distributed_backend_t { /** * Initialize WholeMemory library * @param flags : reserved should be 0 + * @param wm_log_level : wholememory log level, the default level is "info" * @return : wholememory_error_code_t */ -wholememory_error_code_t wholememory_init(unsigned int flags); +wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level = 3); /** * Finalize WholeMemory library diff --git a/cpp/src/wholememory/initialize.cpp b/cpp/src/wholememory/initialize.cpp index 2e80ab3c3..b7d1e54ac 100644 --- a/cpp/src/wholememory/initialize.cpp +++ b/cpp/src/wholememory/initialize.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include "communicator.hpp" @@ -32,7 +33,7 @@ static bool is_wm_init = false; static const std::string RAFT_NAME = "wholememory"; static cudaDeviceProp* device_props = nullptr; -wholememory_error_code_t init(unsigned int flags) noexcept +wholememory_error_code_t init(unsigned int flags, unsigned int wm_log_level) noexcept { try { std::unique_lock lock(mu); @@ -50,6 +51,7 @@ wholememory_error_code_t init(unsigned int flags) noexcept WM_CUDA_CHECK(cudaGetDeviceProperties(device_props + i, i)); } is_wm_init = true; + wholememory::set_log_level(std::pow(10, wm_log_level)); return WHOLEMEMORY_SUCCESS; } catch (raft::logic_error& logic_error) { WHOLEMEMORY_ERROR("init failed, logic_error=%s", logic_error.what()); diff --git a/cpp/src/wholememory/initialize.hpp b/cpp/src/wholememory/initialize.hpp index 2b9d0366b..77870f989 100644 --- a/cpp/src/wholememory/initialize.hpp +++ b/cpp/src/wholememory/initialize.hpp @@ -21,7 +21,7 @@ namespace wholememory { -wholememory_error_code_t init(unsigned int flags) noexcept; +wholememory_error_code_t init(unsigned int flags, unsigned int wm_log_level) noexcept; wholememory_error_code_t finalize() noexcept; diff --git a/cpp/src/wholememory/wholememory.cpp b/cpp/src/wholememory/wholememory.cpp index dbdce12e6..2f5f33a36 100644 --- a/cpp/src/wholememory/wholememory.cpp +++ b/cpp/src/wholememory/wholememory.cpp @@ -25,7 +25,10 @@ extern "C" { #endif -wholememory_error_code_t wholememory_init(unsigned int flags) { return wholememory::init(flags); } +wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level) +{ + return wholememory::init(flags, wm_log_level); +} wholememory_error_code_t wholememory_finalize() { return wholememory::finalize(); } diff --git a/python/pylibwholegraph/examples/node_classfication.py b/python/pylibwholegraph/examples/node_classfication.py index 27b035fb9..fb77ffb88 100644 --- a/python/pylibwholegraph/examples/node_classfication.py +++ b/python/pylibwholegraph/examples/node_classfication.py @@ -130,7 +130,8 @@ def main_func(): wgth.get_world_size(), wgth.get_local_rank(), wgth.get_local_size(), - args.distributed_backend_type + args.distributed_backend_type, + args.log_level ) if args.use_cpp_ext: diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index 97d84c228..263fbd62f 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -71,7 +71,7 @@ cdef extern from "wholememory/wholememory.h": WHOLEMEMORY_DB_NONE "WHOLEMEMORY_DB_NONE" WHOLEMEMORY_DB_NCCL "WHOLEMEMORY_DB_NCCL" WHOLEMEMORY_DB_NVSHMEM "WHOLEMEMORY_DB_NVSHMEM" - cdef wholememory_error_code_t wholememory_init(unsigned int flags) + cdef wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level) cdef wholememory_error_code_t wholememory_finalize() @@ -981,8 +981,8 @@ cdef class PyWholeMemoryUniqueID: def __dlpack_device__(self): return (kDLCPU, 0) -def init(unsigned int flags): - check_wholememory_error_code(wholememory_init(flags)) +def init(unsigned int flags, unsigned int wm_log_level = 3): + check_wholememory_error_code(wholememory_init(flags, wm_log_level)) def finalize(): check_wholememory_error_code(wholememory_finalize()) diff --git a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py index 42746add8..14955305b 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py @@ -9,7 +9,7 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. +# limitations under the License.ß from argparse import ArgumentParser @@ -68,6 +68,12 @@ def add_training_options(argparser: ArgumentParser): default="nccl", help="distributed backend type, should be: nccl, nvshmem ", ) + argparser.add_argument( + "--log-level", + dest="log_level", + default="info", + help="Logging level of wholegraph, should be: trace, debug, info, warn, error" + ) def add_common_graph_options(argparser: ArgumentParser): diff --git a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py index 3259a0e82..94ee74261 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py @@ -18,12 +18,13 @@ from .comm import set_world_info, get_global_communicator, get_local_node_communicator, reset_communicators -def init(world_rank: int, world_size: int, local_rank: int, local_size: int): - wmb.init(0) +def init(world_rank: int, world_size: int, local_rank: int, local_size: int, wm_log_level="info"): + log_level_dic = {"error": 1, "warn": 2, "info": 3, "debug": 4, "trace": 5} + wmb.init(0, log_level_dic[wm_log_level]) set_world_info(world_rank, world_size, local_rank, local_size) -def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size: int): +def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size: int, wm_log_level): r"""Init WholeGraph environment for PyTorch. :param world_rank: world rank of current process :param world_size: world size of all processes @@ -44,7 +45,8 @@ def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size print("[WARNING] MASTER_PORT not set, resetting to 12335") os.environ["MASTER_PORT"] = "12335" - wmb.init(0) + log_level_dic = {"error": 1, "warn": 2, "info": 3, "debug": 4, "trace": 5} + wmb.init(0, log_level_dic[wm_log_level]) torch.set_num_threads(1) torch.cuda.set_device(local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") @@ -52,7 +54,12 @@ def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size def init_torch_env_and_create_wm_comm( - world_rank: int, world_size: int, local_rank: int, local_size: int , distributed_backend_type="nccl" + world_rank: int, + world_size: int, + local_rank: int, + local_size: int, + distributed_backend_type="nccl", + wm_log_level="info" ): r"""Init WholeGraph environment for PyTorch and create single communicator for all ranks. :param world_rank: world rank of current process @@ -61,7 +68,7 @@ def init_torch_env_and_create_wm_comm( :param local_size: local size :return: global and local node Communicator """ - init_torch_env(world_rank, world_size, local_rank, local_size) + init_torch_env(world_rank, world_size, local_rank, local_size, wm_log_level) global_comm = get_global_communicator(distributed_backend_type) local_comm = get_local_node_communicator() From b49be9dd978dff07bcdd3191dd3e81b0738eb598 Mon Sep 17 00:00:00 2001 From: linhu-nv <141609318+linhu-nv@users.noreply.github.com> Date: Tue, 13 Feb 2024 04:07:21 +0800 Subject: [PATCH 15/16] allow users to control gather/scatter sms (#124) Allow users to control the SMs number while gather/scatter from raw embeddings. Authors: - https://github.com/linhu-nv - Ray Douglass (https://github.com/raydouglass) Approvers: - Chuang Zhu (https://github.com/chuangz0) --- cpp/include/wholememory/embedding.h | 4 ++- cpp/include/wholememory/wholememory_op.h | 8 ++++-- cpp/src/wholememory/embedding.cpp | 23 ++++++++++++--- cpp/src/wholememory/embedding.hpp | 2 ++ .../wholememory_ops/functions/gather_func.cu | 28 +++++++++++++------ ...r_func_impl_floating_data_int32_indices.cu | 11 +++++--- ...r_func_impl_floating_data_int64_indices.cu | 11 +++++--- ...er_func_impl_integer_data_int32_indices.cu | 11 +++++--- ...er_func_impl_integer_data_int64_indices.cu | 11 +++++--- .../functions/gather_scatter_func.cuh | 8 ++++-- .../functions/gather_scatter_func.h | 6 ++-- ...r_func_impl_floating_data_int32_indices.cu | 12 +++++--- ...r_func_impl_floating_data_int64_indices.cu | 12 +++++--- ...er_func_impl_integer_data_int32_indices.cu | 12 +++++--- ...er_func_impl_integer_data_int64_indices.cu | 12 +++++--- .../functions/nvshmem_gather_scatter_func.cuh | 10 ++++--- ...r_func_impl_floating_data_int32_indices.cu | 12 +++++--- ...r_func_impl_floating_data_int64_indices.cu | 12 +++++--- ...er_func_impl_integer_data_int32_indices.cu | 12 +++++--- ...er_func_impl_integer_data_int64_indices.cu | 12 +++++--- .../wholememory_ops/functions/scatter_func.cu | 28 +++++++++++++------ ...r_func_impl_floating_data_int32_indices.cu | 11 +++++--- ...r_func_impl_floating_data_int64_indices.cu | 11 +++++--- ...er_func_impl_integer_data_int32_indices.cu | 11 +++++--- ...er_func_impl_integer_data_int64_indices.cu | 11 +++++--- cpp/src/wholememory_ops/gather_op.cpp | 9 ++++-- cpp/src/wholememory_ops/gather_op_impl.h | 12 +++++--- .../wholememory_ops/gather_op_impl_mapped.cu | 13 +++++++-- .../wholememory_ops/gather_op_impl_nccl.cu | 17 +++++++---- .../wholememory_ops/gather_op_impl_nvshmem.cu | 21 +++++++++----- cpp/src/wholememory_ops/scatter_op.cpp | 9 ++++-- cpp/src/wholememory_ops/scatter_op_impl.h | 12 +++++--- .../scatter_op_impl.nvshmem.cu | 21 +++++++++----- .../wholememory_ops/scatter_op_impl_mapped.cu | 13 +++++++-- .../wholememory_ops/scatter_op_impl_nccl.cu | 15 ++++++---- .../binding/wholememory_binding.pyx | 15 ++++++---- .../pylibwholegraph/torch/embedding.py | 6 ++++ 37 files changed, 319 insertions(+), 145 deletions(-) diff --git a/cpp/include/wholememory/embedding.h b/cpp/include/wholememory/embedding.h index 8a8b9af3a..fe70ffadd 100644 --- a/cpp/include/wholememory/embedding.h +++ b/cpp/include/wholememory/embedding.h @@ -130,6 +130,7 @@ wholememory_error_code_t wholememory_destroy_embedding_cache_policy( * @param memory_location : Memory Location of the underlying WholeMemory * @param optimizer : Optimizer to use for training, if don't train embedding, use nullptr * @param cache_policy : Cache policy for this embedding, if don't use cache, use nullptr + * @param user_defined_sms : User-defined sms number for raw embedding gather/scatter * @return : wholememory_error_code_t */ wholememory_error_code_t wholememory_create_embedding( @@ -139,7 +140,8 @@ wholememory_error_code_t wholememory_create_embedding( wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, wholememory_embedding_optimizer_t optimizer, - wholememory_embedding_cache_policy_t cache_policy); + wholememory_embedding_cache_policy_t cache_policy, + int user_defined_sms = -1); /** * Destroy WholeMemory Embedding diff --git a/cpp/include/wholememory/wholememory_op.h b/cpp/include/wholememory/wholememory_op.h index 410953c12..146245aa2 100644 --- a/cpp/include/wholememory/wholememory_op.h +++ b/cpp/include/wholememory/wholememory_op.h @@ -30,13 +30,15 @@ extern "C" { * @param output_tensor : output tensor to gather to, should NOT be WholeMemoryTensor * @param p_env_fns : pointers to environment functions. * @param stream : cudaStream_t to use. + * @param gather_sms : the number of stream multiprocessor used in gather kernel * @return : wholememory_error_code_t */ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_tensor, wholememory_tensor_t indices_tensor, wholememory_tensor_t output_tensor, wholememory_env_func_t* p_env_fns, - void* stream); + void* stream, + int gather_sms = -1); /** * Scatter Op @@ -45,13 +47,15 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten * @param wholememory_tensor : WholeMemory Tensor of embedding table. * @param p_env_fns : pointers to environment functions. * @param stream : cudaStream_t to use. + * @param scatter_sms : the number of stream multiprocessor used in scatter kernel * @return : wholememory_error_code_t */ wholememory_error_code_t wholememory_scatter(wholememory_tensor_t input_tensor, wholememory_tensor_t indices_tensor, wholememory_tensor_t wholememory_tensor, wholememory_env_func_t* p_env_fns, - void* stream); + void* stream, + int scatter_sms = -1); /** * Just a test function, diff --git a/cpp/src/wholememory/embedding.cpp b/cpp/src/wholememory/embedding.cpp index 29e557a2e..b54d202c6 100644 --- a/cpp/src/wholememory/embedding.cpp +++ b/cpp/src/wholememory/embedding.cpp @@ -403,6 +403,21 @@ wholememory_error_code_t embedding_base::destroy_optimizer_states() noexcept return WHOLEMEMORY_SUCCESS; } +wholememory_error_code_t embedding_base::set_gather_sms(int sms) noexcept +{ + if (sms != -1) { + if (sms <= 0) { + WHOLEMEMORY_WARN("Illegal SM number for gather/scatter! Will use default size."); + sms = -1; + } else if (sms > 1568) { + WHOLEMEMORY_WARN("SM number for gather/scatter is too large! Will use default size."); + sms = -1; + } + } + gather_sms_ = sms; + return WHOLEMEMORY_SUCCESS; +} + void embedding_base::deallocate() noexcept { if (optimizer != nullptr) { @@ -477,7 +492,7 @@ wholememory_error_code_t noncached_embedding::gather(wholememory_tensor_t indice cudaStream_t stream) noexcept { WHOLEMEMORY_RETURN_ON_FAIL( - wholememory_gather(allocated_embedding, indices, output, p_env_fns, stream)); + wholememory_gather(allocated_embedding, indices, output, p_env_fns, stream, gather_sms_)); return WHOLEMEMORY_SUCCESS; } @@ -845,7 +860,8 @@ wholememory_error_code_t wholememory_create_embedding( wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, wholememory_embedding_optimizer_t optimizer, - wholememory_embedding_cache_policy_t cache_policy) + wholememory_embedding_cache_policy_t cache_policy, + int user_defined_sms) { wholememory_matrix_description_t embedding_matrix_description; if (!wholememory_convert_tensor_desc_to_matrix(&embedding_matrix_description, @@ -909,10 +925,9 @@ wholememory_error_code_t wholememory_create_embedding( } else { embedding_impl_ptr = new wholememory::noncached_embedding(); } - WHOLEMEMORY_RETURN_ON_FAIL(embedding_impl_ptr->allocate( &embedding_matrix_description, comm, memory_type, memory_location, cache_policy, optimizer)); - + embedding_impl_ptr->set_gather_sms(user_defined_sms); *wholememory_embedding = static_cast(embedding_impl_ptr); return WHOLEMEMORY_SUCCESS; } diff --git a/cpp/src/wholememory/embedding.hpp b/cpp/src/wholememory/embedding.hpp index 72d7bc456..1dae0f4db 100644 --- a/cpp/src/wholememory/embedding.hpp +++ b/cpp/src/wholememory/embedding.hpp @@ -81,6 +81,7 @@ class embedding_base : public wholememory_embedding_ { virtual wholememory_error_code_t drop_all_caches(cudaStream_t stream) const noexcept; wholememory::embedding_cache_base* get_cache_ptr() const { return cache_ptr_; } + wholememory_error_code_t set_gather_sms(int sms) noexcept; protected: virtual wholememory_error_code_t init_optimizer_states() noexcept @@ -96,6 +97,7 @@ class embedding_base : public wholememory_embedding_ { wholememory_error_code_t create_optimizer_states() noexcept; wholememory_error_code_t destroy_optimizer_states() noexcept; + int gather_sms_; wholememory_comm_t raw_embedding_comm_ = nullptr; wholememory::embedding_cache_base* cache_ptr_ = nullptr; wholememory::embedding_optimizer_impl_base* optimizer_impl_base_ = nullptr; diff --git a/cpp/src/wholememory_ops/functions/gather_func.cu b/cpp/src/wholememory_ops/functions/gather_func.cu index 052954c72..0b79f0f15 100644 --- a/cpp/src/wholememory_ops/functions/gather_func.cu +++ b/cpp/src/wholememory_ops/functions/gather_func.cu @@ -26,28 +26,32 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ wholememory_array_description_t indices_desc, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream); + cudaStream_t stream, + int gather_sms); wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream); + cudaStream_t stream, + int gather_sms); wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream); + cudaStream_t stream, + int gather_sms); wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream); + cudaStream_t stream, + int gather_sms); wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, @@ -55,7 +59,8 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, wholememory_array_description_t indices_desc, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { try { bool embedding_is_float = wholememory_dtype_is_floating_number(embedding_desc.dtype); @@ -73,7 +78,8 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, wholememory_array_description_t, void*, wholememory_matrix_description_t, - cudaStream_t) = nullptr; + cudaStream_t, + int) = nullptr; if (embedding_is_float) { if (indices_desc.dtype == WHOLEMEMORY_DT_INT) { p_gather_func = gather_floating_int32_func; @@ -87,8 +93,14 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, p_gather_func = gather_integer_int64_func; } } - return p_gather_func( - embedding_gref, embedding_desc, indices, indices_desc, output, output_desc, stream); + return p_gather_func(embedding_gref, + embedding_desc, + indices, + indices_desc, + output, + output_desc, + stream, + gather_sms); } catch (const wholememory::cuda_error& rle) { return WHOLEMEMORY_LOGIC_ERROR; } catch (const wholememory::logic_error& le) { diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu index 09b149ce3..c7679c508 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu @@ -29,10 +29,11 @@ void gather_floating_int32_temp_func(wholememory_gref_t embedding_gref, int64_t indice_count, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream); + embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt32, @@ -46,7 +47,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding wholememory_array_description_t indices_desc, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_floating_number(embedding_desc.dtype)); @@ -63,7 +65,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding indices_desc.size, output, output_desc, - stream); + stream, + gather_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("gather CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu index 39fa52750..af9d6d6ec 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu @@ -29,10 +29,11 @@ void gather_floating_int64_temp_func(wholememory_gref_t embedding_gref, int64_t indice_count, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream); + embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt64, @@ -46,7 +47,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding wholememory_array_description_t indices_desc, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_floating_number(embedding_desc.dtype)); @@ -63,7 +65,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding indices_desc.size, output, output_desc, - stream); + stream, + gather_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("gather CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu index 3c64cc6a8..bdb7c0be8 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu @@ -29,10 +29,11 @@ void gather_integer_int32_temp_func(wholememory_gref_t embedding_gref, int64_t indice_count, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream); + embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt32, @@ -46,7 +47,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ wholememory_array_description_t indices_desc, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_integer_number(embedding_desc.dtype)); @@ -63,7 +65,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ indices_desc.size, output, output_desc, - stream); + stream, + gather_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("gather CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu index 811a14699..6a6c7f330 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu @@ -29,10 +29,11 @@ void gather_integer_int64_temp_func(wholememory_gref_t embedding_gref, int64_t indice_count, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream); + embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt64, @@ -46,7 +47,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_ wholememory_array_description_t indices_desc, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_integer_number(embedding_desc.dtype)); @@ -63,7 +65,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_ indices_desc.size, output, output_desc, - stream); + stream, + gather_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("gather CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh index 62a7dcb5e..87c89d9c2 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh @@ -316,7 +316,8 @@ void gather_temp_func(wholememory_gref_t embedding_gref, int64_t indice_count, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { WHOLEMEMORY_EXPECTS(output_desc.sizes[0] == indice_count, "gather_func, output shape[0]=%ld, but indice_count=%ld", @@ -365,6 +366,7 @@ void gather_temp_func(wholememory_gref_t embedding_gref, } int block_size = 1024; int block_count = indice_count > 1568 ? 1568 : indice_count; + if (gather_sms != -1) block_count = gather_sms; kernel_fn<<>>(embedding_gref, embedding_desc, static_cast(indices), @@ -461,7 +463,8 @@ void scatter_temp_func(const void* input, int64_t indice_count, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { WHOLEMEMORY_EXPECTS(input_desc.sizes[0] == indice_count, "scatter_func, input shape[0]=%ld, but indice_count=%ld", @@ -506,6 +509,7 @@ void scatter_temp_func(const void* input, } int block_size = 256; int block_count = indice_count > 1568 ? 1568 : indice_count; + if (scatter_sms != -1) block_count = scatter_sms; kernel_fn<<>>(static_cast(input), input_desc, static_cast(indices), diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.h b/cpp/src/wholememory_ops/functions/gather_scatter_func.h index 0d42fbe4a..0c0b9e4a4 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.h +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.h @@ -27,7 +27,8 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, wholememory_array_description_t indices_desc, void* output, wholememory_matrix_description_t output_desc, - cudaStream_t stream); + cudaStream_t stream, + int gather_sms = -1); wholememory_error_code_t scatter_func(const void* input, wholememory_matrix_description_t input_desc, @@ -35,6 +36,7 @@ wholememory_error_code_t scatter_func(const void* input, wholememory_array_description_t indices_desc, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream); + cudaStream_t stream, + int scatter_sms = -1); } // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int32_indices.cu index 9b951a6bf..74072280c 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int32_indices.cu @@ -33,7 +33,8 @@ void nvshmem_gather_floating_int32_temp_func(wholememory_comm_t wm_comm, wholememory_matrix_description_t output_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { nvshmem_gather_temp_get_mem_sort_idx_func( wm_comm, @@ -46,7 +47,8 @@ void nvshmem_gather_floating_int32_temp_func(wholememory_comm_t wm_comm, output_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemGatherFuncFloatingInt32, @@ -65,7 +67,8 @@ wholememory_error_code_t nvshmem_gather_floating_int32_func( wholememory_matrix_description_t output_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_floating_number(embedding_desc.dtype)); @@ -85,7 +88,8 @@ wholememory_error_code_t nvshmem_gather_floating_int32_func( output_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + gather_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("gather CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int64_indices.cu index 34b37be2b..65c9fec77 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_floating_data_int64_indices.cu @@ -33,7 +33,8 @@ void nvshmem_gather_floating_int64_temp_func(wholememory_comm_t wm_comm, wholememory_matrix_description_t output_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { nvshmem_gather_temp_get_mem_sort_idx_func( wm_comm, @@ -46,7 +47,8 @@ void nvshmem_gather_floating_int64_temp_func(wholememory_comm_t wm_comm, output_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemGatherFuncFloatingInt64, @@ -65,7 +67,8 @@ wholememory_error_code_t nvshmem_gather_floating_int64_func( wholememory_matrix_description_t output_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_floating_number(embedding_desc.dtype)); @@ -85,7 +88,8 @@ wholememory_error_code_t nvshmem_gather_floating_int64_func( output_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + gather_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("gather CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int32_indices.cu index 41a81cef9..b97e47760 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int32_indices.cu @@ -33,7 +33,8 @@ void nvshmem_gather_integer_int32_temp_func(wholememory_comm_t wm_comm, wholememory_matrix_description_t output_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { nvshmem_gather_temp_get_mem_sort_idx_func( wm_comm, @@ -46,7 +47,8 @@ void nvshmem_gather_integer_int32_temp_func(wholememory_comm_t wm_comm, output_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemGatherFuncIntegerInt32, @@ -65,7 +67,8 @@ wholememory_error_code_t nvshmem_gather_integer_int32_func( wholememory_matrix_description_t output_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_integer_number(embedding_desc.dtype)); @@ -85,7 +88,8 @@ wholememory_error_code_t nvshmem_gather_integer_int32_func( output_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + gather_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("gather CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int64_indices.cu index 2194c3ad9..a1876a322 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_func_impl_integer_data_int64_indices.cu @@ -33,7 +33,8 @@ void nvshmem_gather_integer_int64_temp_func(wholememory_comm_t wm_comm, wholememory_matrix_description_t output_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { nvshmem_gather_temp_get_mem_sort_idx_func( wm_comm, @@ -46,7 +47,8 @@ void nvshmem_gather_integer_int64_temp_func(wholememory_comm_t wm_comm, output_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemGatherFuncIntegerInt64, @@ -65,7 +67,8 @@ wholememory_error_code_t nvshmem_gather_integer_int64_func( wholememory_matrix_description_t output_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_integer_number(embedding_desc.dtype)); @@ -85,7 +88,8 @@ wholememory_error_code_t nvshmem_gather_integer_int64_func( output_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + gather_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("gather CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh index 219c5801d..ea905cd93 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/nvshmem_gather_scatter_func.cuh @@ -315,9 +315,8 @@ void nvshmem_gather_temp_get_mem_sort_idx_func(wholememory_comm_t wm_comm, wholememory_matrix_description_t output_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream - -) + cudaStream_t stream, + int gather_sms) { wm_thrust_allocator thrust_allocator(p_env_fns); @@ -453,6 +452,7 @@ void nvshmem_gather_temp_get_mem_sort_idx_func(wholememory_comm_t wm_comm, } WM_CUDA_CHECK(cudaGetLastError()); + (void)gather_sms; } template @@ -556,7 +556,8 @@ void nvshmem_scatter_temp_put_mem_sort_idx_func(wholememory_comm_t wm_comm, wholememory_matrix_description_t embedding_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { wm_thrust_allocator thrust_allocator(p_env_fns); @@ -696,6 +697,7 @@ void nvshmem_scatter_temp_put_mem_sort_idx_func(wholememory_comm_t wm_comm, } WM_CUDA_CHECK(cudaGetLastError()); + (void)scatter_sms; } }; // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int32_indices.cu index 51b358995..3fe3a96fa 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int32_indices.cu @@ -33,7 +33,8 @@ void nvshmem_scatter_floating_int32_temp_func(wholememory_comm_t wm_comm, wholememory_matrix_description_t embedding_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { nvshmem_scatter_temp_put_mem_sort_idx_func( wm_comm, @@ -46,7 +47,8 @@ void nvshmem_scatter_floating_int32_temp_func(wholememory_comm_t wm_comm, embedding_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemScatterFuncFloatingInt32, @@ -65,7 +67,8 @@ wholememory_error_code_t nvshmem_scatter_floating_int32_func( wholememory_matrix_description_t embedding_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_floating_number(embedding_desc.dtype)); @@ -85,7 +88,8 @@ wholememory_error_code_t nvshmem_scatter_floating_int32_func( embedding_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + scatter_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("scatter CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int64_indices.cu index a2cfe0bf4..51107a5fc 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_floating_data_int64_indices.cu @@ -33,7 +33,8 @@ void nvshmem_scatter_floating_int64_temp_func(wholememory_comm_t wm_comm, wholememory_matrix_description_t embedding_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { nvshmem_scatter_temp_put_mem_sort_idx_func( wm_comm, @@ -46,7 +47,8 @@ void nvshmem_scatter_floating_int64_temp_func(wholememory_comm_t wm_comm, embedding_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemScatterFuncFloatingInt64, @@ -65,7 +67,8 @@ wholememory_error_code_t nvshmem_scatter_floating_int64_func( wholememory_matrix_description_t embedding_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_floating_number(embedding_desc.dtype)); @@ -85,7 +88,8 @@ wholememory_error_code_t nvshmem_scatter_floating_int64_func( embedding_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + scatter_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("scatter CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int32_indices.cu index 5340875f5..4530442be 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int32_indices.cu @@ -33,7 +33,8 @@ void nvshmem_scatter_integer_int32_temp_func(wholememory_comm_t wm_comm, wholememory_matrix_description_t embedding_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { nvshmem_scatter_temp_put_mem_sort_idx_func( wm_comm, @@ -46,7 +47,8 @@ void nvshmem_scatter_integer_int32_temp_func(wholememory_comm_t wm_comm, embedding_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemScatterFuncIntegerInt32, @@ -65,7 +67,8 @@ wholememory_error_code_t nvshmem_scatter_integer_int32_func( wholememory_matrix_description_t embedding_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_floating_number(embedding_desc.dtype)); @@ -85,7 +88,8 @@ wholememory_error_code_t nvshmem_scatter_integer_int32_func( embedding_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + scatter_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("scatter CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int64_indices.cu index 5ca51169e..bee8fa869 100644 --- a/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/nvshmem_scatter_func_impl_integer_data_int64_indices.cu @@ -33,7 +33,8 @@ void nvshmem_scatter_integer_int64_temp_func(wholememory_comm_t wm_comm, wholememory_matrix_description_t embedding_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { nvshmem_scatter_temp_put_mem_sort_idx_func( wm_comm, @@ -46,7 +47,8 @@ void nvshmem_scatter_integer_int64_temp_func(wholememory_comm_t wm_comm, embedding_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(NvshmemScatterFuncIntegerInt64, @@ -65,7 +67,8 @@ wholememory_error_code_t nvshmem_scatter_integer_int64_func( wholememory_matrix_description_t embedding_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_floating_number(embedding_desc.dtype)); @@ -85,7 +88,8 @@ wholememory_error_code_t nvshmem_scatter_integer_int64_func( embedding_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + scatter_sms; } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("scatter CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/scatter_func.cu b/cpp/src/wholememory_ops/functions/scatter_func.cu index 8767fbb49..6ced57f29 100644 --- a/cpp/src/wholememory_ops/functions/scatter_func.cu +++ b/cpp/src/wholememory_ops/functions/scatter_func.cu @@ -27,14 +27,16 @@ wholememory_error_code_t scatter_integer_int32_func(const void* input, wholememory_array_description_t indices_desc, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream); + cudaStream_t stream, + int scatter_sms); wholememory_error_code_t scatter_integer_int64_func(const void* input, wholememory_matrix_description_t input_desc, void* indices, wholememory_array_description_t indices_desc, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream); + cudaStream_t stream, + int scatter_sms); wholememory_error_code_t scatter_floating_int32_func( const void* input, wholememory_matrix_description_t input_desc, @@ -42,7 +44,8 @@ wholememory_error_code_t scatter_floating_int32_func( wholememory_array_description_t indices_desc, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream); + cudaStream_t stream, + int scatter_sms); wholememory_error_code_t scatter_floating_int64_func( const void* input, wholememory_matrix_description_t input_desc, @@ -50,7 +53,8 @@ wholememory_error_code_t scatter_floating_int64_func( wholememory_array_description_t indices_desc, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream); + cudaStream_t stream, + int scatter_sms); wholememory_error_code_t scatter_func(const void* input, wholememory_matrix_description_t input_desc, @@ -58,7 +62,8 @@ wholememory_error_code_t scatter_func(const void* input, wholememory_array_description_t indices_desc, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { try { bool embedding_is_float = wholememory_dtype_is_floating_number(embedding_desc.dtype); @@ -76,7 +81,8 @@ wholememory_error_code_t scatter_func(const void* input, wholememory_array_description_t, wholememory_gref_t, wholememory_matrix_description_t, - cudaStream_t) = nullptr; + cudaStream_t, + int) = nullptr; if (embedding_is_float) { if (indices_desc.dtype == WHOLEMEMORY_DT_INT) { p_scatter_func = scatter_floating_int32_func; @@ -90,8 +96,14 @@ wholememory_error_code_t scatter_func(const void* input, p_scatter_func = scatter_integer_int64_func; } } - return p_scatter_func( - input, input_desc, indices, indices_desc, embedding_gref, embedding_desc, stream); + return p_scatter_func(input, + input_desc, + indices, + indices_desc, + embedding_gref, + embedding_desc, + stream, + scatter_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("scatter CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/scatter_func_impl_floating_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/scatter_func_impl_floating_data_int32_indices.cu index 153adc257..607086224 100644 --- a/cpp/src/wholememory_ops/functions/scatter_func_impl_floating_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/scatter_func_impl_floating_data_int32_indices.cu @@ -29,10 +29,11 @@ void scatter_floating_int32_temp_func(const void* input, int64_t indice_count, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { scatter_temp_func( - input, input_desc, indices, indice_count, embedding_gref, embedding_desc, stream); + input, input_desc, indices, indice_count, embedding_gref, embedding_desc, stream, scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(ScatterFuncFloatingInt32, @@ -47,7 +48,8 @@ wholememory_error_code_t scatter_floating_int32_func( wholememory_array_description_t indices_desc, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_floating_number(embedding_desc.dtype)); @@ -64,7 +66,8 @@ wholememory_error_code_t scatter_floating_int32_func( indices_desc.size, embedding_gref, embedding_desc, - stream); + stream, + scatter_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("scatter CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/scatter_func_impl_floating_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/scatter_func_impl_floating_data_int64_indices.cu index 18d075e66..8aa5b578f 100644 --- a/cpp/src/wholememory_ops/functions/scatter_func_impl_floating_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/scatter_func_impl_floating_data_int64_indices.cu @@ -29,10 +29,11 @@ void scatter_floating_int64_temp_func(const void* input, int64_t indice_count, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { scatter_temp_func( - input, input_desc, indices, indice_count, embedding_gref, embedding_desc, stream); + input, input_desc, indices, indice_count, embedding_gref, embedding_desc, stream, scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(ScatterFuncFloatingInt64, @@ -47,7 +48,8 @@ wholememory_error_code_t scatter_floating_int64_func( wholememory_array_description_t indices_desc, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_floating_number(embedding_desc.dtype)); @@ -64,7 +66,8 @@ wholememory_error_code_t scatter_floating_int64_func( indices_desc.size, embedding_gref, embedding_desc, - stream); + stream, + scatter_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("scatter CUDA LOGIC Error %s\n", wle.what()); fflush(stdout); diff --git a/cpp/src/wholememory_ops/functions/scatter_func_impl_integer_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/scatter_func_impl_integer_data_int32_indices.cu index 46c2c9f98..d22efa0aa 100644 --- a/cpp/src/wholememory_ops/functions/scatter_func_impl_integer_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/scatter_func_impl_integer_data_int32_indices.cu @@ -29,10 +29,11 @@ void scatter_integer_int32_temp_func(const void* input, int64_t indice_count, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { scatter_temp_func( - input, input_desc, indices, indice_count, embedding_gref, embedding_desc, stream); + input, input_desc, indices, indice_count, embedding_gref, embedding_desc, stream, scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(ScatterFuncIntegerInt32, @@ -46,7 +47,8 @@ wholememory_error_code_t scatter_integer_int32_func(const void* input, wholememory_array_description_t indices_desc, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_integer_number(embedding_desc.dtype)); @@ -63,7 +65,8 @@ wholememory_error_code_t scatter_integer_int32_func(const void* input, indices_desc.size, embedding_gref, embedding_desc, - stream); + stream, + scatter_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("scatter CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/functions/scatter_func_impl_integer_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/scatter_func_impl_integer_data_int64_indices.cu index ad33036a0..783da9eca 100644 --- a/cpp/src/wholememory_ops/functions/scatter_func_impl_integer_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/scatter_func_impl_integer_data_int64_indices.cu @@ -29,10 +29,11 @@ void scatter_integer_int64_temp_func(const void* input, int64_t indice_count, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { scatter_temp_func( - input, input_desc, indices, indice_count, embedding_gref, embedding_desc, stream); + input, input_desc, indices, indice_count, embedding_gref, embedding_desc, stream, scatter_sms); } REGISTER_DISPATCH_TWO_TYPES(ScatterFuncIntegerInt64, @@ -46,7 +47,8 @@ wholememory_error_code_t scatter_integer_int64_func(const void* input, wholememory_array_description_t indices_desc, wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { try { WHOLEMEMORY_CHECK(wholememory_dtype_is_integer_number(embedding_desc.dtype)); @@ -63,7 +65,8 @@ wholememory_error_code_t scatter_integer_int64_func(const void* input, indices_desc.size, embedding_gref, embedding_desc, - stream); + stream, + scatter_sms); } catch (const wholememory::cuda_error& wle) { WHOLEMEMORY_ERROR("scatter CUDA LOGIC Error %s\n", wle.what()); return WHOLEMEMORY_LOGIC_ERROR; diff --git a/cpp/src/wholememory_ops/gather_op.cpp b/cpp/src/wholememory_ops/gather_op.cpp index 3a7a0354a..a6b2e97b5 100644 --- a/cpp/src/wholememory_ops/gather_op.cpp +++ b/cpp/src/wholememory_ops/gather_op.cpp @@ -24,7 +24,8 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten wholememory_tensor_t indices_tensor, wholememory_tensor_t output_tensor, wholememory_env_func_t* p_env_fns, - void* stream) + void* stream, + int gather_sms) { bool const has_handle = wholememory_tensor_has_handle(wholememory_tensor); wholememory_memory_type_t memory_type = WHOLEMEMORY_MT_NONE; @@ -86,7 +87,8 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten output, output_desc, p_env_fns, - static_cast(stream)); + static_cast(stream), + gather_sms); } WHOLEMEMORY_EXPECTS_NOTHROW(!has_handle || memory_type == WHOLEMEMORY_MT_CHUNKED || @@ -103,5 +105,6 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten output, output_desc, p_env_fns, - static_cast(stream)); + static_cast(stream), + gather_sms); } diff --git a/cpp/src/wholememory_ops/gather_op_impl.h b/cpp/src/wholememory_ops/gather_op_impl.h index f19bf228d..6f85d6410 100644 --- a/cpp/src/wholememory_ops/gather_op_impl.h +++ b/cpp/src/wholememory_ops/gather_op_impl.h @@ -28,7 +28,8 @@ wholememory_error_code_t wholememory_gather_mapped( void* output, wholememory_matrix_description_t output_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int gather_sms); wholememory_error_code_t wholememory_gather_nccl(wholememory_handle_t wholememory_handle, wholememory_matrix_description_t wholememory_desc, @@ -37,7 +38,8 @@ wholememory_error_code_t wholememory_gather_nccl(wholememory_handle_t wholememor void* output, wholememory_matrix_description_t output_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int gather_sms); wholememory_error_code_t wholememory_gather_distributed( wholememory_handle_t wholememory_handle, @@ -47,7 +49,8 @@ wholememory_error_code_t wholememory_gather_distributed( void* output, wholememory_matrix_description_t output_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int gather_sms); #ifdef WITH_NVSHMEM_SUPPORT @@ -59,6 +62,7 @@ wholememory_error_code_t wholememory_gather_nvshmem( void* output, wholememory_matrix_description_t output_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int gather_sms); #endif } // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/gather_op_impl_mapped.cu b/cpp/src/wholememory_ops/gather_op_impl_mapped.cu index bc78a7bcc..38e64919d 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_mapped.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_mapped.cu @@ -31,10 +31,17 @@ wholememory_error_code_t wholememory_gather_mapped( void* output, wholememory_matrix_description_t output_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { - WHOLEMEMORY_RETURN_ON_FAIL(gather_func( - wholememory_gref, wholememory_desc, indices, indice_desc, output, output_desc, stream)); + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(wholememory_gref, + wholememory_desc, + indices, + indice_desc, + output, + output_desc, + stream, + gather_sms)); WM_CUDA_DEBUG_SYNC_STREAM(stream); return WHOLEMEMORY_SUCCESS; } diff --git a/cpp/src/wholememory_ops/gather_op_impl_nccl.cu b/cpp/src/wholememory_ops/gather_op_impl_nccl.cu index 74d95c463..e842a829a 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_nccl.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_nccl.cu @@ -38,7 +38,8 @@ wholememory_error_code_t wholememory_gather_nccl(wholememory_handle_t wholememor void* output, wholememory_matrix_description_t output_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { try { if (wholememory_desc.storage_offset < 0 || @@ -122,7 +123,8 @@ wholememory_error_code_t wholememory_gather_nccl(wholememory_handle_t wholememor dev_recv_indice_desc, dev_local_gather_buffer_ptr, local_gather_buffer_desc, - stream)); + stream, + gather_sms)); // AllToAllV for embeddings size_t embedding_size = wholememory_desc.sizes[1] * wholememory_dtype_get_element_size(output_desc.dtype); @@ -152,7 +154,7 @@ wholememory_error_code_t wholememory_gather_nccl(wholememory_handle_t wholememor output_desc, stream)); WM_CUDA_CHECK(cudaGetLastError()); - WM_CUDA_CHECK(cudaStreamSynchronize(stream)); + // WM_CUDA_CHECK(cudaStreamSynchronize(stream)); } catch (wholememory::cuda_error& wce) { WHOLEMEMORY_ERROR("CUDA logic Error %s\n", wce.what()); return WHOLEMEMORY_CUDA_ERROR; @@ -174,7 +176,8 @@ wholememory_error_code_t wholememory_gather_distributed( void* output, wholememory_matrix_description_t output_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { #ifdef WITH_NVSHMEM_SUPPORT @@ -186,7 +189,8 @@ wholememory_error_code_t wholememory_gather_distributed( output, output_desc, p_env_fns, - stream); + stream, + gather_sms); } #endif return wholememory_gather_nccl(wholememory_handle, @@ -196,6 +200,7 @@ wholememory_error_code_t wholememory_gather_distributed( output, output_desc, p_env_fns, - stream); + stream, + gather_sms); } } // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu index 4051f12bd..8a683a8c1 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_nvshmem.cu @@ -55,7 +55,8 @@ wholememory_error_code_t nvshmem_gather_floating_int32_func( wholememory_matrix_description_t output_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int gather_sms); wholememory_error_code_t nvshmem_gather_floating_int64_func( wholememory_comm_t wm_comm, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, @@ -67,7 +68,8 @@ wholememory_error_code_t nvshmem_gather_floating_int64_func( wholememory_matrix_description_t output_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int gather_sms); wholememory_error_code_t nvshmem_gather_integer_int64_func( wholememory_comm_t wm_comm, @@ -80,7 +82,8 @@ wholememory_error_code_t nvshmem_gather_integer_int64_func( wholememory_matrix_description_t output_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int gather_sms); wholememory_error_code_t nvshmem_gather_integer_int32_func( wholememory_comm_t wm_comm, wholememory_nvshmem_ref_t embeding_nvshmem_ptr, @@ -92,7 +95,8 @@ wholememory_error_code_t nvshmem_gather_integer_int32_func( wholememory_matrix_description_t output_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int gather_sms); wholememory_error_code_t wholememory_gather_nvshmem( wholememory_handle_t wholememory_handle, @@ -102,7 +106,8 @@ wholememory_error_code_t wholememory_gather_nvshmem( void* output, wholememory_matrix_description_t output_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int gather_sms) { try { bool embedding_is_float = wholememory_dtype_is_floating_number(wholememory_desc.dtype); @@ -158,7 +163,8 @@ wholememory_error_code_t wholememory_gather_nvshmem( wholememory_matrix_description_t, size_t, wholememory_env_func_t*, - cudaStream_t) = nullptr; + cudaStream_t, + int) = nullptr; if (embedding_is_float) { if (indice_desc.dtype == WHOLEMEMORY_DT_INT) { @@ -183,7 +189,8 @@ wholememory_error_code_t wholememory_gather_nvshmem( output_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + gather_sms); // ungistre WM_CUDA_CHECK(cudaStreamSynchronize(stream)); if (nvshmemx_buffer_unregister(temp_output_ptr) != 0) { diff --git a/cpp/src/wholememory_ops/scatter_op.cpp b/cpp/src/wholememory_ops/scatter_op.cpp index 7ef57fc02..95316693d 100644 --- a/cpp/src/wholememory_ops/scatter_op.cpp +++ b/cpp/src/wholememory_ops/scatter_op.cpp @@ -24,7 +24,8 @@ wholememory_error_code_t wholememory_scatter(wholememory_tensor_t input_tensor, wholememory_tensor_t indices_tensor, wholememory_tensor_t wholememory_tensor, wholememory_env_func_t* p_env_fns, - void* stream) + void* stream, + int scatter_sms) { bool const has_handle = wholememory_tensor_has_handle(wholememory_tensor); wholememory_memory_type_t memory_type = WHOLEMEMORY_MT_NONE; @@ -86,7 +87,8 @@ wholememory_error_code_t wholememory_scatter(wholememory_tensor_t input_tensor, wholememory_tensor_get_memory_handle(wholememory_tensor), matrix_description, p_env_fns, - static_cast(stream)); + static_cast(stream), + scatter_sms); } WHOLEMEMORY_EXPECTS_NOTHROW(!has_handle || memory_type == WHOLEMEMORY_MT_CHUNKED || @@ -103,5 +105,6 @@ wholememory_error_code_t wholememory_scatter(wholememory_tensor_t input_tensor, gref, matrix_description, p_env_fns, - static_cast(stream)); + static_cast(stream), + scatter_sms); } diff --git a/cpp/src/wholememory_ops/scatter_op_impl.h b/cpp/src/wholememory_ops/scatter_op_impl.h index def8a4ed0..719e950b5 100644 --- a/cpp/src/wholememory_ops/scatter_op_impl.h +++ b/cpp/src/wholememory_ops/scatter_op_impl.h @@ -28,7 +28,8 @@ wholememory_error_code_t wholememory_scatter_mapped( wholememory_gref_t wholememory_gref, wholememory_matrix_description_t wholememory_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int scatter_sms); wholememory_error_code_t wholememory_scatter_nccl(void* input, wholememory_matrix_description_t input_desc, @@ -37,7 +38,8 @@ wholememory_error_code_t wholememory_scatter_nccl(void* input, wholememory_handle_t wholememory_handle, wholememory_matrix_description_t wholememory_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int scatter_sms); wholememory_error_code_t wholememory_scatter_distributed( void* input, @@ -47,7 +49,8 @@ wholememory_error_code_t wholememory_scatter_distributed( wholememory_handle_t wholememory_handle, wholememory_matrix_description_t wholememory_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int scatter_sms); #ifdef WITH_NVSHMEM_SUPPORT wholememory_error_code_t wholememory_scatter_nvshmem( @@ -58,7 +61,8 @@ wholememory_error_code_t wholememory_scatter_nvshmem( wholememory_handle_t wholememory_handle, wholememory_matrix_description_t wholememory_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int scatter_sms); #endif } // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/scatter_op_impl.nvshmem.cu b/cpp/src/wholememory_ops/scatter_op_impl.nvshmem.cu index 407da62e8..80dc20784 100644 --- a/cpp/src/wholememory_ops/scatter_op_impl.nvshmem.cu +++ b/cpp/src/wholememory_ops/scatter_op_impl.nvshmem.cu @@ -51,7 +51,8 @@ wholememory_error_code_t nvshmem_scatter_floating_int32_func( wholememory_matrix_description_t embedding_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int scatter_sms); wholememory_error_code_t nvshmem_scatter_floating_int64_func( wholememory_comm_t wm_comm, @@ -64,7 +65,8 @@ wholememory_error_code_t nvshmem_scatter_floating_int64_func( wholememory_matrix_description_t embedding_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int scatter_sms); wholememory_error_code_t nvshmem_scatter_integer_int32_func( wholememory_comm_t wm_comm, @@ -77,7 +79,8 @@ wholememory_error_code_t nvshmem_scatter_integer_int32_func( wholememory_matrix_description_t embedding_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int scatter_sms); wholememory_error_code_t nvshmem_scatter_integer_int64_func( wholememory_comm_t wm_comm, @@ -90,7 +93,8 @@ wholememory_error_code_t nvshmem_scatter_integer_int64_func( wholememory_matrix_description_t embedding_desc, size_t embedding_entry_count_per_rank, wholememory_env_func_t* p_env_fns, - cudaStream_t stream); + cudaStream_t stream, + int scatter_sms); wholememory_error_code_t wholememory_scatter_nvshmem( void* input, @@ -100,7 +104,8 @@ wholememory_error_code_t wholememory_scatter_nvshmem( wholememory_handle_t wholememory_handle, wholememory_matrix_description_t wholememory_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { try { bool embedding_is_float = wholememory_dtype_is_floating_number(wholememory_desc.dtype); @@ -165,7 +170,8 @@ wholememory_error_code_t wholememory_scatter_nvshmem( wholememory_matrix_description_t, size_t, wholememory_env_func_t*, - cudaStream_t); + cudaStream_t, + int); if (embedding_is_float) { if (indices_desc.dtype == WHOLEMEMORY_DT_INT) { @@ -191,7 +197,8 @@ wholememory_error_code_t wholememory_scatter_nvshmem( wholememory_desc, embedding_entry_count_per_rank, p_env_fns, - stream); + stream, + scatter_sms); if (nvshmemx_buffer_unregister(temp_input_ptr) != 0) { WHOLEMEMORY_ERROR("nvshmemx_buffer_unregister error in wholememory_gather_nvshmem"); } diff --git a/cpp/src/wholememory_ops/scatter_op_impl_mapped.cu b/cpp/src/wholememory_ops/scatter_op_impl_mapped.cu index 9adc8d664..77f570f90 100644 --- a/cpp/src/wholememory_ops/scatter_op_impl_mapped.cu +++ b/cpp/src/wholememory_ops/scatter_op_impl_mapped.cu @@ -30,10 +30,17 @@ wholememory_error_code_t wholememory_scatter_mapped( wholememory_gref_t wholememory_gref, wholememory_matrix_description_t wholememory_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { - return scatter_func( - input, input_desc, indices, indices_desc, wholememory_gref, wholememory_desc, stream); + return scatter_func(input, + input_desc, + indices, + indices_desc, + wholememory_gref, + wholememory_desc, + stream, + scatter_sms); } } // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/scatter_op_impl_nccl.cu b/cpp/src/wholememory_ops/scatter_op_impl_nccl.cu index 752ddee71..47765de17 100644 --- a/cpp/src/wholememory_ops/scatter_op_impl_nccl.cu +++ b/cpp/src/wholememory_ops/scatter_op_impl_nccl.cu @@ -38,7 +38,8 @@ wholememory_error_code_t wholememory_scatter_nccl(void* input, wholememory_handle_t wholememory_handle, wholememory_matrix_description_t wholememory_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { try { if (wholememory_desc.storage_offset < 0 || @@ -148,7 +149,8 @@ wholememory_error_code_t wholememory_scatter_nccl(void* input, recv_indices_desc, local_fake_embedding_gref, wholememory_desc, - stream)); + stream, + scatter_sms)); WM_CUDA_CHECK(cudaGetLastError()); WM_CUDA_CHECK(cudaStreamSynchronize(stream)); } catch (wholememory::cuda_error& wce) { @@ -173,7 +175,8 @@ wholememory_error_code_t wholememory_scatter_distributed( wholememory_handle_t wholememory_handle, wholememory_matrix_description_t wholememory_desc, wholememory_env_func_t* p_env_fns, - cudaStream_t stream) + cudaStream_t stream, + int scatter_sms) { #ifdef WITH_NVSHMEM_SUPPORT if (wholememory_get_distributed_backend(wholememory_handle) == WHOLEMEMORY_DB_NVSHMEM) { @@ -184,7 +187,8 @@ wholememory_error_code_t wholememory_scatter_distributed( wholememory_handle, wholememory_desc, p_env_fns, - stream); + stream, + scatter_sms); } #endif @@ -195,6 +199,7 @@ wholememory_error_code_t wholememory_scatter_distributed( wholememory_handle, wholememory_desc, p_env_fns, - stream); + stream, + scatter_sms); } } // namespace wholememory_ops diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index 263fbd62f..77d86ffdb 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -607,7 +607,8 @@ cdef extern from "wholememory/embedding.h": wholememory_memory_type_t memory_type, wholememory_memory_location_t memory_location, wholememory_embedding_optimizer_t optimizer, - wholememory_embedding_cache_policy_t cache_policy) + wholememory_embedding_cache_policy_t cache_policy, + int user_defined_sms) cdef wholememory_error_code_t wholememory_destroy_embedding( wholememory_embedding_t wholememory_embedding) @@ -770,7 +771,8 @@ cdef class PyWholeMemoryEmbedding: WholeMemoryMemoryType memory_type, WholeMemoryMemoryLocation memory_location, WholeMemoryOptimizer optimizer, - WholeMemoryCachePolicy cache_policy): + WholeMemoryCachePolicy cache_policy, + int user_defined_sms): self.memory_type = memory_type self.memory_location = memory_location check_wholememory_error_code(wholememory_create_embedding(&self.wm_embedding, @@ -779,7 +781,8 @@ cdef class PyWholeMemoryEmbedding: self.memory_type, self.memory_location, optimizer.wm_optimizer, - cache_policy.cache_policy)) + cache_policy.cache_policy, + user_defined_sms)) def destroy_embedding(self): check_wholememory_error_code(wholememory_destroy_embedding(self.wm_embedding)) @@ -824,14 +827,16 @@ def create_embedding(PyWholeMemoryTensorDescription tensor_desc, WholeMemoryMemoryType memory_type, WholeMemoryMemoryLocation memory_location, WholeMemoryOptimizer optimizer, - WholeMemoryCachePolicy cache_policy): + WholeMemoryCachePolicy cache_policy, + int user_defined_sms): wm_embedding = PyWholeMemoryEmbedding() wm_embedding.create_embedding(tensor_desc, comm, memory_type, memory_location, optimizer, - cache_policy) + cache_policy, + user_defined_sms) return wm_embedding cpdef void EmbeddingGatherForward(PyWholeMemoryEmbedding wm_embedding, diff --git a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py index ce59e7b54..47d3a9909 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py @@ -406,6 +406,7 @@ def create_embedding( optimizer: Union[WholeMemoryOptimizer, None] = None, cache_policy: Union[WholeMemoryCachePolicy, None] = None, random_init: bool = False, + gather_sms: int = -1, ): r""" Create embedding @@ -416,6 +417,7 @@ def create_embedding( :param sizes: size of the embedding, must be 2D :param optimizer: optimizer :param cache_policy: cache policy + :param gather_sms: the number of SMs used in gather process :return: WholeMemoryEmbedding """ if optimizer is None: @@ -446,6 +448,7 @@ def create_embedding( str_to_wmb_wholememory_location(memory_location), wmb_optimizer, wmb_cache_policy, + user_defined_sms=gather_sms, ), optimizer, cache_policy, @@ -472,6 +475,7 @@ def create_embedding_from_filelist( *, optimizer: Union[WholeMemoryOptimizer, None] = None, cache_policy: Union[WholeMemoryCachePolicy, None] = None, + gather_sms: int = -1, ): r""" Create embedding from file list @@ -483,6 +487,7 @@ def create_embedding_from_filelist( :param last_dim_size: size of last dim :param optimizer: optimizer :param cache_policy: cache policy + :param gather_sms: the number of SMs used in gather process :return: """ if isinstance(filelist, str): @@ -508,6 +513,7 @@ def create_embedding_from_filelist( [total_entry_count, last_dim_size], optimizer=optimizer, cache_policy=cache_policy, + gather_sms=gather_sms ) wm_embedding.get_embedding_tensor().from_filelist(filelist) return wm_embedding From cecd3ff9e8e194c6a2cd842c56a9fc94e781c23c Mon Sep 17 00:00:00 2001 From: Ray Douglass Date: Mon, 12 Feb 2024 15:45:24 -0500 Subject: [PATCH 16/16] Update Changelog [skip ci] --- CHANGELOG.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f6df02332..f721af24f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,28 @@ +# wholegraph 24.02.00 (12 Feb 2024) + +## 🐛 Bug Fixes + +- Revert "Exclude tests from builds ([#127)" (#130](https://github.com/rapidsai/wholegraph/pull/127)" (#130)) [@raydouglass](https://github.com/raydouglass) +- Exclude tests from builds ([#127](https://github.com/rapidsai/wholegraph/pull/127)) [@vyasr](https://github.com/vyasr) +- fix a bug for embedding optimizer, which leads to undefined behavior ([#108](https://github.com/rapidsai/wholegraph/pull/108)) [@linhu-nv](https://github.com/linhu-nv) +- fix inferencesample option ([#107](https://github.com/rapidsai/wholegraph/pull/107)) [@chuangz0](https://github.com/chuangz0) + +## 🚀 New Features + +- allow users to control gather/scatter sms ([#124](https://github.com/rapidsai/wholegraph/pull/124)) [@linhu-nv](https://github.com/linhu-nv) + +## 🛠️ Improvements + +- Logging level ([#123](https://github.com/rapidsai/wholegraph/pull/123)) [@linhu-nv](https://github.com/linhu-nv) +- Fix pip dependencies ([#118](https://github.com/rapidsai/wholegraph/pull/118)) [@trxcllnt](https://github.com/trxcllnt) +- Remove usages of rapids-env-update ([#117](https://github.com/rapidsai/wholegraph/pull/117)) [@KyleFromNVIDIA](https://github.com/KyleFromNVIDIA) +- refactor CUDA versions in dependencies.yaml ([#115](https://github.com/rapidsai/wholegraph/pull/115)) [@jameslamb](https://github.com/jameslamb) +- Don't overwrite wholegraph_ROOT if provided ([#114](https://github.com/rapidsai/wholegraph/pull/114)) [@vyasr](https://github.com/vyasr) +- added Direct IO support for WholeMemory loading ([#113](https://github.com/rapidsai/wholegraph/pull/113)) [@dongxuy04](https://github.com/dongxuy04) +- Align versions for cudnn, clang-tools, cython, and doxygen with the rest of RAPIDS. ([#112](https://github.com/rapidsai/wholegraph/pull/112)) [@bdice](https://github.com/bdice) +- Reset WholeGraph communicators during the finalize call ([#111](https://github.com/rapidsai/wholegraph/pull/111)) [@chang-l](https://github.com/chang-l) +- Forward-merge branch-23.12 to branch-24.02 ([#102](https://github.com/rapidsai/wholegraph/pull/102)) [@bdice](https://github.com/bdice) + # wholegraph 23.12.00 (6 Dec 2023) ## 🐛 Bug Fixes