diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 47951783ba..9c22edf74c 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -80,7 +80,6 @@ jobs: with: build_type: pull-request enable_check_symbols: true - symbol_exclusions: raft_cutlass conda-python-build: needs: conda-cpp-build secrets: inherit diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2bee8a3d1d..92020f6a76 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -23,7 +23,6 @@ jobs: date: ${{ inputs.date }} sha: ${{ inputs.sha }} enable_check_symbols: true - symbol_exclusions: raft_cutlass conda-cpp-tests: secrets: inherit uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.12 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d8ccf92ce5..e3b3c8c440 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -93,7 +93,10 @@ repos: - id: codespell additional_dependencies: [tomli] args: ["--toml", "pyproject.toml"] - exclude: (?x)^(^CHANGELOG.md$) + exclude: | + (?x) + ^CHANGELOG[.]md$| + ^cpp/cmake/patches/cutlass/build-export[.]patch$ - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index a118becb4b..a70fed9ec8 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -35,7 +35,6 @@ function sed_runner() { sed -i.bak ''"$1"'' $2 && rm -f ${2}.bak } -sed_runner "s/set(RAPIDS_VERSION .*)/set(RAPIDS_VERSION \"${NEXT_SHORT_TAG}\")/g" cpp/template/cmake/thirdparty/fetch_rapids.cmake sed_runner 's/'"find_and_configure_ucxx(VERSION .*"'/'"find_and_configure_ucxx(VERSION ${NEXT_UCXX_SHORT_TAG_PEP440}"'/g' python/raft-dask/cmake/thirdparty/get_ucxx.cmake sed_runner 's/'"branch-.*"'/'"branch-${NEXT_UCXX_SHORT_TAG_PEP440}"'/g' python/raft-dask/cmake/thirdparty/get_ucxx.cmake diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4ed9529a36..780f6f8581 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -268,6 +268,10 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu src/raft_runtime/random/rmat_rectangular_generator_int_double.cu src/raft_runtime/random/rmat_rectangular_generator_int_float.cu + src/raft_runtime/solver/lanczos_solver_int64_double.cu + src/raft_runtime/solver/lanczos_solver_int64_float.cu + src/raft_runtime/solver/lanczos_solver_int_double.cu + src/raft_runtime/solver/lanczos_solver_int_float.cu ) set_target_properties( raft_objs diff --git a/cpp/cmake/patches/cutlass/build-export.patch b/cpp/cmake/patches/cutlass/build-export.patch new file mode 100644 index 0000000000..a6423e9c08 --- /dev/null +++ b/cpp/cmake/patches/cutlass/build-export.patch @@ -0,0 +1,27 @@ +From e0a9597946257a01ae8444200f836ee51d5597ba Mon Sep 17 00:00:00 2001 +From: Kyle Edwards +Date: Wed, 20 Nov 2024 16:37:38 -0500 +Subject: [PATCH] Remove erroneous include directories + +These directories are left over from when CuTe was a separate +CMake project. Remove them. +--- + CMakeLists.txt | 2 -- + 1 file changed, 2 deletions(-) + +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 7419bdf5e..545384d82 100755 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -665,8 +665,6 @@ target_include_directories( + $ + $ + $ +- $ +- $ + ) + + # Mark CTK headers as system to supress warnings from them +-- +2.34.1 + diff --git a/cpp/cmake/patches/cutlass_override.json b/cpp/cmake/patches/cutlass_override.json new file mode 100644 index 0000000000..7bf818987f --- /dev/null +++ b/cpp/cmake/patches/cutlass_override.json @@ -0,0 +1,16 @@ +{ + "packages" : { + "cutlass" : { + "version": "3.5.1", + "git_url": "https://github.com/NVIDIA/cutlass.git", + "git_tag": "v${version}", + "patches" : [ + { + "file" : "${current_json_dir}/cutlass/build-export.patch", + "issue" : "Fix build directory export", + "fixed_in" : "" + } + ] + } + } +} diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index 0123c4b07a..d5bdd4632f 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -13,7 +13,9 @@ # ============================================================================= function(find_and_configure_cutlass) - set(oneValueArgs VERSION REPOSITORY PINNED_TAG) + set(options) + set(oneValueArgs) + set(multiValueArgs) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) # if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) @@ -34,13 +36,22 @@ function(find_and_configure_cutlass) set(CUDART_LIBRARY "${CUDA_cudart_static_LIBRARY}" CACHE FILEPATH "fixing cutlass cmake code" FORCE) endif() + include("${rapids-cmake-dir}/cpm/package_override.cmake") + rapids_cpm_package_override("${CMAKE_CURRENT_FUNCTION_LIST_DIR}/../patches/cutlass_override.json") + + include("${rapids-cmake-dir}/cpm/detail/package_details.cmake") + rapids_cpm_package_details(cutlass version repository tag shallow exclude) + + include("${rapids-cmake-dir}/cpm/detail/generate_patch_command.cmake") + rapids_cpm_generate_patch_command(cutlass ${version} patch_command) + rapids_cpm_find( - NvidiaCutlass ${PKG_VERSION} + NvidiaCutlass ${version} GLOBAL_TARGETS nvidia::cutlass::cutlass CPM_ARGS - GIT_REPOSITORY ${PKG_REPOSITORY} - GIT_TAG ${PKG_PINNED_TAG} - GIT_SHALLOW TRUE + GIT_REPOSITORY ${repository} + GIT_TAG ${tag} + GIT_SHALLOW ${shallow} ${patch_command} OPTIONS "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}" ) @@ -79,14 +90,4 @@ function(find_and_configure_cutlass) ) endfunction() -if(NOT RAFT_CUTLASS_GIT_TAG) - set(RAFT_CUTLASS_GIT_TAG v2.10.0) -endif() - -if(NOT RAFT_CUTLASS_GIT_REPOSITORY) - set(RAFT_CUTLASS_GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git) -endif() - -find_and_configure_cutlass( - VERSION 2.10.0 REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} PINNED_TAG ${RAFT_CUTLASS_GIT_TAG} -) +find_and_configure_cutlass() diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index b4e9f8d1d7..c5241e831b 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -210,7 +210,7 @@ auto constexpr make_device_strided_matrix_view(ElementType* ptr, constexpr auto is_row_major = std::is_same_v; constexpr auto is_col_major = std::is_same_v; - assert(is_row_major || is_col_major); + static_assert(is_row_major || is_col_major, "Unsupported layout policy for strided matrix view"); IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1; IndexType stride1 = is_row_major ? 1 : (stride > 0 ? stride : n_rows); diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h index 2b2c04b9d3..f6dea987e5 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h @@ -61,6 +61,7 @@ class PairwiseDistanceEpilogueElementwise { using ElementT = ElementT_; static int const kElementsPerAccess = ElementsPerAccess; static int const kCount = kElementsPerAccess; + static bool const kIsSingleSource = true; using DistanceOp = DistanceOp_; using FinalOp = FinalOp_; diff --git a/cpp/include/raft/distance/fused_distance_nn-ext.cuh b/cpp/include/raft/distance/fused_distance_nn-ext.cuh deleted file mode 100644 index 263bbcea81..0000000000 --- a/cpp/include/raft/distance/fused_distance_nn-ext.cuh +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::KeyValuePair -#include // raft::resources -#include // include initialize and reduce operations -#include // RAFT_EXPLICIT - -#include // int64_t - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft { -namespace distance { - -template -void fusedDistanceNNMinReduce(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - bool sqrt, - bool initOutBuffer, - bool isRowMajor, - raft::distance::DistanceType metric, - float metric_arg, - cudaStream_t stream) RAFT_EXPLICIT; - -} // namespace distance -} // namespace raft - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ - extern template void raft::distance::fusedDistanceNNMinReduce( \ - OutT * min, \ - const DataT* x, \ - const DataT* y, \ - const DataT* xn, \ - const DataT* yn, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - bool sqrt, \ - bool initOutBuffer, \ - bool isRowMajor, \ - raft::distance::DistanceType metric, \ - float metric_arg, \ - cudaStream_t stream) - -instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); -instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); - -// We can't have comma's in the macro expansion, so we use the COMMA macro: -#define COMMA , - -instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); -instantiate_raft_distance_fusedDistanceNNMinReduce(float, - raft::KeyValuePair, - int64_t); - -#undef COMMA - -#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/include/raft/distance/fused_distance_nn.cuh b/cpp/include/raft/distance/fused_distance_nn.cuh index 04c42e49a1..25b1ae01ea 100755 --- a/cpp/include/raft/distance/fused_distance_nn.cuh +++ b/cpp/include/raft/distance/fused_distance_nn.cuh @@ -15,10 +15,4 @@ */ #pragma once -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "fused_distance_nn-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "fused_distance_nn-ext.cuh" -#endif +#include "fused_distance_nn-inl.cuh" \ No newline at end of file diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh deleted file mode 100644 index 4800f2e3cf..0000000000 --- a/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -// The explicit instantiation of raft::linalg::detail::coalescedReduction is not -// forced because there would be too many instances. Instead, we cover the most -// common instantiations with extern template instantiations below. - -#define instantiate_raft_linalg_detail_coalescedReduction( \ - InType, OutType, IdxType, MainLambda, ReduceLambda, FinalLambda) \ - extern template void raft::linalg::detail::coalescedReduction(OutType* dots, \ - const InType* data, \ - IdxType D, \ - IdxType N, \ - OutType init, \ - cudaStream_t stream, \ - bool inplace, \ - MainLambda main_op, \ - ReduceLambda reduce_op, \ - FinalLambda final_op) - -instantiate_raft_linalg_detail_coalescedReduction( - double, double, int, raft::identity_op, raft::min_op, raft::identity_op); -instantiate_raft_linalg_detail_coalescedReduction( - double, double, int, raft::sq_op, raft::add_op, raft::identity_op); -instantiate_raft_linalg_detail_coalescedReduction( - double, double, int, raft::sq_op, raft::add_op, raft::sqrt_op); -instantiate_raft_linalg_detail_coalescedReduction( - double, double, int, raft::abs_op, raft::add_op, raft::identity_op); -instantiate_raft_linalg_detail_coalescedReduction( - double, double, int, raft::abs_op, raft::max_op, raft::identity_op); -instantiate_raft_linalg_detail_coalescedReduction( - float, float, size_t, raft::abs_op, raft::add_op, raft::sqrt_op); -instantiate_raft_linalg_detail_coalescedReduction( - float, float, int, raft::abs_op, raft::add_op, raft::identity_op); -instantiate_raft_linalg_detail_coalescedReduction( - float, float, int, raft::identity_op, raft::add_op, raft::identity_op); -instantiate_raft_linalg_detail_coalescedReduction( - float, float, int, raft::identity_op, raft::min_op, raft::identity_op); -instantiate_raft_linalg_detail_coalescedReduction( - float, float, int, raft::sq_op, raft::add_op, raft::identity_op); -instantiate_raft_linalg_detail_coalescedReduction( - float, float, int, raft::sq_op, raft::add_op, raft::sqrt_op); -instantiate_raft_linalg_detail_coalescedReduction( - float, float, long, raft::sq_op, raft::add_op, raft::identity_op); -instantiate_raft_linalg_detail_coalescedReduction( - float, float, size_t, raft::identity_op, raft::add_op, raft::identity_op); -instantiate_raft_linalg_detail_coalescedReduction( - float, float, size_t, raft::sq_op, raft::add_op, raft::identity_op); -instantiate_raft_linalg_detail_coalescedReduction( - float, float, size_t, raft::abs_op, raft::max_op, raft::sqrt_op); -instantiate_raft_linalg_detail_coalescedReduction( - float, float, size_t, raft::sq_op, raft::add_op, raft::sqrt_op); -instantiate_raft_linalg_detail_coalescedReduction( - float, float, unsigned int, raft::sq_op, raft::add_op, raft::identity_op); - -#undef instantiate_raft_linalg_detail_coalescedReduction diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh index 3e6b17978b..d24c2a7444 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh @@ -16,11 +16,4 @@ #pragma once -// Always include inline definitions of coalesced reduction, because we do not -// force explicit instantion. #include "coalesced_reduction-inl.cuh" - -// Do include the extern template instantiations when possible. -#ifdef RAFT_COMPILED -#include "coalesced_reduction-ext.cuh" -#endif diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh deleted file mode 100644 index 35f4f0e1c9..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh +++ /dev/null @@ -1,418 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include // none_cagra_sample_filter -#include // RAFT_EXPLICIT - -#include - -namespace raft::neighbors::cagra::detail { -namespace multi_cta_search { - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -template -void select_and_run( - DATASET_DESCRIPTOR_T dataset_desc, - raft::device_matrix_view graph, - typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, - const uint32_t num_queries, - const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, - uint32_t* const num_executed_iterations, - uint32_t topk, - uint32_t block_size, - uint32_t result_buffer_size, - uint32_t smem_size, - int64_t hash_bitlen, - typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, - uint32_t num_cta_per_query, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, - uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, - SAMPLE_FILTER_T sample_filter, - raft::distance::DistanceType metric, - cudaStream_t stream) RAFT_EXPLICIT; -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - extern template void select_and_run< \ - TEAM_SIZE, \ - MAX_DATASET_DIM, \ - raft::neighbors::cagra::detail::standard_dataset_descriptor_t, \ - SAMPLE_FILTER_T>( \ - raft::neighbors::cagra::detail::standard_dataset_descriptor_t \ - dataset_desc, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - raft::distance::DistanceType metric, \ - cudaStream_t stream); - -instantiate_kernel_selection( - 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 1024, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 8, 128, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 16, 256, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 512, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection - -#define instantiate_q_kernel_selection(TEAM_SIZE, \ - MAX_DATASET_DIM, \ - CODE_BOOK_T, \ - PQ_BITS, \ - PQ_CODE_BOOK_DIM, \ - DATA_T, \ - INDEX_T, \ - DISTANCE_T, \ - SAMPLE_FILTER_T) \ - extern template void \ - select_and_run, \ - SAMPLE_FILTER_T>( \ - raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t dataset_desc, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - raft::distance::DistanceType metric, \ - cudaStream_t stream); - -instantiate_q_kernel_selection( - 8, 128, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection( - 16, 256, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection( - 32, 512, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 2, - half, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection( - 8, 128, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection( - 16, 256, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection( - 32, 512, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 4, - half, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_kernel_selection( - 8, 128, half, 8, 2, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(16, - 256, - half, - 8, - 2, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 512, - half, - 8, - 2, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 2, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection( - 8, 128, half, 8, 4, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(16, - 256, - half, - 8, - 4, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 512, - half, - 8, - 4, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 4, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_kernel_selection(8, - 128, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(16, - 256, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 512, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(8, - 128, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(16, - 256, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 512, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_kernel_selection(8, - 128, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(16, - 256, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 512, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(8, - 128, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(16, - 256, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 512, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_q_kernel_selection -} // namespace multi_cta_search -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel.cuh index e003907292..3dc0745e6d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel.cuh @@ -15,10 +15,4 @@ */ #pragma once -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "search_multi_cta_kernel-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "search_multi_cta_kernel-ext.cuh" -#endif diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh deleted file mode 100644 index 510219ab5d..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh +++ /dev/null @@ -1,602 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include // RAFT_EXPLICIT - -#include - -namespace raft::neighbors::cagra::detail { -namespace single_cta_search { - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -template -void select_and_run( // raft::resources const& res, - DATASET_DESCRIPTOR_T dataset_desc, - raft::device_matrix_view graph, - typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, // [num_queries, topk] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const uint32_t num_queries, - const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk, - uint32_t num_itopk_candidates, - uint32_t block_size, - uint32_t smem_size, - int64_t hash_bitlen, - typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, - size_t small_hash_bitlen, - size_t small_hash_reset_interval, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, - uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, - SAMPLE_FILTER_T sample_filter, - raft::distance::DistanceType metric, - cudaStream_t stream) RAFT_EXPLICIT; - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - extern template void select_and_run< \ - TEAM_SIZE, \ - MAX_DATASET_DIM, \ - raft::neighbors::cagra::detail::standard_dataset_descriptor_t, \ - SAMPLE_FILTER_T>( \ - raft::neighbors::cagra::detail::standard_dataset_descriptor_t \ - dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - raft::distance::DistanceType metric, \ - cudaStream_t stream); - -instantiate_single_cta_select_and_run( - 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 1024, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 8, 128, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 16, 256, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 512, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_select_and_run - -#define instantiate_q_single_cta_select_and_run(TEAM_SIZE, \ - MAX_DATASET_DIM, \ - CODE_BOOK_T, \ - PQ_BITS, \ - PQ_CODE_BOOK_DIM, \ - DATA_T, \ - INDEX_T, \ - DISTANCE_T, \ - SAMPLE_FILTER_T) \ - extern template void \ - select_and_run, \ - SAMPLE_FILTER_T>( \ - raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - raft::distance::DistanceType metric, \ - cudaStream_t stream); - -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 16, 256, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 512, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - half, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 16, 256, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 512, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - half, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 2, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 2, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 2, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 4, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 4, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 4, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 2, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 16, 256, half, 8, 2, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 512, half, 8, 2, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 1024, half, 8, 2, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 4, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 16, 256, half, 8, 4, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 512, half, 8, 4, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 1024, half, 8, 4, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 2, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 16, 256, half, 8, 2, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 512, half, 8, 2, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - float, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 4, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 16, 256, half, 8, 4, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 512, half, 8, 4, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - float, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run(8, - 128, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(8, - 128, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run(8, - 128, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(8, - 128, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run(8, - 128, - half, - 8, - 2, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 2, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 2, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(8, - 128, - half, - 8, - 4, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 4, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 4, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 2, int8_t, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 2, - int8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 2, - int8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - int8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 4, int8_t, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 4, - int8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 4, - int8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - int8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_q_single_cta_select_and_run - -} // namespace single_cta_search -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel.cuh index 1d8fd8e30a..3e72fbf8e8 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel.cuh @@ -15,10 +15,4 @@ */ #pragma once -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "search_single_cta_kernel-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "search_single_cta_kernel-ext.cuh" -#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh deleted file mode 100644 index 140a9f17c8..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::neighbors::ivf_flat::index -#include // none_ivf_sample_filter -#include // RAFT_EXPLICIT - -#include // rmm:cuda_stream_view - -#include - -#include // uintX_t - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft::neighbors::ivf_flat::detail { - -auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k) -> bool; - -template -void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& index, - const T* queries, - const uint32_t* coarse_query_results, - const uint32_t n_queries, - const uint32_t queries_offset, - const raft::distance::DistanceType metric, - const uint32_t n_probes, - const uint32_t k, - const uint32_t max_samples, - const uint32_t* chunk_indices, - const bool select_min, - IvfSampleFilterT sample_filter, - uint32_t* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) RAFT_EXPLICIT; - -} // namespace raft::neighbors::ivf_flat::detail - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ - T, AccT, IdxT, IvfSampleFilterT) \ - extern template void \ - raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const uint32_t queries_offset, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const uint32_t max_samples, \ - const uint32_t* chunk_indices, \ - const bool select_min, \ - IvfSampleFilterT sample_filter, \ - uint32_t* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream) - -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - half, half, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); - -#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh index 63f341dd9a..11d7da851a 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh @@ -16,10 +16,4 @@ #pragma once -#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) #include "ivf_flat_interleaved_scan-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "ivf_flat_interleaved_scan-ext.cuh" -#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh deleted file mode 100644 index c14b0e810f..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::neighbors::ivf_flat::index -#include // none_ivf_sample_filter -#include // RAFT_EXPLICIT - -#include - -#include - -#include // uintX_t - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft::neighbors::ivf_flat::detail { - -template -void search(raft::resources const& handle, - const search_params& params, - const raft::neighbors::ivf_flat::index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::device_async_resource_ref mr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; - -} // namespace raft::neighbors::ivf_flat::detail - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, IvfSampleFilterT) \ - extern template void raft::neighbors::ivf_flat::detail::search( \ - raft::resources const& handle, \ - const search_params& params, \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::device_async_resource_ref mr, \ - IvfSampleFilterT sample_filter) - -instantiate_raft_neighbors_ivf_flat_detail_search( - float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_search( - half, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_search( - int8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_search( - uint8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); - -#undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh index 7b03ebeab6..56e58bac27 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh @@ -15,10 +15,4 @@ */ #pragma once -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "ivf_flat_search-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "ivf_flat_search-ext.cuh" -#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh deleted file mode 100644 index 5e1a9b46d6..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // RAFT_WEAK_FUNCTION -#include // raft::distance::DistanceType -#include // raft::neighbors::ivf_pq::detail::fp_8bit -#include // raft::neighbors::ivf_pq::codebook_gen -#include // none_ivf_sample_filter -#include // RAFT_EXPLICIT - -#include // rmm::cuda_stream_view - -#include // __half - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft::neighbors::ivf_pq::detail { - -// is_local_topk_feasible is not inline here, because we would have to define it -// here as well. That would run the risk of the definitions here and in the -// -inl.cuh header diverging. -auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) - -> bool; - -template -RAFT_KERNEL compute_similarity_kernel(uint32_t dim, - uint32_t n_probes, - uint32_t pq_dim, - uint32_t n_queries, - uint32_t queries_offset, - distance::DistanceType metric, - codebook_gen codebook_kind, - uint32_t topk, - uint32_t max_samples, - const float* cluster_centers, - const float* pq_centers, - const uint8_t* const* pq_dataset, - const uint32_t* cluster_labels, - const uint32_t* _chunk_indices, - const float* queries, - const uint32_t* index_list, - float* query_kths, - IvfSampleFilterT sample_filter, - LutT* lut_scores, - OutT* _out_scores, - uint32_t* _out_indices) RAFT_EXPLICIT; - -// The signature of the kernel defined by a minimal set of template parameters -template -using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); - -template -struct selected { - compute_similarity_kernel_t kernel; - dim3 grid_dim; - dim3 block_dim; - size_t smem_size; - size_t device_lut_size; -}; - -template -void compute_similarity_run(selected s, - rmm::cuda_stream_view stream, - uint32_t dim, - uint32_t n_probes, - uint32_t pq_dim, - uint32_t n_queries, - uint32_t queries_offset, - distance::DistanceType metric, - codebook_gen codebook_kind, - uint32_t topk, - uint32_t max_samples, - const float* cluster_centers, - const float* pq_centers, - const uint8_t* const* pq_dataset, - const uint32_t* cluster_labels, - const uint32_t* _chunk_indices, - const float* queries, - const uint32_t* index_list, - float* query_kths, - IvfSampleFilterT sample_filter, - LutT* lut_scores, - OutT* _out_scores, - uint32_t* _out_indices) RAFT_EXPLICIT; - -/** - * Use heuristics to choose an optimal instance of the search kernel. - * It selects among a few kernel variants (with/out using shared mem for - * lookup tables / precomputed distances) and tries to choose the block size - * to maximize kernel occupancy. - * - * @param manage_local_topk - * whether use the fused calculate+select or just calculate the distances for each - * query and probed cluster. - * - * @param locality_hint - * beyond this limit do not consider increasing the number of active blocks per SM - * would improve locality anymore. - */ -template -auto compute_similarity_select(const cudaDeviceProp& dev_props, - bool manage_local_topk, - int locality_hint, - double preferred_shmem_carveout, - uint32_t pq_bits, - uint32_t pq_dim, - uint32_t precomp_data_count, - uint32_t n_queries, - uint32_t n_probes, - uint32_t topk) - -> selected RAFT_EXPLICIT; - -} // namespace raft::neighbors::ivf_pq::detail - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - extern template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - extern template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, - raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, - raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, - half, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, - half, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, - float, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, - raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, - raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh index d987c0d4ed..467d389d38 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh @@ -16,10 +16,4 @@ #pragma once -#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) #include "ivf_pq_compute_similarity-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "ivf_pq_compute_similarity-ext.cuh" -#endif diff --git a/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp b/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp index c10c0de426..97ac7c45f4 100644 --- a/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp +++ b/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp @@ -30,7 +30,25 @@ namespace linalg { namespace detail { /** - * @brief create a cuSparse dense descriptor + * @brief create a cuSparse dense descriptor for a vector + * @tparam ValueType Data type of vector_view (float/double) + * @tparam IndexType Type of vector_view + * @param[in] vector_view input raft::device_vector_view + * @returns dense vector descriptor to be used by cuSparse API + */ +template +cusparseDnVecDescr_t create_descriptor(raft::device_vector_view vector_view) +{ + cusparseDnVecDescr_t descr; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednvec( + &descr, + vector_view.extent(0), + const_cast*>(vector_view.data_handle()))); + return descr; +} + +/** + * @brief create a cuSparse dense descriptor for a matrix * @tparam ValueType Data type of dense_view (float/double) * @tparam IndexType Type of dense_view * @tparam LayoutPolicy layout of dense_view diff --git a/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh b/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh deleted file mode 100644 index 01625a0ce8..0000000000 --- a/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include // RAFT_EXPLICIT - -#include // __half - -#include // uint32_t - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft::sparse::matrix::detail { - -template -void select_k(raft::resources const& handle, - raft::device_csr_matrix_view in_val, - std::optional> in_idx, - raft::device_matrix_view out_val, - raft::device_matrix_view out_idx, - bool select_min, - bool sorted = false, - raft::matrix::SelectAlgo algo = raft::matrix::SelectAlgo::kAuto) RAFT_EXPLICIT; -} // namespace raft::sparse::matrix::detail - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ - extern template void raft::sparse::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_sparse_matrix_detail_select_k(__half, uint32_t); -instantiate_raft_sparse_matrix_detail_select_k(__half, int64_t); -instantiate_raft_sparse_matrix_detail_select_k(float, int64_t); -instantiate_raft_sparse_matrix_detail_select_k(float, uint32_t); -instantiate_raft_sparse_matrix_detail_select_k(float, int); -instantiate_raft_sparse_matrix_detail_select_k(double, int64_t); -instantiate_raft_sparse_matrix_detail_select_k(double, uint32_t); - -#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/include/raft/sparse/matrix/detail/select_k.cuh b/cpp/include/raft/sparse/matrix/detail/select_k.cuh index 5d52b94b2f..31a4b54a94 100644 --- a/cpp/include/raft/sparse/matrix/detail/select_k.cuh +++ b/cpp/include/raft/sparse/matrix/detail/select_k.cuh @@ -15,11 +15,4 @@ */ #pragma once -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "select_k-inl.cuh" - -#endif - -#ifdef RAFT_COMPILED -#include "select_k-ext.cuh" -#endif diff --git a/cpp/include/raft/sparse/solver/detail/lanczos.cuh b/cpp/include/raft/sparse/solver/detail/lanczos.cuh index 9ecb4b729f..02a77a0d99 100644 --- a/cpp/include/raft/sparse/solver/detail/lanczos.cuh +++ b/cpp/include/raft/sparse/solver/detail/lanczos.cuh @@ -19,10 +19,43 @@ // for cmath: #define _USE_MATH_DEFINES +#include +#include +#include +#include +#include +#include +#include #include #include #include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include #include @@ -30,9 +63,17 @@ #include +#include #include +#include +#include +#include #include +#include +#include +#include +#include #include namespace raft::sparse::solver::detail { @@ -1396,4 +1437,674 @@ int computeLargestEigenvectors( return status; } +template +RAFT_KERNEL kernel_triangular_populate(T* M, const T* beta, int n) +{ + int row = blockIdx.x * blockDim.x + threadIdx.x; + + if (row < n) { + // Upper diagonal: M[row + 1, row] in column-major + if (row < n - 1) { M[(row + 1) * n + row] = beta[row]; } + + // Lower diagonal: M[row - 1, row] in column-major + if (row > 0) { M[(row - 1) * n + row] = beta[row - 1]; } + } +} + +template +RAFT_KERNEL kernel_triangular_beta_k(T* t, const T* beta_k, int k, int n) +{ + int tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (tid < k) { + // Update the k-th column: t[i, k] -> t[k * n + i] in column-major + t[tid * n + k] = beta_k[tid]; + + // Update the k-th row: t[k, j] -> t[j * n + k] in column-major + t[k * n + tid] = beta_k[tid]; + } +} + +template +RAFT_KERNEL kernel_normalize(const T* u, const T* beta, int j, int n, T* v, T* V, int size) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i < size) { + if (beta[j] == 0) { + v[i] = u[i] / 1; + } else { + v[i] = u[i] / beta[j]; + } + V[i + (j + 1) * n] = v[i]; + } +} + +template +RAFT_KERNEL kernel_clamp_down(T* value, T threshold) +{ + *value = (fabs(*value) < threshold) ? 0 : *value; +} + +template +RAFT_KERNEL kernel_clamp_down_vector(T* vec, T threshold, int size) +{ + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < size) { vec[idx] = (fabs(vec[idx]) < threshold) ? 0 : vec[idx]; } +} + +template +void lanczos_solve_ritz( + raft::resources const& handle, + raft::device_matrix_view alpha, + raft::device_matrix_view beta, + std::optional> beta_k, + IndexTypeT k, + int which, + int ncv, + raft::device_matrix_view eigenvectors, + raft::device_vector_view eigenvalues) +{ + auto stream = resource::get_cuda_stream(handle); + + ValueTypeT zero = 0; + auto triangular_matrix = + raft::make_device_matrix(handle, ncv, ncv); + raft::matrix::fill(handle, triangular_matrix.view(), zero); + + raft::device_vector_view alphaVec = + raft::make_device_vector_view(alpha.data_handle(), ncv); + raft::matrix::set_diagonal(handle, alphaVec, triangular_matrix.view()); + + // raft::matrix::initializeDiagonalMatrix( + // alpha.data_handle(), triangular_matrix.data_handle(), ncv, ncv, stream); + + int blockSize = 256; + int numBlocks = (ncv + blockSize - 1) / blockSize; + kernel_triangular_populate + <<>>(triangular_matrix.data_handle(), beta.data_handle(), ncv); + + if (beta_k) { + int threadsPerBlock = 256; + int blocksPerGrid = (k + threadsPerBlock - 1) / threadsPerBlock; + kernel_triangular_beta_k<<>>( + triangular_matrix.data_handle(), beta_k.value().data_handle(), (int)k, ncv); + } + + auto triangular_matrix_view = + raft::make_device_matrix_view( + triangular_matrix.data_handle(), ncv, ncv); + + raft::linalg::eig_dc(handle, triangular_matrix_view, eigenvectors, eigenvalues); +} + +template +void lanczos_aux(raft::resources const& handle, + raft::device_csr_matrix_view A, + raft::device_matrix_view V, + raft::device_matrix_view u, + raft::device_matrix_view alpha, + raft::device_matrix_view beta, + int start_idx, + int end_idx, + int ncv, + raft::device_matrix_view v, + raft::device_matrix_view uu, + raft::device_matrix_view vv) +{ + auto stream = resource::get_cuda_stream(handle); + + IndexTypeT n = A.structure_view().get_n_rows(); + auto v_vector = raft::make_device_vector_view(v.data_handle(), n); + auto u_vector = raft::make_device_vector_view(u.data_handle(), n); + + raft::copy( + v.data_handle(), V.data_handle() + start_idx * V.stride(0), n, stream); // V(start_idx, 0) + + auto cusparse_h = resource::get_cusparse_handle(handle); + cusparseSpMatDescr_t cusparse_A = raft::sparse::linalg::detail::create_descriptor(A); + + cusparseDnVecDescr_t cusparse_v = raft::sparse::linalg::detail::create_descriptor(v_vector); + cusparseDnVecDescr_t cusparse_u = raft::sparse::linalg::detail::create_descriptor(u_vector); + + ValueTypeT one = 1; + ValueTypeT zero = 0; + size_t bufferSize; + raft::sparse::detail::cusparsespmv_buffersize(cusparse_h, + CUSPARSE_OPERATION_NON_TRANSPOSE, + &one, + cusparse_A, + cusparse_v, + &zero, + cusparse_u, + CUSPARSE_SPMV_ALG_DEFAULT, + &bufferSize, + stream); + auto cusparse_spmv_buffer = raft::make_device_vector(handle, bufferSize); + + for (int i = start_idx; i < end_idx; i++) { + raft::sparse::detail::cusparsespmv(cusparse_h, + CUSPARSE_OPERATION_NON_TRANSPOSE, + &one, + cusparse_A, + cusparse_v, + &zero, + cusparse_u, + CUSPARSE_SPMV_ALG_DEFAULT, + cusparse_spmv_buffer.data_handle(), + stream); + + auto alpha_i = + raft::make_device_scalar_view(alpha.data_handle() + i * alpha.stride(1)); // alpha(0, i) + raft::linalg::dot(handle, v_vector, u_vector, alpha_i); + + raft::matrix::fill(handle, vv, zero); + + auto cublas_h = resource::get_cublas_handle(handle); + + ValueTypeT alpha_i_host = 0; + ValueTypeT b = 0; + ValueTypeT mone = -1; + + raft::copy( + &b, beta.data_handle() + ((i - 1 + ncv) % ncv) * beta.stride(1), 1, stream); + raft::copy( + &alpha_i_host, alpha.data_handle() + i * alpha.stride(1), 1, stream); // alpha(0, i) + + raft::linalg::axpy(handle, n, &alpha_i_host, v.data_handle(), 1, vv.data_handle(), 1, stream); + raft::linalg::axpy(handle, + n, + &b, + V.data_handle() + (((i - 1 + ncv) % ncv) * V.stride(0)), + 1, + vv.data_handle(), + 1, + stream); + raft::linalg::axpy(handle, n, &mone, vv.data_handle(), 1, u.data_handle(), 1, stream); + + raft::linalg::gemv(handle, + CUBLAS_OP_T, + n, + i + 1, + &one, + V.data_handle(), + n, + u.data_handle(), + 1, + &zero, + uu.data_handle(), + 1, + stream); + + raft::linalg::gemv(handle, + CUBLAS_OP_N, + n, + i + 1, + &mone, + V.data_handle(), + n, + uu.data_handle(), + 1, + &one, + u.data_handle(), + 1, + stream); + + auto uu_i = raft::make_device_scalar_view(uu.data_handle() + uu.stride(1) * i); // uu(0, i) + raft::linalg::add(handle, make_const_mdspan(alpha_i), make_const_mdspan(uu_i), alpha_i); + + kernel_clamp_down<<<1, 1, 0, stream>>>(alpha_i.data_handle(), static_cast(1e-9)); + + auto output = raft::make_device_vector_view( + beta.data_handle() + beta.stride(1) * i, 1); + auto input = raft::make_device_matrix_view(u.data_handle(), 1, n); + raft::linalg::norm(handle, + input, + output, + raft::linalg::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op()); + + int blockSize = 256; + int numBlocks = (n + blockSize - 1) / blockSize; + + kernel_clamp_down_vector<<>>( + u.data_handle(), static_cast(1e-7), n); + + kernel_clamp_down<<<1, 1, 0, stream>>>(beta.data_handle() + beta.stride(1) * i, + static_cast(1e-6)); + + if (i >= end_idx - 1) { break; } + + int threadsPerBlock = 256; + int blocksPerGrid = (n + threadsPerBlock - 1) / threadsPerBlock; + + kernel_normalize<<>>( + u.data_handle(), beta.data_handle(), i, n, v.data_handle(), V.data_handle(), n); + } +} + +template +auto lanczos_smallest( + raft::resources const& handle, + raft::device_csr_matrix_view A, + int nEigVecs, + int maxIter, + int restartIter, + ValueTypeT tol, + ValueTypeT* eigVals_dev, + ValueTypeT* eigVecs_dev, + ValueTypeT* v0, + uint64_t seed) -> int +{ + int n = A.structure_view().get_n_rows(); + int ncv = restartIter; + auto stream = resource::get_cuda_stream(handle); + + auto V = raft::make_device_matrix(handle, ncv, n); + auto V_0_view = + raft::make_device_matrix_view(V.data_handle(), 1, n); // First Row V[0] + auto v0_view = raft::make_device_matrix_view(v0, 1, n); + + auto u = raft::make_device_matrix(handle, 1, n); + auto u_vector = raft::make_device_vector_view(u.data_handle(), n); + raft::copy(u.data_handle(), v0, n, stream); + + auto cublas_h = resource::get_cublas_handle(handle); + auto v0nrm = raft::make_device_vector(handle, 1); + raft::linalg::norm(handle, + v0_view, + v0nrm.view(), + raft::linalg::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op()); + + auto v0_vector_const = raft::make_device_vector_view(v0, n); + + raft::linalg::unary_op( + handle, v0_vector_const, V_0_view, [device_scalar = v0nrm.data_handle()] __device__(auto y) { + return y / *device_scalar; + }); + + auto alpha = raft::make_device_matrix(handle, 1, ncv); + auto beta = raft::make_device_matrix(handle, 1, ncv); + ValueTypeT zero = 0; + raft::matrix::fill(handle, alpha.view(), zero); + raft::matrix::fill(handle, beta.view(), zero); + + auto v = raft::make_device_matrix(handle, 1, n); + auto aux_uu = raft::make_device_matrix(handle, 1, ncv); + auto vv = raft::make_device_matrix(handle, 1, n); + + lanczos_aux(handle, + A, + V.view(), + u.view(), + alpha.view(), + beta.view(), + 0, + ncv, + ncv, + v.view(), + aux_uu.view(), + vv.view()); + + auto eigenvectors = + raft::make_device_matrix(handle, ncv, ncv); + auto eigenvalues = raft::make_device_vector(handle, ncv); + + lanczos_solve_ritz(handle, + alpha.view(), + beta.view(), + std::nullopt, + nEigVecs, + 0, + ncv, + eigenvectors.view(), + eigenvalues.view()); + + auto eigenvectors_k = raft::make_device_matrix_view( + eigenvectors.data_handle(), ncv, nEigVecs); + auto eigenvalues_k = + raft::make_device_vector_view(eigenvalues.data_handle(), nEigVecs); + + auto ritz_eigenvectors = + raft::make_device_matrix_view(eigVecs_dev, n, nEigVecs); + + auto V_T = + raft::make_device_matrix_view(V.data_handle(), n, ncv); + raft::linalg::gemm( + handle, V_T, eigenvectors_k, ritz_eigenvectors); + + auto s = raft::make_device_vector(handle, nEigVecs); + + auto eigenvectors_k_slice = + raft::make_device_matrix_view( + eigenvectors.data_handle(), ncv, nEigVecs); + auto S_matrix = raft::make_device_matrix_view( + s.data_handle(), 1, nEigVecs); + + raft::matrix::slice_coordinates coords(ncv - 1, 0, ncv, nEigVecs); + raft::matrix::slice(handle, make_const_mdspan(eigenvectors_k_slice), S_matrix, coords); + + auto beta_k = raft::make_device_vector(handle, nEigVecs); + raft::matrix::fill(handle, beta_k.view(), zero); + auto beta_scalar = raft::make_device_scalar_view(beta.data_handle() + + (ncv - 1) * beta.stride(1)); + + raft::linalg::axpy(handle, beta_scalar, raft::make_const_mdspan(s.view()), beta_k.view()); + + ValueTypeT res = 0; + + raft::device_vector output = + raft::make_device_vector(handle, 1); + raft::device_matrix_view input = + raft::make_device_matrix_view(beta_k.data_handle(), 1, nEigVecs); + raft::linalg::norm(handle, + input, + output.view(), + raft::linalg::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op()); + raft::copy(&res, output.data_handle(), 1, stream); + resource::sync_stream(handle, stream); + + auto uu = raft::make_device_matrix(handle, 0, nEigVecs); + int iter = ncv; + while (res > tol && iter < maxIter) { + auto beta_view = raft::make_device_matrix_view( + beta.data_handle(), 1, nEigVecs); + raft::matrix::fill(handle, beta_view, zero); + + raft::copy(alpha.data_handle(), eigenvalues_k.data_handle(), nEigVecs, stream); + + auto x_T = + raft::make_device_matrix_view(ritz_eigenvectors.data_handle(), nEigVecs, n); + + raft::copy(V.data_handle(), x_T.data_handle(), nEigVecs * n, stream); + + ValueTypeT one = 1; + ValueTypeT mone = -1; + + // Using raft::linalg::gemv leads to Reason=7:CUBLAS_STATUS_INVALID_VALUE (issue raft#2484) + raft::linalg::detail::cublasgemv(cublas_h, + CUBLAS_OP_T, + nEigVecs, + n, + &one, + V.data_handle(), + nEigVecs, + u.data_handle(), + 1, + &zero, + uu.data_handle(), + 1, + stream); + + raft::linalg::detail::cublasgemv(cublas_h, + CUBLAS_OP_N, + nEigVecs, + n, + &mone, + V.data_handle(), + nEigVecs, + uu.data_handle(), + 1, + &one, + u.data_handle(), + 1, + stream); + + auto V_0_view = + raft::make_device_matrix_view(V.data_handle() + (nEigVecs * n), 1, n); + auto V_0_view_vector = + raft::make_device_vector_view(V_0_view.data_handle(), n); + auto unrm = raft::make_device_vector(handle, 1); + raft::linalg::norm(handle, + raft::make_const_mdspan(u.view()), + unrm.view(), + raft::linalg::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op()); + + raft::linalg::unary_op( + handle, + raft::make_const_mdspan(u_vector), + V_0_view, + [device_scalar = unrm.data_handle()] __device__(auto y) { return y / *device_scalar; }); + + auto cusparse_h = resource::get_cusparse_handle(handle); + cusparseSpMatDescr_t cusparse_A = raft::sparse::linalg::detail::create_descriptor(A); + + cusparseDnVecDescr_t cusparse_v = + raft::sparse::linalg::detail::create_descriptor(V_0_view_vector); + cusparseDnVecDescr_t cusparse_u = raft::sparse::linalg::detail::create_descriptor(u_vector); + + ValueTypeT zero = 0; + size_t bufferSize; + raft::sparse::detail::cusparsespmv_buffersize(cusparse_h, + CUSPARSE_OPERATION_NON_TRANSPOSE, + &one, + cusparse_A, + cusparse_v, + &zero, + cusparse_u, + CUSPARSE_SPMV_ALG_DEFAULT, + &bufferSize, + stream); + auto cusparse_spmv_buffer = raft::make_device_vector(handle, bufferSize); + + raft::sparse::detail::cusparsespmv(cusparse_h, + CUSPARSE_OPERATION_NON_TRANSPOSE, + &one, + cusparse_A, + cusparse_v, + &zero, + cusparse_u, + CUSPARSE_SPMV_ALG_DEFAULT, + cusparse_spmv_buffer.data_handle(), + stream); + + auto alpha_k = raft::make_device_scalar_view(alpha.data_handle() + nEigVecs); + + raft::linalg::dot( + handle, make_const_mdspan(V_0_view_vector), make_const_mdspan(u_vector), alpha_k); + + raft::linalg::binary_op(handle, + make_const_mdspan(u_vector), + make_const_mdspan(V_0_view_vector), + u_vector, + [device_scalar_ptr = alpha_k.data_handle()] __device__( + ValueTypeT u_element, ValueTypeT V_0_element) { + return u_element - (*device_scalar_ptr) * V_0_element; + }); + + auto temp = raft::make_device_vector(handle, n); + + auto V_k = raft::make_device_matrix_view( + V.data_handle(), nEigVecs, n); + auto V_k_T = + raft::make_device_matrix(handle, n, nEigVecs); + + raft::linalg::transpose(handle, V_k, V_k_T.view()); + + ValueTypeT three = 3; + ValueTypeT two = 2; + + std::vector M = {1, 2, 3, 4, 5, 6}; + std::vector vec = {1, 1}; + + auto M_dev = raft::make_device_matrix(handle, 2, 3); + auto vec_dev = raft::make_device_vector(handle, 2); + auto out = raft::make_device_vector(handle, 3); + raft::copy(M_dev.data_handle(), M.data(), 6, stream); + raft::copy(vec_dev.data_handle(), vec.data(), 2, stream); + + raft::linalg::gemv(handle, + CUBLAS_OP_N, + three, + two, + &one, + M_dev.data_handle(), + three, + vec_dev.data_handle(), + 1, + &zero, + out.data_handle(), + 1, + stream); + + raft::linalg::gemv(handle, + CUBLAS_OP_N, + n, + nEigVecs, + &one, + V_k.data_handle(), + n, + beta_k.data_handle(), + 1, + &zero, + temp.data_handle(), + 1, + stream); + + auto one_scalar = raft::make_device_scalar(handle, 1); + raft::linalg::binary_op(handle, + make_const_mdspan(u_vector), + make_const_mdspan(temp.view()), + u_vector, + [device_scalar_ptr = one_scalar.data_handle()] __device__( + ValueTypeT u_element, ValueTypeT temp_element) { + return u_element - (*device_scalar_ptr) * temp_element; + }); + + auto output1 = raft::make_device_vector_view( + beta.data_handle() + beta.stride(1) * nEigVecs, 1); + raft::linalg::norm(handle, + raft::make_const_mdspan(u.view()), + output1, + raft::linalg::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op()); + + auto V_kplus1 = + raft::make_device_vector_view(V.data_handle() + V.stride(0) * (nEigVecs + 1), n); + + raft::linalg::unary_op( + handle, + make_const_mdspan(u_vector), + V_kplus1, + [device_scalar = (beta.data_handle() + beta.stride(1) * nEigVecs)] __device__(auto y) { + return y / *device_scalar; + }); + + lanczos_aux(handle, + A, + V.view(), + u.view(), + alpha.view(), + beta.view(), + nEigVecs + 1, + ncv, + ncv, + v.view(), + aux_uu.view(), + vv.view()); + iter += ncv - nEigVecs; + lanczos_solve_ritz(handle, + alpha.view(), + beta.view(), + beta_k.view(), + nEigVecs, + 0, + ncv, + eigenvectors.view(), + eigenvalues.view()); + auto eigenvectors_k = raft::make_device_matrix_view( + eigenvectors.data_handle(), ncv, nEigVecs); + + auto ritz_eigenvectors = raft::make_device_matrix_view( + eigVecs_dev, n, nEigVecs); + + auto V_T = + raft::make_device_matrix_view(V.data_handle(), n, ncv); + raft::linalg::gemm( + handle, V_T, eigenvectors_k, ritz_eigenvectors); + + auto eigenvectors_k_slice = + raft::make_device_matrix_view( + eigenvectors.data_handle(), ncv, nEigVecs); + auto S_matrix = raft::make_device_matrix_view( + s.data_handle(), 1, nEigVecs); + + raft::matrix::slice_coordinates coords(ncv - 1, 0, ncv, nEigVecs); + raft::matrix::slice(handle, make_const_mdspan(eigenvectors_k_slice), S_matrix, coords); + + raft::matrix::fill(handle, beta_k.view(), zero); + + auto beta_scalar = raft::make_device_scalar_view( + beta.data_handle() + beta.stride(1) * (ncv - 1)); // &((beta.view())(0, ncv - 1)) + + raft::linalg::axpy(handle, beta_scalar, raft::make_const_mdspan(s.view()), beta_k.view()); + + raft::device_vector output2 = + raft::make_device_vector(handle, 1); + raft::device_matrix_view input2 = + raft::make_device_matrix_view(beta_k.data_handle(), 1, nEigVecs); + raft::linalg::norm(handle, + input2, + output2.view(), + raft::linalg::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op()); + raft::copy(&res, output2.data_handle(), 1, stream); + resource::sync_stream(handle, stream); + RAFT_LOG_TRACE("Iteration %f: residual (tolerance) %d", iter, res); + } + + raft::copy(eigVals_dev, eigenvalues_k.data_handle(), nEigVecs, stream); + raft::copy(eigVecs_dev, ritz_eigenvectors.data_handle(), n * nEigVecs, stream); + + return 0; +} + +template +auto lanczos_compute_smallest_eigenvectors( + raft::resources const& handle, + lanczos_solver_config const& config, + raft::device_csr_matrix_view A, + std::optional> v0, + raft::device_vector_view eigenvalues, + raft::device_matrix_view eigenvectors) -> int +{ + if (v0.has_value()) { + return lanczos_smallest(handle, + A, + config.n_components, + config.max_iterations, + config.ncv, + config.tolerance, + eigenvalues.data_handle(), + eigenvectors.data_handle(), + v0->data_handle(), + config.seed); + } else { + // Handle the optional v0 initial Lanczos vector if nullopt is used + auto n = A.structure_view().get_n_rows(); + auto temp_v0 = raft::make_device_vector(handle, n); + raft::random::RngState rng_state(config.seed); + raft::random::uniform(handle, rng_state, temp_v0.view(), ValueTypeT{0.0}, ValueTypeT{1.0}); + return lanczos_smallest(handle, + A, + config.n_components, + config.max_iterations, + config.ncv, + config.tolerance, + eigenvalues.data_handle(), + eigenvectors.data_handle(), + temp_v0.data_handle(), + config.seed); + } +} + } // namespace raft::sparse::solver::detail diff --git a/cpp/include/raft/sparse/solver/lanczos.cuh b/cpp/include/raft/sparse/solver/lanczos.cuh index 1aa56d6ba2..fed31e6a9c 100644 --- a/cpp/include/raft/sparse/solver/lanczos.cuh +++ b/cpp/include/raft/sparse/solver/lanczos.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #pragma once #include +#include #include namespace raft::sparse::solver { @@ -27,6 +28,78 @@ namespace raft::sparse::solver { // Eigensolver // ========================================================= +/** + * @brief Find the smallest eigenpairs using lanczos solver + * @tparam index_type_t the type of data used for indexing. + * @tparam value_type_t the type of data used for weights, distances. + * @param handle the raft handle. + * @param config lanczos config used to set hyperparameters + * @param A Sparse matrix in CSR format. + * @param v0 Optional Initial lanczos vector + * @param eigenvalues output eigenvalues + * @param eigenvectors output eigenvectors + * @todo Add largest eigenvalues computation (issue #2483) + * @return Zero if successful. Otherwise non-zero. + */ +template +auto lanczos_compute_smallest_eigenvectors( + raft::resources const& handle, + lanczos_solver_config const& config, + raft::device_csr_matrix_view A, + std::optional> v0, + raft::device_vector_view eigenvalues, + raft::device_matrix_view eigenvectors) -> int +{ + return detail::lanczos_compute_smallest_eigenvectors( + handle, config, A, v0, eigenvalues, eigenvectors); +} + +/** + * @brief Find the smallest eigenpairs using lanczos solver + * @tparam index_type_t the type of data used for indexing. + * @tparam value_type_t the type of data used for weights, distances. + * @param handle the raft handle. + * @param config lanczos config used to set hyperparameters + * @param rows Vector view of the rows of the sparse matrix. + * @param cols Vector view of the cols of the sparse matrix. + * @param vals Vector view of the vals of the sparse matrix. + * @param v0 Optional Initial lanczos vector + * @param eigenvalues output eigenvalues + * @param eigenvectors output eigenvectors + * @todo Add largest eigenvalues computation (issue #2483) + * @return Zero if successful. Otherwise non-zero. + */ +template +auto lanczos_compute_smallest_eigenvectors( + raft::resources const& handle, + lanczos_solver_config const& config, + raft::device_vector_view rows, + raft::device_vector_view cols, + raft::device_vector_view vals, + std::optional> v0, + raft::device_vector_view eigenvalues, + raft::device_matrix_view eigenvectors) -> int +{ + IndexTypeT ncols = rows.extent(0) - 1; + IndexTypeT nrows = rows.extent(0) - 1; + IndexTypeT nnz = cols.extent(0); + + auto csr_structure = + raft::make_device_compressed_structure_view( + const_cast(rows.data_handle()), + const_cast(cols.data_handle()), + ncols, + nrows, + nnz); + + auto csr_matrix = + raft::make_device_csr_matrix_view( + const_cast(vals.data_handle()), csr_structure); + + return lanczos_compute_smallest_eigenvectors( + handle, config, csr_matrix, v0, eigenvalues, eigenvectors); +} + /** * @brief Compute smallest eigenvectors of symmetric matrix * Computes eigenvalues and eigenvectors that are least diff --git a/cpp/include/raft/sparse/solver/lanczos_types.hpp b/cpp/include/raft/sparse/solver/lanczos_types.hpp new file mode 100644 index 0000000000..edd5548079 --- /dev/null +++ b/cpp/include/raft/sparse/solver/lanczos_types.hpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::sparse::solver { + +template +struct lanczos_solver_config { + /** The number of eigenvalues and eigenvectors to compute. Must be 1 <= k < n.*/ + int n_components; + /** Maximum number of iteration. */ + int max_iterations; + /** The number of Lanczos vectors generated. Must be k + 1 < ncv < n. */ + int ncv; + /** Tolerance for residuals ``||Ax - wx||`` */ + ValueTypeT tolerance; + /** random seed */ + uint64_t seed; +}; + +} // namespace raft::sparse::solver diff --git a/cpp/include/raft/spectral/eigen_solvers.cuh b/cpp/include/raft/spectral/eigen_solvers.cuh index 4774d8b8ae..324f16ac7b 100644 --- a/cpp/include/raft/spectral/eigen_solvers.cuh +++ b/cpp/include/raft/spectral/eigen_solvers.cuh @@ -69,6 +69,7 @@ struct lanczos_solver_t { eigVals, eigVecs, config_.seed); + return iters; } diff --git a/cpp/include/raft_runtime/solver/lanczos.hpp b/cpp/include/raft_runtime/solver/lanczos.hpp new file mode 100644 index 0000000000..6c9d901bf1 --- /dev/null +++ b/cpp/include/raft_runtime/solver/lanczos.hpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace raft::runtime::solver { + +/** + * @defgroup lanczos_runtime lanczos Runtime API + * @{ + */ + +#define FUNC_DECL(IndexType, ValueType) \ + void lanczos_solver( \ + const raft::resources& handle, \ + raft::sparse::solver::lanczos_solver_config config, \ + raft::device_vector_view rows, \ + raft::device_vector_view cols, \ + raft::device_vector_view vals, \ + std::optional> v0, \ + raft::device_vector_view eigenvalues, \ + raft::device_matrix_view eigenvectors) + +FUNC_DECL(int, float); +FUNC_DECL(int64_t, float); +FUNC_DECL(int, double); +FUNC_DECL(int64_t, double); + +#undef FUNC_DECL + +/** @} */ // end group lanczos_runtime + +} // namespace raft::runtime::solver diff --git a/cpp/src/raft_runtime/solver/lanczos_solver.cuh b/cpp/src/raft_runtime/solver/lanczos_solver.cuh new file mode 100644 index 0000000000..0c851ef13a --- /dev/null +++ b/cpp/src/raft_runtime/solver/lanczos_solver.cuh @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define FUNC_DEF(IndexType, ValueType) \ + void lanczos_solver( \ + const raft::resources& handle, \ + raft::sparse::solver::lanczos_solver_config config, \ + raft::device_vector_view rows, \ + raft::device_vector_view cols, \ + raft::device_vector_view vals, \ + std::optional> v0, \ + raft::device_vector_view eigenvalues, \ + raft::device_matrix_view eigenvectors) \ + { \ + raft::sparse::solver::lanczos_compute_smallest_eigenvectors( \ + handle, config, rows, cols, vals, v0, eigenvalues, eigenvectors); \ + } diff --git a/cpp/src/raft_runtime/solver/lanczos_solver_int64_double.cu b/cpp/src/raft_runtime/solver/lanczos_solver_int64_double.cu new file mode 100644 index 0000000000..f772a8a0d1 --- /dev/null +++ b/cpp/src/raft_runtime/solver/lanczos_solver_int64_double.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lanczos_solver.cuh" + +namespace raft::runtime::solver { + +FUNC_DEF(int64_t, double); + +} // namespace raft::runtime::solver diff --git a/cpp/src/raft_runtime/solver/lanczos_solver_int64_float.cu b/cpp/src/raft_runtime/solver/lanczos_solver_int64_float.cu new file mode 100644 index 0000000000..efaf3be565 --- /dev/null +++ b/cpp/src/raft_runtime/solver/lanczos_solver_int64_float.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lanczos_solver.cuh" + +namespace raft::runtime::solver { + +FUNC_DEF(int64_t, float); + +} // namespace raft::runtime::solver diff --git a/cpp/src/raft_runtime/solver/lanczos_solver_int_double.cu b/cpp/src/raft_runtime/solver/lanczos_solver_int_double.cu new file mode 100644 index 0000000000..9bbc00e78a --- /dev/null +++ b/cpp/src/raft_runtime/solver/lanczos_solver_int_double.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lanczos_solver.cuh" + +namespace raft::runtime::solver { + +FUNC_DEF(int, double); + +} // namespace raft::runtime::solver diff --git a/cpp/src/raft_runtime/solver/lanczos_solver_int_float.cu b/cpp/src/raft_runtime/solver/lanczos_solver_int_float.cu new file mode 100644 index 0000000000..316a9fb7e1 --- /dev/null +++ b/cpp/src/raft_runtime/solver/lanczos_solver_int_float.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lanczos_solver.cuh" + +namespace raft::runtime::solver { + +FUNC_DEF(int, float); + +} // namespace raft::runtime::solver diff --git a/cpp/template/cmake/thirdparty/get_raft.cmake b/cpp/template/cmake/thirdparty/get_raft.cmake deleted file mode 100644 index 4474fd2875..0000000000 --- a/cpp/template/cmake/thirdparty/get_raft.cmake +++ /dev/null @@ -1,67 +0,0 @@ -# ============================================================================= -# Copyright (c) 2023-2024, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except -# in compliance with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing permissions and limitations under -# the License. - -# Use RAPIDS_VERSION from cmake/thirdparty/fetch_rapids.cmake -set(RAFT_VERSION "${RAPIDS_VERSION}") -set(RAFT_FORK "rapidsai") -set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") - -function(find_and_configure_raft) - set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY ENABLE_NVTX ENABLE_MNMG_DEPENDENCIES) - cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN} ) - - set(RAFT_COMPONENTS "") - if(PKG_COMPILE_LIBRARY) - string(APPEND RAFT_COMPONENTS " compiled") - endif() - - if(PKG_ENABLE_MNMG_DEPENDENCIES) - string(APPEND RAFT_COMPONENTS " distributed") - endif() - - #----------------------------------------------------- - # Invoke CPM find_package() - #----------------------------------------------------- - # Since the RAFT_NVTX option is used by targets generated by - # find_package(RAFT_NVTX) and when building from source we want to - # make `RAFT_NVTX` a cache variable so we get consistent - # behavior - # - set(RAFT_NVTX ${PKG_ENABLE_NVTX} CACHE BOOL "Enable raft nvtx logging" FORCE) - rapids_cpm_find(raft ${PKG_VERSION} - GLOBAL_TARGETS raft::raft - BUILD_EXPORT_SET raft-template-exports - INSTALL_EXPORT_SET raft-template-exports - COMPONENTS ${RAFT_COMPONENTS} - CPM_ARGS - GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git - GIT_TAG ${PKG_PINNED_TAG} - SOURCE_SUBDIR cpp - OPTIONS - "BUILD_TESTS OFF" - "BUILD_PRIMS_BENCH OFF" - "RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}" - ) -endfunction() - -# Change pinned tag here to test a commit in CI -# To use a different RAFT locally, set the CMake variable -# CPM_raft_SOURCE=/path/to/local/raft -find_and_configure_raft(VERSION ${RAFT_VERSION}.00 - FORK ${RAFT_FORK} - PINNED_TAG ${RAFT_PINNED_TAG} - COMPILE_LIBRARY ON - ENABLE_MNMG_DEPENDENCIES OFF - ENABLE_NVTX OFF -) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index a387e9ce09..621ee6c160 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -232,8 +232,8 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME SOLVERS_TEST PATH linalg/eigen_solvers.cu lap/lap.cu sparse/mst.cu LIB - EXPLICIT_INSTANTIATE_ONLY + NAME SOLVERS_TEST PATH linalg/eigen_solvers.cu lap/lap.cu sparse/mst.cu + sparse/solver/lanczos.cu LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( diff --git a/cpp/test/sparse/solver/lanczos.cu b/cpp/test/sparse/solver/lanczos.cu new file mode 100644 index 0000000000..74611a1fd8 --- /dev/null +++ b/cpp/test/sparse/solver/lanczos.cu @@ -0,0 +1,445 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace raft::sparse { + +template +struct lanczos_inputs { + int n_components; + int restartiter; + int maxiter; + int conv_n_iters; + float conv_eps; + float tol; + uint64_t seed; + std::vector rows; // indptr + std::vector cols; // indices + std::vector vals; // data + std::vector expected_eigenvalues; +}; + +template +struct rmat_lanczos_inputs { + int n_components; + int restartiter; + int maxiter; + int conv_n_iters; + float conv_eps; + float tol; + uint64_t seed; + int r_scale; + int c_scale; + float sparsity; + std::vector expected_eigenvalues; +}; + +template +class rmat_lanczos_tests + : public ::testing::TestWithParam> { + public: + rmat_lanczos_tests() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + rng(params.seed), + expected_eigenvalues(raft::make_device_vector( + handle, params.n_components)), + r_scale(params.r_scale), + c_scale(params.c_scale), + sparsity(params.sparsity) + { + } + + protected: + void SetUp() override + { + raft::copy(expected_eigenvalues.data_handle(), + params.expected_eigenvalues.data(), + params.n_components, + stream); + } + + void TearDown() override {} + + void Run() + { + uint64_t n_edges = sparsity * ((long long)(1 << r_scale) * (long long)(1 << c_scale)); + uint64_t n_nodes = 1 << std::max(r_scale, c_scale); + uint64_t theta_len = std::max(r_scale, c_scale) * 4; + + auto theta = raft::make_device_vector(handle, theta_len); + raft::random::uniform(handle, rng, theta.view(), 0, 1); + + auto out = + raft::make_device_matrix(handle, n_edges * 2, 2); + auto out_src = raft::make_device_vector(handle, n_edges); + auto out_dst = raft::make_device_vector(handle, n_edges); + + raft::random::RngState rng1{params.seed}; + + raft::random::rmat_rectangular_gen(handle, + rng1, + make_const_mdspan(theta.view()), + out.view(), + out_src.view(), + out_dst.view(), + r_scale, + c_scale); + + raft::device_vector out_data = + raft::make_device_vector(handle, n_edges); + raft::matrix::fill(handle, out_data.view(), 1.0); + raft::sparse::COO coo(stream); + + raft::sparse::op::coo_sort(n_nodes, + n_nodes, + n_edges, + out_src.data_handle(), + out_dst.data_handle(), + out_data.data_handle(), + stream); + raft::sparse::op::max_duplicates(handle, + coo, + out_src.data_handle(), + out_dst.data_handle(), + out_data.data_handle(), + n_edges, + n_nodes, + n_nodes); + + raft::sparse::COO symmetric_coo(stream); + raft::sparse::linalg::symmetrize( + handle, coo.rows(), coo.cols(), coo.vals(), coo.n_rows, coo.n_cols, coo.nnz, symmetric_coo); + + raft::device_vector row_indices = + raft::make_device_vector(handle, + symmetric_coo.n_rows + 1); + raft::sparse::convert::sorted_coo_to_csr(symmetric_coo.rows(), + symmetric_coo.nnz, + row_indices.data_handle(), + symmetric_coo.n_rows + 1, + stream); + + int n_components = params.n_components; + + raft::device_vector v0 = + raft::make_device_vector(handle, symmetric_coo.n_rows); + + raft::random::uniform(handle, rng, v0.view(), 0, 1); + std::tuple stats; + + raft::device_vector eigenvalues = + raft::make_device_vector(handle, n_components); + raft::device_matrix eigenvectors = + raft::make_device_matrix( + handle, symmetric_coo.n_rows, n_components); + + raft::spectral::matrix::sparse_matrix_t const csr_m{ + handle, + row_indices.data_handle(), + symmetric_coo.cols(), + symmetric_coo.vals(), + symmetric_coo.n_rows, + symmetric_coo.nnz}; + raft::sparse::solver::lanczos_solver_config config{ + n_components, params.maxiter, params.restartiter, params.tol, rng.seed}; + + auto csr_structure = + raft::make_device_compressed_structure_view( + const_cast(row_indices.data_handle()), + const_cast(symmetric_coo.cols()), + symmetric_coo.n_rows, + symmetric_coo.n_rows, + symmetric_coo.nnz); + + auto csr_matrix = raft::make_device_csr_matrix_view( + const_cast(symmetric_coo.vals()), csr_structure); + + std::get<0>(stats) = + raft::sparse::solver::lanczos_compute_smallest_eigenvectors( + handle, + config, + csr_matrix, + std::make_optional(v0.view()), + eigenvalues.view(), + eigenvectors.view()); + + ASSERT_TRUE(raft::devArrMatch(eigenvalues.data_handle(), + expected_eigenvalues.data_handle(), + n_components, + raft::CompareApprox(1e-5), + stream)); + } + + protected: + rmat_lanczos_inputs params; + raft::resources handle; + cudaStream_t stream; + raft::random::RngState rng; + int r_scale; + int c_scale; + float sparsity; + raft::device_vector expected_eigenvalues; +}; + +template +class lanczos_tests : public ::testing::TestWithParam> { + public: + lanczos_tests() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + n(params.rows.size() - 1), + nnz(params.vals.size()), + rng(params.seed), + rows(raft::make_device_vector(handle, n + 1)), + cols(raft::make_device_vector(handle, nnz)), + vals(raft::make_device_vector(handle, nnz)), + v0(raft::make_device_vector(handle, n)), + eigenvalues(raft::make_device_vector( + handle, params.n_components)), + eigenvectors(raft::make_device_matrix( + handle, n, params.n_components)), + expected_eigenvalues( + raft::make_device_vector(handle, params.n_components)) + { + } + + protected: + void SetUp() override + { + raft::copy(rows.data_handle(), params.rows.data(), n + 1, stream); + raft::copy(cols.data_handle(), params.cols.data(), nnz, stream); + raft::copy(vals.data_handle(), params.vals.data(), nnz, stream); + raft::copy(expected_eigenvalues.data_handle(), + params.expected_eigenvalues.data(), + params.n_components, + stream); + } + + void TearDown() override {} + + void Run() + { + raft::random::uniform(handle, rng, v0.view(), 0, 1); + std::tuple stats; + + raft::sparse::solver::lanczos_solver_config config{ + params.n_components, params.maxiter, params.restartiter, params.tol, rng.seed}; + auto csr_structure = + raft::make_device_compressed_structure_view( + const_cast(rows.data_handle()), + const_cast(cols.data_handle()), + n, + n, + nnz); + + auto csr_matrix = raft::make_device_csr_matrix_view( + const_cast(vals.data_handle()), csr_structure); + + std::get<0>(stats) = + raft::sparse::solver::lanczos_compute_smallest_eigenvectors( + handle, + config, + csr_matrix, + std::make_optional(v0.view()), + eigenvalues.view(), + eigenvectors.view()); + + ASSERT_TRUE(raft::devArrMatch(eigenvalues.data_handle(), + expected_eigenvalues.data_handle(), + params.n_components, + raft::CompareApprox(1e-5), + stream)); + } + + protected: + lanczos_inputs params; + raft::resources handle; + cudaStream_t stream; + int n; + int nnz; + raft::random::RngState rng; + raft::device_vector rows; + raft::device_vector cols; + raft::device_vector vals; + raft::device_vector v0; + raft::device_vector eigenvalues; + raft::device_matrix eigenvectors; + raft::device_vector expected_eigenvalues; +}; + +// TODO: Find a way to generate and validate test data without hardcoding them (issue #2485) +const std::vector> inputsf = { + {2, + 34, + 10000, + 0, + 0, + 1e-15, + 42, + {0, 0, 0, 0, 3, 5, 6, 8, 9, 11, 16, 16, 18, 20, 23, 24, 27, + 30, 31, 33, 37, 37, 39, 41, 43, 44, 46, 46, 47, 49, 50, 50, 51, 53, + 57, 58, 59, 66, 67, 68, 69, 71, 72, 75, 78, 83, 86, 90, 93, 94, 96, + 98, 99, 101, 101, 104, 106, 108, 109, 109, 109, 109, 111, 113, 118, 120, 121, 123, + 124, 128, 132, 134, 136, 138, 139, 141, 145, 148, 151, 152, 154, 155, 157, 160, 164, + 167, 170, 170, 170, 173, 178, 179, 182, 184, 186, 191, 192, 196, 198, 198, 198}, + {44, 68, 74, 16, 36, 85, 34, 75, 61, 51, 83, 15, 33, 55, 69, 71, 18, 84, 70, 95, 71, 83, + 97, 83, 9, 36, 54, 4, 42, 46, 52, 11, 89, 31, 37, 74, 96, 36, 88, 56, 64, 68, 94, 82, + 35, 90, 50, 82, 85, 83, 19, 47, 94, 9, 44, 56, 79, 6, 25, 4, 15, 21, 52, 75, 79, 92, + 19, 72, 94, 94, 96, 80, 16, 54, 89, 46, 48, 63, 3, 33, 67, 73, 77, 46, 47, 75, 16, 43, + 45, 81, 32, 45, 68, 43, 55, 63, 27, 89, 8, 17, 36, 15, 42, 96, 9, 49, 22, 33, 77, 7, + 75, 78, 88, 43, 49, 66, 76, 91, 22, 82, 69, 63, 84, 44, 3, 23, 47, 81, 9, 65, 76, 92, + 12, 96, 9, 13, 38, 93, 44, 3, 19, 6, 36, 45, 61, 63, 69, 89, 44, 57, 94, 62, 33, 36, + 41, 46, 68, 24, 28, 64, 8, 13, 14, 29, 11, 66, 88, 5, 28, 93, 21, 62, 84, 18, 42, 50, + 76, 91, 25, 63, 89, 97, 36, 69, 72, 85, 23, 32, 39, 40, 77, 12, 19, 40, 54, 70, 13, 91}, + {0.4734894, 0.1402491, 0.7686475, 0.0416142, 0.2559651, 0.9360436, 0.7486080, 0.5206724, + 0.0374126, 0.8082515, 0.5993828, 0.4866583, 0.8907925, 0.9251201, 0.8566143, 0.9528994, + 0.4557763, 0.4907070, 0.4158074, 0.8311127, 0.9026024, 0.3103237, 0.5876446, 0.7585195, + 0.4866583, 0.4493615, 0.5909155, 0.0416142, 0.0963910, 0.6722401, 0.3468698, 0.4557763, + 0.1445242, 0.7720124, 0.9923756, 0.1227579, 0.7194629, 0.8916773, 0.4320931, 0.5840980, + 0.0216121, 0.3709223, 0.1705930, 0.8297898, 0.2409706, 0.9585592, 0.3171389, 0.0228039, + 0.4350971, 0.4939908, 0.7720124, 0.2722416, 0.1792683, 0.8907925, 0.1085757, 0.8745620, + 0.3298612, 0.7486080, 0.2409706, 0.2559651, 0.4493615, 0.8916773, 0.5540361, 0.5150571, + 0.9160119, 0.1767728, 0.9923756, 0.5717281, 0.1077409, 0.9368132, 0.6273088, 0.6616613, + 0.0963910, 0.9378265, 0.3059566, 0.3159291, 0.0449106, 0.9085807, 0.4734894, 0.1085757, + 0.2909013, 0.7787509, 0.7168902, 0.9691764, 0.2669757, 0.4389115, 0.6722401, 0.3159291, + 0.9691764, 0.7467896, 0.2722416, 0.2669757, 0.1532843, 0.0449106, 0.2023634, 0.8934466, + 0.3171389, 0.6594226, 0.8082515, 0.3468698, 0.5540361, 0.5909155, 0.9378265, 0.2909178, + 0.9251201, 0.2023634, 0.5840980, 0.8745620, 0.2624605, 0.0374126, 0.1034030, 0.3736577, + 0.3315690, 0.9085807, 0.8934466, 0.5548525, 0.2302140, 0.7827352, 0.0216121, 0.8262919, + 0.1646078, 0.5548525, 0.2658700, 0.2909013, 0.1402491, 0.3709223, 0.1532843, 0.5792196, + 0.8566143, 0.1646078, 0.0827300, 0.5810611, 0.4158074, 0.5188584, 0.9528994, 0.9026024, + 0.5717281, 0.7269946, 0.7787509, 0.7686475, 0.1227579, 0.5206724, 0.5150571, 0.4389115, + 0.1034030, 0.2302140, 0.0827300, 0.8961608, 0.7168902, 0.2624605, 0.4823034, 0.3736577, + 0.3298612, 0.9160119, 0.6616613, 0.7467896, 0.5792196, 0.8297898, 0.0228039, 0.8262919, + 0.5993828, 0.3103237, 0.7585195, 0.4939908, 0.4907070, 0.2658700, 0.0844443, 0.9360436, + 0.4350971, 0.6997072, 0.4320931, 0.3315690, 0.0844443, 0.1445242, 0.3059566, 0.6594226, + 0.8961608, 0.6498466, 0.9585592, 0.7827352, 0.6498466, 0.2812338, 0.1767728, 0.5810611, + 0.7269946, 0.6997072, 0.1705930, 0.1792683, 0.1077409, 0.9368132, 0.4823034, 0.8311127, + 0.7194629, 0.6273088, 0.2909178, 0.5188584, 0.5876446, 0.2812338}, + {-2.0369630, -1.7673520}}}; + +const std::vector> inputsd = { + {2, + 34, + 10000, + 0, + 0, + 1e-15, + 42, + {0, 0, 0, 0, 3, 5, 6, 8, 9, 11, 16, 16, 18, 20, 23, 24, 27, + 30, 31, 33, 37, 37, 39, 41, 43, 44, 46, 46, 47, 49, 50, 50, 51, 53, + 57, 58, 59, 66, 67, 68, 69, 71, 72, 75, 78, 83, 86, 90, 93, 94, 96, + 98, 99, 101, 101, 104, 106, 108, 109, 109, 109, 109, 111, 113, 118, 120, 121, 123, + 124, 128, 132, 134, 136, 138, 139, 141, 145, 148, 151, 152, 154, 155, 157, 160, 164, + 167, 170, 170, 170, 173, 178, 179, 182, 184, 186, 191, 192, 196, 198, 198, 198}, + {44, 68, 74, 16, 36, 85, 34, 75, 61, 51, 83, 15, 33, 55, 69, 71, 18, 84, 70, 95, 71, 83, + 97, 83, 9, 36, 54, 4, 42, 46, 52, 11, 89, 31, 37, 74, 96, 36, 88, 56, 64, 68, 94, 82, + 35, 90, 50, 82, 85, 83, 19, 47, 94, 9, 44, 56, 79, 6, 25, 4, 15, 21, 52, 75, 79, 92, + 19, 72, 94, 94, 96, 80, 16, 54, 89, 46, 48, 63, 3, 33, 67, 73, 77, 46, 47, 75, 16, 43, + 45, 81, 32, 45, 68, 43, 55, 63, 27, 89, 8, 17, 36, 15, 42, 96, 9, 49, 22, 33, 77, 7, + 75, 78, 88, 43, 49, 66, 76, 91, 22, 82, 69, 63, 84, 44, 3, 23, 47, 81, 9, 65, 76, 92, + 12, 96, 9, 13, 38, 93, 44, 3, 19, 6, 36, 45, 61, 63, 69, 89, 44, 57, 94, 62, 33, 36, + 41, 46, 68, 24, 28, 64, 8, 13, 14, 29, 11, 66, 88, 5, 28, 93, 21, 62, 84, 18, 42, 50, + 76, 91, 25, 63, 89, 97, 36, 69, 72, 85, 23, 32, 39, 40, 77, 12, 19, 40, 54, 70, 13, 91}, + {0.4734894, 0.1402491, 0.7686475, 0.0416142, 0.2559651, 0.9360436, 0.7486080, 0.5206724, + 0.0374126, 0.8082515, 0.5993828, 0.4866583, 0.8907925, 0.9251201, 0.8566143, 0.9528994, + 0.4557763, 0.4907070, 0.4158074, 0.8311127, 0.9026024, 0.3103237, 0.5876446, 0.7585195, + 0.4866583, 0.4493615, 0.5909155, 0.0416142, 0.0963910, 0.6722401, 0.3468698, 0.4557763, + 0.1445242, 0.7720124, 0.9923756, 0.1227579, 0.7194629, 0.8916773, 0.4320931, 0.5840980, + 0.0216121, 0.3709223, 0.1705930, 0.8297898, 0.2409706, 0.9585592, 0.3171389, 0.0228039, + 0.4350971, 0.4939908, 0.7720124, 0.2722416, 0.1792683, 0.8907925, 0.1085757, 0.8745620, + 0.3298612, 0.7486080, 0.2409706, 0.2559651, 0.4493615, 0.8916773, 0.5540361, 0.5150571, + 0.9160119, 0.1767728, 0.9923756, 0.5717281, 0.1077409, 0.9368132, 0.6273088, 0.6616613, + 0.0963910, 0.9378265, 0.3059566, 0.3159291, 0.0449106, 0.9085807, 0.4734894, 0.1085757, + 0.2909013, 0.7787509, 0.7168902, 0.9691764, 0.2669757, 0.4389115, 0.6722401, 0.3159291, + 0.9691764, 0.7467896, 0.2722416, 0.2669757, 0.1532843, 0.0449106, 0.2023634, 0.8934466, + 0.3171389, 0.6594226, 0.8082515, 0.3468698, 0.5540361, 0.5909155, 0.9378265, 0.2909178, + 0.9251201, 0.2023634, 0.5840980, 0.8745620, 0.2624605, 0.0374126, 0.1034030, 0.3736577, + 0.3315690, 0.9085807, 0.8934466, 0.5548525, 0.2302140, 0.7827352, 0.0216121, 0.8262919, + 0.1646078, 0.5548525, 0.2658700, 0.2909013, 0.1402491, 0.3709223, 0.1532843, 0.5792196, + 0.8566143, 0.1646078, 0.0827300, 0.5810611, 0.4158074, 0.5188584, 0.9528994, 0.9026024, + 0.5717281, 0.7269946, 0.7787509, 0.7686475, 0.1227579, 0.5206724, 0.5150571, 0.4389115, + 0.1034030, 0.2302140, 0.0827300, 0.8961608, 0.7168902, 0.2624605, 0.4823034, 0.3736577, + 0.3298612, 0.9160119, 0.6616613, 0.7467896, 0.5792196, 0.8297898, 0.0228039, 0.8262919, + 0.5993828, 0.3103237, 0.7585195, 0.4939908, 0.4907070, 0.2658700, 0.0844443, 0.9360436, + 0.4350971, 0.6997072, 0.4320931, 0.3315690, 0.0844443, 0.1445242, 0.3059566, 0.6594226, + 0.8961608, 0.6498466, 0.9585592, 0.7827352, 0.6498466, 0.2812338, 0.1767728, 0.5810611, + 0.7269946, 0.6997072, 0.1705930, 0.1792683, 0.1077409, 0.9368132, 0.4823034, 0.8311127, + 0.7194629, 0.6273088, 0.2909178, 0.5188584, 0.5876446, 0.2812338}, + {-2.0369630, -1.7673520}}}; + +const std::vector> rmat_inputsf = { + {50, 100, 10000, 0, 0, 1e-9, 42, 12, 12, 1, {-122.526794, -74.00686, -59.698284, -54.68617, + -49.686813, -34.02644, -32.130703, -31.26906, + -30.32097, -22.946098, -20.497862, -20.23817, + -19.269697, -18.42496, -17.675667, -17.013401, + -16.734581, -15.820215, -15.73925, -15.448187, + -15.044634, -14.692028, -14.127425, -13.967386, + -13.6237755, -13.469393, -13.181225, -12.777589, + -12.623185, -12.55508, -12.2874565, -12.053391, + -11.677346, -11.558279, -11.163732, -10.922034, + -10.7936945, -10.558049, -10.205776, -10.005316, + -9.559181, -9.491834, -9.242631, -8.883637, + -8.765364, -8.688508, -8.458255, -8.385196, + -8.217982, -8.0442095}}}; + +using LanczosTestF = lanczos_tests; +TEST_P(LanczosTestF, Result) { Run(); } + +using LanczosTestD = lanczos_tests; +TEST_P(LanczosTestD, Result) { Run(); } + +using RmatLanczosTestF = rmat_lanczos_tests; +TEST_P(RmatLanczosTestF, Result) { Run(); } + +INSTANTIATE_TEST_CASE_P(LanczosTests, LanczosTestF, ::testing::ValuesIn(inputsf)); +INSTANTIATE_TEST_CASE_P(LanczosTests, LanczosTestD, ::testing::ValuesIn(inputsd)); +INSTANTIATE_TEST_CASE_P(LanczosTests, RmatLanczosTestF, ::testing::ValuesIn(rmat_inputsf)); + +} // namespace raft::sparse diff --git a/docs/source/pylibraft_api.rst b/docs/source/pylibraft_api.rst index aaa359e646..ad7d2873d7 100644 --- a/docs/source/pylibraft_api.rst +++ b/docs/source/pylibraft_api.rst @@ -9,3 +9,4 @@ Python API pylibraft_api/common.rst pylibraft_api/random.rst + pylibraft_api/sparse.rst diff --git a/docs/source/pylibraft_api/sparse.rst b/docs/source/pylibraft_api/sparse.rst new file mode 100644 index 0000000000..b2c3f7a2b1 --- /dev/null +++ b/docs/source/pylibraft_api/sparse.rst @@ -0,0 +1,11 @@ +Sparse +====== + +This page provides pylibraft class references for the publicly-exposed elements of the `pylibraft.sparse.linalg.eigsh` package. + + +.. role:: py(code) + :language: python + :class: highlight + +.. autofunction:: pylibraft.sparse.linalg.eigsh \ No newline at end of file diff --git a/python/pylibraft/CMakeLists.txt b/python/pylibraft/CMakeLists.txt index 9bde613720..758c1e4711 100644 --- a/python/pylibraft/CMakeLists.txt +++ b/python/pylibraft/CMakeLists.txt @@ -87,6 +87,7 @@ rapids_cython_init() add_subdirectory(pylibraft/common) add_subdirectory(pylibraft/random) +add_subdirectory(pylibraft/sparse) if(DEFINED cython_lib_dir) rapids_cython_add_rpath_entries(TARGET raft PATHS "${cython_lib_dir}") diff --git a/python/pylibraft/pylibraft/sparse/CMakeLists.txt b/python/pylibraft/pylibraft/sparse/CMakeLists.txt new file mode 100644 index 0000000000..3779fd2715 --- /dev/null +++ b/python/pylibraft/pylibraft/sparse/CMakeLists.txt @@ -0,0 +1,15 @@ +# ============================================================================= +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +add_subdirectory(linalg) diff --git a/python/pylibraft/pylibraft/sparse/__init__.py b/python/pylibraft/pylibraft/sparse/__init__.py new file mode 100644 index 0000000000..c77def5bb0 --- /dev/null +++ b/python/pylibraft/pylibraft/sparse/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pylibraft.sparse import linalg + +__all__ = ["linalg"] diff --git a/python/pylibraft/pylibraft/sparse/linalg/CMakeLists.txt b/python/pylibraft/pylibraft/sparse/linalg/CMakeLists.txt new file mode 100644 index 0000000000..ef16981644 --- /dev/null +++ b/python/pylibraft/pylibraft/sparse/linalg/CMakeLists.txt @@ -0,0 +1,27 @@ +# ============================================================================= +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +# Set the list of Cython files to build +set(cython_sources lanczos.pyx) + +# TODO: should finally be replaced with 'compiled' library to be more generic, when that is +# available +set(linked_libraries raft::raft raft::compiled) + +# Build all of the Cython targets +rapids_cython_create_modules( + CXX + SOURCE_FILES "${cython_sources}" + LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS raft MODULE_PREFIX sparse_ +) diff --git a/python/pylibraft/pylibraft/sparse/linalg/__init__.pxd b/python/pylibraft/pylibraft/sparse/linalg/__init__.pxd new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/sparse/linalg/__init__.py b/python/pylibraft/pylibraft/sparse/linalg/__init__.py new file mode 100644 index 0000000000..04a8106496 --- /dev/null +++ b/python/pylibraft/pylibraft/sparse/linalg/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .lanczos import eigsh + +__all__ = ["eigsh"] diff --git a/python/pylibraft/pylibraft/sparse/linalg/cpp/__init__.pxd b/python/pylibraft/pylibraft/sparse/linalg/cpp/__init__.pxd new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/sparse/linalg/cpp/__init__.py b/python/pylibraft/pylibraft/sparse/linalg/cpp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/sparse/linalg/lanczos.pyx b/python/pylibraft/pylibraft/sparse/linalg/lanczos.pyx new file mode 100644 index 0000000000..dc2a84b428 --- /dev/null +++ b/python/pylibraft/pylibraft/sparse/linalg/lanczos.pyx @@ -0,0 +1,277 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import cupy as cp +import numpy as np + +from cython.operator cimport dereference as deref +from libc.stdint cimport int64_t, uint32_t, uint64_t, uintptr_t + +from pylibraft.common import Handle, cai_wrapper, device_ndarray +from pylibraft.common.handle import auto_sync_handle + +from libcpp cimport bool + +from pylibraft.common.cpp.mdspan cimport ( + col_major, + device_matrix_view, + device_vector_view, + make_device_matrix_view, + make_device_vector_view, + row_major, +) +from pylibraft.common.cpp.optional cimport optional +from pylibraft.common.handle cimport device_resources +from pylibraft.random.cpp.rng_state cimport RngState + + +cdef extern from "raft/sparse/solver/lanczos_types.hpp" \ + namespace "raft::sparse::solver" nogil: + + cdef cppclass lanczos_solver_config[ValueTypeT]: + int n_components + int max_iterations + int ncv + ValueTypeT tolerance + uint64_t seed + +cdef lanczos_solver_config[float] config_float +cdef lanczos_solver_config[double] config_double + +cdef extern from "raft_runtime/solver/lanczos.hpp" \ + namespace "raft::runtime::solver" nogil: + + cdef void lanczos_solver( + const device_resources &handle, + lanczos_solver_config[double] config, + device_vector_view[int64_t, uint32_t] rows, + device_vector_view[int64_t, uint32_t] cols, + device_vector_view[double, uint32_t] vals, + optional[device_vector_view[double, uint32_t]] v0, + device_vector_view[double, uint32_t] eigenvalues, + device_matrix_view[double, uint32_t, col_major] eigenvectors) except + + + cdef void lanczos_solver( + const device_resources &handle, + lanczos_solver_config[float] config, + device_vector_view[int64_t, uint32_t] rows, + device_vector_view[int64_t, uint32_t] cols, + device_vector_view[float, uint32_t] vals, + optional[device_vector_view[float, uint32_t]] v0, + device_vector_view[float, uint32_t] eigenvalues, + device_matrix_view[float, uint32_t, col_major] eigenvectors) except + + + cdef void lanczos_solver( + const device_resources &handle, + lanczos_solver_config[double] config, + device_vector_view[int, uint32_t] rows, + device_vector_view[int, uint32_t] cols, + device_vector_view[double, uint32_t] vals, + optional[device_vector_view[double, uint32_t]] v0, + device_vector_view[double, uint32_t] eigenvalues, + device_matrix_view[double, uint32_t, col_major] eigenvectors) except + + + cdef void lanczos_solver( + const device_resources &handle, + lanczos_solver_config[float] config, + device_vector_view[int, uint32_t] rows, + device_vector_view[int, uint32_t] cols, + device_vector_view[float, uint32_t] vals, + optional[device_vector_view[float, uint32_t]] v0, + device_vector_view[float, uint32_t] eigenvalues, + device_matrix_view[float, uint32_t, col_major] eigenvectors) except + + + +@auto_sync_handle +def eigsh(A, k=6, v0=None, ncv=None, maxiter=None, + tol=0, seed=None, handle=None): + """ + Find ``k`` eigenvalues and eigenvectors of the real symmetric square + matrix or complex Hermitian matrix ``A``. + + Solves ``Ax = wx``, the standard eigenvalue problem for ``w`` eigenvalues + with corresponding eigenvectors ``x``. + + Args: + a (spmatrix): A symmetric square sparse CSR matrix with + dimension ``(n, n)``. ``a`` must be of type + :class:`cupyx.scipy.sparse._csr.csr_matrix` + k (int): The number of eigenvalues and eigenvectors to compute. Must be + ``1 <= k < n``. + v0 (ndarray): Starting vector for iteration. If ``None``, a random + unit vector is used. + ncv (int): The number of Lanczos vectors generated. Must be + ``k + 1 < ncv < n``. If ``None``, default value is used. + maxiter (int): Maximum number of Lanczos update iterations. + If ``None``, default value is used. + tol (float): Tolerance for residuals ``||Ax - wx||``. If ``0``, machine + precision is used. + + Returns: + tuple: + It returns ``w`` and ``x`` + where ``w`` is eigenvalues and ``x`` is eigenvectors. + + .. seealso:: + :func:`scipy.sparse.linalg.eigsh` + :func:`cupyx.scipy.sparse.linalg.eigsh` + + .. note:: + This function uses the thick-restart Lanczos methods + (https://sdm.lbl.gov/~kewu/ps/trlan.html). + + """ + + if A is None: + raise Exception("'A' cannot be None!") + + rows = A.indptr + cols = A.indices + vals = A.data + + rows = cai_wrapper(rows) + cols = cai_wrapper(cols) + vals = cai_wrapper(vals) + + IndexType = rows.dtype + ValueType = vals.dtype + + N = A.shape[0] + n = N + nnz = A.nnz + + rows_ptr = rows.data + cols_ptr = cols.data + vals_ptr = vals.data + cdef optional[device_vector_view[double, uint32_t]] d_v0 + cdef optional[device_vector_view[float, uint32_t]] f_v0 + + if ncv is None: + ncv = min(n, max(2*k + 1, 20)) + else: + ncv = min(max(ncv, k + 2), n - 1) + + seed = seed if seed is not None else 42 + if maxiter is None: + maxiter = 10 * n + if tol == 0: + tol = np.finfo(ValueType).eps + + eigenvectors = device_ndarray.empty((N, k), dtype=ValueType, order='F') + eigenvalues = device_ndarray.empty((k), dtype=ValueType, order='F') + + eigenvectors_cai = cai_wrapper(eigenvectors) + eigenvalues_cai = cai_wrapper(eigenvalues) + + eigenvectors_ptr = eigenvectors_cai.data + eigenvalues_ptr = eigenvalues_cai.data + + handle = handle if handle is not None else Handle() + cdef device_resources *h = handle.getHandle() + + if IndexType == np.int32 and ValueType == np.float32: + config_float.n_components = k + config_float.max_iterations = maxiter + config_float.ncv = ncv + config_float.tolerance = tol + config_float.seed = seed + if v0 is not None: + v0 = cai_wrapper(v0) + v0_ptr = v0.data + f_v0 = make_device_vector_view(v0_ptr, N) + lanczos_solver( + deref(h), + config_float, + make_device_vector_view(rows_ptr, (N + 1)), + make_device_vector_view(cols_ptr, nnz), + make_device_vector_view(vals_ptr, nnz), + f_v0, + make_device_vector_view(eigenvalues_ptr, k), + make_device_matrix_view[float, uint32_t, col_major]( + eigenvectors_ptr, N, k), + ) + elif IndexType == np.int64 and ValueType == np.float32: + config_float.n_components = k + config_float.max_iterations = maxiter + config_float.ncv = ncv + config_float.tolerance = tol + config_float.seed = seed + if v0 is not None: + v0 = cai_wrapper(v0) + v0_ptr = v0.data + f_v0 = make_device_vector_view(v0_ptr, N) + lanczos_solver( + deref(h), + config_float, + make_device_vector_view(rows_ptr, (N + 1)), + make_device_vector_view(cols_ptr, nnz), + make_device_vector_view(vals_ptr, nnz), + f_v0, + make_device_vector_view(eigenvalues_ptr, k), + make_device_matrix_view[float, uint32_t, col_major]( + eigenvectors_ptr, N, k), + ) + elif IndexType == np.int32 and ValueType == np.float64: + config_double.n_components = k + config_double.max_iterations = maxiter + config_double.ncv = ncv + config_double.tolerance = tol + config_double.seed = seed + if v0 is not None: + v0 = cai_wrapper(v0) + v0_ptr = v0.data + d_v0 = make_device_vector_view(v0_ptr, N) + lanczos_solver( + deref(h), + config_double, + make_device_vector_view(rows_ptr, (N + 1)), + make_device_vector_view(cols_ptr, nnz), + make_device_vector_view(vals_ptr, nnz), + d_v0, + make_device_vector_view(eigenvalues_ptr, k), + make_device_matrix_view[double, uint32_t, col_major]( + eigenvectors_ptr, N, k), + ) + elif IndexType == np.int64 and ValueType == np.float64: + config_double.n_components = k + config_double.max_iterations = maxiter + config_double.ncv = ncv + config_double.tolerance = tol + config_double.seed = seed + if v0 is not None: + v0 = cai_wrapper(v0) + v0_ptr = v0.data + d_v0 = make_device_vector_view(v0_ptr, N) + lanczos_solver( + deref(h), + config_double, + make_device_vector_view(rows_ptr, (N + 1)), + make_device_vector_view(cols_ptr, nnz), + make_device_vector_view(vals_ptr, nnz), + d_v0, + make_device_vector_view(eigenvalues_ptr, k), + make_device_matrix_view[double, uint32_t, col_major]( + eigenvectors_ptr, N, k), + ) + else: + raise ValueError("dtype IndexType=%s and ValueType=%s not supported" % + (IndexType, ValueType)) + + return (cp.asarray(eigenvalues), cp.asarray(eigenvectors)) diff --git a/python/pylibraft/pylibraft/test/test_sparse.py b/python/pylibraft/pylibraft/test/test_sparse.py new file mode 100644 index 0000000000..10b261d322 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_sparse.py @@ -0,0 +1,142 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import cupy +import cupyx.scipy.sparse.linalg # NOQA +import numpy +import pytest +from cupyx.scipy import sparse + +from pylibraft.sparse.linalg import eigsh + + +def shaped_random( + shape, xp=cupy, dtype=numpy.float32, scale=10, seed=0, order="C" +): + """ + Returns an array filled with random values. + + Args + ---- + shape(tuple): Shape of returned ndarray. + xp(numpy or cupy): Array module to use. + dtype(dtype): Dtype of returned ndarray. + scale(float): Scaling factor of elements. + seed(int): Random seed. + + Returns + ------- + numpy.ndarray or cupy.ndarray: The array with + given shape, array module, + + If ``dtype`` is ``numpy.bool_``, the elements are + independently drawn from ``True`` and ``False`` + with same probabilities. + Otherwise, the array is filled with samples + independently and identically drawn + from uniform distribution over :math:`[0, scale)` + with specified dtype. + """ + numpy.random.seed(seed) + dtype = numpy.dtype(dtype) + if dtype == "?": + a = numpy.random.randint(2, size=shape) + elif dtype.kind == "c": + a = numpy.random.rand(*shape) + 1j * numpy.random.rand(*shape) + a *= scale + else: + a = numpy.random.rand(*shape) * scale + return xp.asarray(a, dtype=dtype, order=order) + + +class TestEigsh: + n = 30 + density = 0.33 + tol = {numpy.float32: 1e-5, numpy.complex64: 1e-5, "default": 1e-12} + res_tol = {"f": 1e-5, "d": 1e-12} + return_eigenvectors = True + + def _make_matrix(self, dtype, xp): + shape = (self.n, self.n) + a = shaped_random(shape, xp, dtype=dtype) + mask = shaped_random(shape, xp, dtype="f", scale=1) + a[mask > self.density] = 0 + a = a * a.conj().T + return a + + def _test_eigsh(self, a, k, xp, sp): + expected_ret = sp.linalg.eigsh( + a, k=k, return_eigenvectors=self.return_eigenvectors + ) + actual_ret = eigsh(a, k=k) + if self.return_eigenvectors: + w, x = actual_ret + exp_w, _ = expected_ret + # Check the residuals to see if eigenvectors are correct. + ax_xw = a @ x - xp.multiply(x, w.reshape(1, k)) + res = xp.linalg.norm(ax_xw) / xp.linalg.norm(w) + tol = self.res_tol[numpy.dtype(a.dtype).char.lower()] + assert res < tol + else: + w = actual_ret + exp_w = expected_ret + w = xp.sort(w) + cupy.allclose(w, exp_w, rtol=tol, atol=tol) + + @pytest.mark.parametrize("format", ["csr"]) # , 'csc', 'coo']) + @pytest.mark.parametrize("k", [3, 6, 12]) + @pytest.mark.parametrize("dtype", ["f", "d"]) + def test_sparse(self, format, k, dtype, xp=cupy, sp=sparse): + if format == "csc": + pytest.xfail("may be buggy") # trans=True + + a = self._make_matrix(dtype, xp) + a = sp.coo_matrix(a).asformat(format) + return self._test_eigsh(a, k, xp, sp) + + def test_invalid(self): + xp, sp = cupy, sparse + a = xp.diag(xp.ones((self.n,), dtype="f")) + with pytest.raises(ValueError): + sp.linalg.eigsh(xp.ones((2, 1), dtype="f")) + with pytest.raises(ValueError): + sp.linalg.eigsh(a, k=0) + a = xp.diag(xp.ones((self.n,), dtype="f")) + with pytest.raises(ValueError): + sp.linalg.eigsh(xp.ones((1,), dtype="f")) + with pytest.raises(TypeError): + sp.linalg.eigsh(xp.ones((2, 2), dtype="i")) + with pytest.raises(ValueError): + sp.linalg.eigsh(a, k=self.n) + + def test_starting_vector(self): + # Make symmetric matrix + aux = self._make_matrix("f", cupy) + aux = sparse.coo_matrix(aux).asformat("csr") + matrix = (aux + aux.T) / 2.0 + + # Find reference eigenvector + ew, ev = eigsh(matrix, k=1) + v = ev[:, 0] + + # Obtain non-converged eigenvector from random initial guess. + ew_aux, ev_aux = eigsh(matrix, k=1, ncv=1, maxiter=0) + v_aux = cupy.copysign(ev_aux[:, 0], v) + + # Obtain eigenvector using known eigenvector as initial guess. + ew_v0, ev_v0 = eigsh(matrix, k=1, v0=v.copy(), ncv=1, maxiter=0) + v_v0 = cupy.copysign(ev_v0[:, 0], v) + + assert cupy.linalg.norm(v - v_v0) < cupy.linalg.norm(v - v_aux)