Skip to content

Commit

Permalink
Add CAGRA-Q subspace dim = 4 support (#2244)
Browse files Browse the repository at this point in the history
This PR adds the support for subspace dim (pq_dim) = 4 in CAGRA-Q

Authors:
  - tsuki (https://github.com/enp1s0)

Approvers:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2244
  • Loading branch information
enp1s0 authored Apr 3, 2024
1 parent 6c7794f commit eabe3b0
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 17 deletions.
3 changes: 2 additions & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ void launch_vpq_search_main_core(
CagraSampleFilterT sample_filter)
{
RAFT_EXPECTS(vpq_dset->pq_bits() == 8, "Only pq_bits = 8 is supported for now");
RAFT_EXPECTS(vpq_dset->pq_len() == 2, "Only pq_len 2 is supported for now");
RAFT_EXPECTS(vpq_dset->pq_len() == 2 || vpq_dset->pq_len() == 4,
"Only pq_len 2 or 4 is supported for now");
RAFT_EXPECTS(vpq_dset->dim() % vpq_dset->pq_dim() == 0,
"dim must be a multiple of pq_dim at the moment");

Expand Down
29 changes: 16 additions & 13 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
using CODE_BOOK_T = CODE_BOOK_T_;
using QUERY_T = typename dataset_descriptor_base_t<half, DISTANCE_T, INDEX_T>::QUERY_T;

static_assert(std::is_same_v<CODE_BOOK_T, half>, "Only CODE_BOOK_T = `half` is supported now");

const std::uint8_t* encoded_dataset_ptr;
const std::uint32_t encoded_dataset_dim;
const std::uint32_t n_subspace;
Expand All @@ -53,18 +55,19 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
smem_pq_code_book_ptr = reinterpret_cast<CODE_BOOK_T*>(smem_ptr);

// Copy PQ table
if constexpr (std::is_same<CODE_BOOK_T, half>::value) {
for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) {
half2 buf2;
buf2.x = pq_code_book_ptr[i];
buf2.y = pq_code_book_ptr[i + 1];
(reinterpret_cast<half2*>(smem_pq_code_book_ptr + i))[0] = buf2;
}
} else {
for (unsigned i = threadIdx.x; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x) {
// TODO: vectorize
smem_pq_code_book_ptr[i] = pq_code_book_ptr[i];
}
for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) {
half2 buf2;
buf2.x = pq_code_book_ptr[i];
buf2.y = pq_code_book_ptr[i + 1];

// Change the order of PQ code book array to reduce the
// frequency of bank conflicts.
constexpr auto num_elements_per_bank = 4 / utils::size_of<CODE_BOOK_T>();
constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank;
const auto j = i / num_elements_per_bank;
const auto smem_index =
(j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS);
reinterpret_cast<half2*>(smem_pq_code_book_ptr)[smem_index] = buf2;
}
}

Expand Down Expand Up @@ -136,7 +139,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
4 + k));
}
//
if constexpr ((std::is_same<CODE_BOOK_T, half>::value) && (PQ_LEN % 2 == 0)) {
if constexpr (PQ_LEN % 2 == 0) {
// **** Use half2 for distance computation ****
half2 norm2{0, 0};
#pragma unroll
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/vpq_dataset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ auto fill_missing_params_heuristics(const vpq_params& params, const DatasetT& da
vpq_params r = params;
double n_rows = dataset.extent(0);
size_t dim = dataset.extent(1);
if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{2}); }
if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{4}); }
if (r.pq_bits == 0) { r.pq_bits = 8; }
if (r.vq_n_centers == 0) { r.vq_n_centers = raft::round_up_safe<uint32_t>(std::sqrt(n_rows), 8); }
if (r.vq_kmeans_trainset_fraction == 0) {
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/neighbors/ann_cagra_vpq.cuh
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class AnnCagraVpqTest : public ::testing::TestWithParam<AnnCagraVpqInputs> {
resource::sync_stream(handle_);
}

const auto vpq_k = ps.k * 16;
const auto vpq_k = ps.k * 4;
{
rmm::device_uvector<DistanceT> distances_dev(vpq_k * ps.n_queries, stream_);
rmm::device_uvector<IdxT> indices_dev(vpq_k * ps.n_queries, stream_);
Expand Down Expand Up @@ -319,7 +319,7 @@ const std::vector<AnnCagraVpqInputs> vpq_inputs = raft::util::itertools::product
{1000, 10000}, // n_rows
{128, 132, 192, 256, 512, 768}, // dim
{8, 12}, // k
{2}, // pq_len
{2, 4}, // pq_len
{8}, // pq_bits
{graph_build_algo::NN_DESCENT}, // build_algo
{search_algo::SINGLE_CTA, search_algo::MULTI_CTA}, // algo
Expand Down

0 comments on commit eabe3b0

Please sign in to comment.