From 5b136f8d0fc497f34891c371c58a6b65c2b11f36 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 19 Feb 2024 10:49:02 +0000 Subject: [PATCH] add c++/python test, bugfixes --- .../raft/neighbors/detail/ivf_common.cuh | 65 ++++++++++++------- .../detail/ivf_flat_interleaved_scan-inl.cuh | 8 ++- .../neighbors/detail/ivf_flat_search-inl.cuh | 18 ++--- cpp/test/neighbors/ann_ivf_flat.cuh | 17 +++++ .../pylibraft/pylibraft/test/test_ivf_flat.py | 10 ++- 5 files changed, 79 insertions(+), 39 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_common.cuh b/cpp/include/raft/neighbors/detail/ivf_common.cuh index e6f2e09340..c258c05ea5 100644 --- a/cpp/include/raft/neighbors/detail/ivf_common.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_common.cuh @@ -30,9 +30,10 @@ namespace raft::neighbors::ivf::detail { template constexpr static IdxT kOutOfBoundsRecord = std::numeric_limits::max(); -template +template struct dummy_block_sort_t { - using queue_t = matrix::detail::select::warpsort::warp_sort_distributed; + using queue_t = + matrix::detail::select::warpsort::warp_sort_distributed; template __device__ dummy_block_sort_t(int k, Args...){}; }; @@ -215,36 +216,52 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk] float scaling_factor, rmm::cuda_stream_view stream) { - size_t len = size_t(n_queries) * size_t(topk); + constexpr bool needs_cast = !std::is_same::value; + size_t len = size_t(n_queries) * size_t(topk); switch (metric) { case distance::DistanceType::L2Unexpanded: case distance::DistanceType::L2Expanded: { - linalg::unaryOp( - out, - in, - len, - raft::compose_op(raft::mul_const_op{scaling_factor * scaling_factor}, - raft::cast_op{}), - stream); + if (scaling_factor != 0) { + linalg::unaryOp( + out, + in, + len, + raft::compose_op(raft::mul_const_op{scaling_factor * scaling_factor}, + raft::cast_op{}), + stream); + } else if (needs_cast) { + linalg::unaryOp(out, in, len, raft::cast_op{}, stream); + } } break; case distance::DistanceType::L2SqrtUnexpanded: case distance::DistanceType::L2SqrtExpanded: { - linalg::unaryOp(out, - in, - len, - raft::compose_op{raft::mul_const_op{scaling_factor}, - raft::sqrt_op{}, - raft::cast_op{}}, - stream); + if (scaling_factor != 0) { + linalg::unaryOp(out, + in, + len, + raft::compose_op{raft::mul_const_op{scaling_factor}, + raft::sqrt_op{}, + raft::cast_op{}}, + stream); + } else if (needs_cast) { + linalg::unaryOp( + out, in, len, raft::compose_op{raft::sqrt_op{}, raft::cast_op{}}, stream); + } else { + linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream); + } } break; case distance::DistanceType::InnerProduct: { - linalg::unaryOp( - out, - in, - len, - raft::compose_op(raft::mul_const_op{-scaling_factor * scaling_factor}, - raft::cast_op{}), - stream); + if (scaling_factor != 0) { + linalg::unaryOp( + out, + in, + len, + raft::compose_op(raft::mul_const_op{-scaling_factor * scaling_factor}, + raft::cast_op{}), + stream); + } else if (needs_cast) { + linalg::unaryOp(out, in, len, raft::cast_op{}, stream); + } } break; default: RAFT_FAIL("Unexpected metric."); } diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index b44ca470ae..060a130d86 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -642,8 +642,9 @@ struct flat_block_sort { }; template -struct flat_block_sort<0, Ascending, T, IdxT> : ivf::detail::dummy_block_sort_t { - using type = ivf::detail::dummy_block_sort_t; +struct flat_block_sort<0, Ascending, T, IdxT> + : ivf::detail::dummy_block_sort_t { + using type = ivf::detail::dummy_block_sort_t; }; template @@ -944,7 +945,8 @@ void launch_kernel(Lambda lambda, neighbors += grid_dim_y * grid_dim_x * k; distances += grid_dim_y * grid_dim_x * k; } else { - neighbors += grid_dim_y * max_samples; + distances += grid_dim_y * max_samples; + chunk_indices += grid_dim_y * n_probes; } coarse_index += grid_dim_y * n_probes; } diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 7a51cee67b..1087b6a8d0 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -227,7 +227,7 @@ void search_impl(raft::resources const& handle, refined_distances_dev.data(), refined_indices_dev.data(), n_queries, - manage_local_topk ? k * grid_dim_x : max_samples, + manage_local_topk ? (k * grid_dim_x) : max_samples, k, distances, neighbors, @@ -236,7 +236,7 @@ void search_impl(raft::resources const& handle, if (!manage_local_topk) { // post process distances && neighbor IDs ivf::detail::postprocess_distances( - distances, distances, index.metric(), n_queries, k, 1.0, stream); + distances, distances, index.metric(), n_queries, k, 0, stream); ivf::detail::postprocess_neighbors(neighbors, neighbors, index.inds_ptrs().data_handle(), @@ -276,7 +276,8 @@ inline void search(raft::resources const& handle, uint32_t max_samples = 0; if (!manage_local_topk) { - IdxT ms = Pow2<128 / sizeof(float)>::roundUp(index.accum_sorted_sizes()(n_probes)); + IdxT ms = + Pow2<128 / sizeof(float)>::roundUp(std::max(index.accum_sorted_sizes()(n_probes), k)); RAFT_EXPECTS(ms <= IdxT(std::numeric_limits::max()), "The maximum sample size is too big."); max_samples = ms; @@ -286,14 +287,9 @@ inline void search(raft::resources const& handle, constexpr uint64_t kExpectedWsSize = 1024 * 1024 * 1024; uint64_t max_ws_size = std::min(resource::get_workspace_free_bytes(handle), kExpectedWsSize); - uint64_t ws_size_per_query = - // fixed - 4ull * (2 * n_probes + index.n_lists() + index.dim() + 1) - // fused - + manage_local_topk - ? ((sizeof(IdxT) + 4) * n_probes * k) - // non-fused - : (4ull * (max_samples + n_probes + 1)); + uint64_t ws_size_per_query = 4ull * (2 * n_probes + index.n_lists() + index.dim() + 1) + + (manage_local_topk ? ((sizeof(IdxT) + 4) * n_probes * k) + : (4ull * (max_samples + n_probes + 1))); const uint32_t max_queries = std::min(n_queries, raft::div_rounding_up_safe(max_ws_size, ws_size_per_query)); diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 7cde4b1566..26be743eec 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -608,6 +608,23 @@ const std::vector> inputs = { {1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true}, {10000, 131072, 8, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false}, + // various combinations with k>raft::matrix::detail::select::warpsort::kMaxCapacity + {1000, 10000, 16, 1024, 40, 1024, raft::distance::DistanceType::L2SqrtExpanded, true}, + {1000, 10000, 2053, 512, 50, 1024, raft::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 2049, 2048, 70, 1024, raft::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 16, 4000, 100, 2048, raft::distance::DistanceType::L2SqrtExpanded, false}, + {10, 10000, 16, 4000, 100, 2048, raft::distance::DistanceType::L2SqrtExpanded, false}, + {10, 10000, 16, 4000, 120, 2048, raft::distance::DistanceType::L2SqrtExpanded, true}, + {20, 100000, 16, 257, 20, 1024, raft::distance::DistanceType::L2SqrtExpanded, true}, + {1000, 100000, 16, 259, 20, 1024, raft::distance::DistanceType::L2Expanded, true, true}, + {10000, 131072, 8, 280, 20, 1024, raft::distance::DistanceType::InnerProduct, false}, + {100000, 1024, 32, 257, 64, 64, raft::distance::DistanceType::L2Expanded, false}, + {100000, 1024, 32, 257, 64, 64, raft::distance::DistanceType::L2SqrtExpanded, false}, + {100000, 1024, 32, 257, 64, 64, raft::distance::DistanceType::InnerProduct, false}, + {100000, 1024, 16, 300, 20, 60, raft::distance::DistanceType::L2Expanded, false}, + {100000, 1024, 16, 500, 20, 60, raft::distance::DistanceType::L2SqrtExpanded, false}, + {100000, 1024, 16, 700, 20, 60, raft::distance::DistanceType::InnerProduct, false}, + // host input data {1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::L2Expanded, false, true}, {1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::L2Expanded, false, true}, diff --git a/python/pylibraft/pylibraft/test/test_ivf_flat.py b/python/pylibraft/pylibraft/test/test_ivf_flat.py index 23140073f1..2e38dab7bc 100644 --- a/python/pylibraft/pylibraft/test/test_ivf_flat.py +++ b/python/pylibraft/pylibraft/test/test_ivf_flat.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -297,6 +297,14 @@ def test_ivf_flat_params(params): "k": 129, "n_probes": 100, }, + { + "k": 257, + "n_probes": 100, + }, + { + "k": 4096, + "n_probes": 100, + }, ], ) def test_ivf_pq_search_params(params):