Skip to content

Commit

Permalink
Use TxN_t
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Mar 19, 2024
1 parent 3ff9382 commit 38a8bf2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 57 deletions.
45 changes: 4 additions & 41 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "utils.hpp"

#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/vectorized.cuh>

#include <type_traits>

Expand All @@ -36,44 +37,6 @@ _RAFT_DEVICE constexpr unsigned get_vlen()
return utils::size_of<LOAD_T>() / utils::size_of<DATA_T>();
}

template <std::uint32_t VECLEN>
struct code_book_load_t_core {
using type = void;
};
template <>
struct code_book_load_t_core<1> {
using type = std::uint8_t;
};
template <>
struct code_book_load_t_core<2> {
using type = std::uint16_t;
};
template <>
struct code_book_load_t_core<4> {
using type = std::uint32_t;
};
template <>
struct code_book_load_t_core<8> {
using type = LOAD_64BIT_T;
};
template <>
struct code_book_load_t_core<16> {
using type = LOAD_128BIT_T;
};

template <class T, std::uint32_t vlen>
struct code_book_load_t {
using type = typename code_book_load_t_core<utils::size_of<T>() * vlen>::type;
};

template <class DATA_T, unsigned VLEN>
struct data_load_t {
union {
typename code_book_load_t<DATA_T, VLEN>::type load;
DATA_T data[VLEN];
};
};

template <unsigned TEAM_SIZE,
unsigned DATASET_BLOCK_DIM,
class DATASET_DESCRIPTOR_T,
Expand Down Expand Up @@ -256,7 +219,7 @@ struct standard_dataset_descriptor_t
constexpr unsigned vlen = device::get_vlen<LOAD_T, DATA_T>();
// #include <raft/util/cuda_dev_essentials.cuh
constexpr unsigned reg_nelem = raft::ceildiv<unsigned>(DATASET_BLOCK_DIM, TEAM_SIZE * vlen);
device::data_load_t<DATA_T, vlen> dl_buff[reg_nelem];
raft::TxN_t<DATA_T, vlen> dl_buff[reg_nelem];

DISTANCE_T norm2 = 0;
if (valid) {
Expand All @@ -265,7 +228,7 @@ struct standard_dataset_descriptor_t
for (uint32_t e = 0; e < reg_nelem; e++) {
const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset;
if (k >= dim) break;
dl_buff[e].load = *reinterpret_cast<const LOAD_T*>(dataset_ptr + k);
dl_buff[e].load(dataset_ptr, k);
}
#pragma unroll
for (uint32_t e = 0; e < reg_nelem; e++) {
Expand All @@ -279,7 +242,7 @@ struct standard_dataset_descriptor_t
// - Above the last element (dataset_dim-1), the query array is filled with zeros.
// - The data buffer has to be also padded with zeros.
DISTANCE_T diff = query_ptr[device::swizzling(kv)];
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]);
norm2 += diff * diff;
}
}
Expand Down
27 changes: 11 additions & 16 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,13 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
const std::uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen;
if (k >= n_subspace) break;
// Loading VQ code-book
device::data_load_t<half2, vlen / 2> vq_vals[PQ_CODE_BOOK_DIM];
using vq_vals_load_t = typename device::code_book_load_t<half2, vlen / 2>::type;
raft::TxN_t<half2, vlen / 2> vq_vals[PQ_CODE_BOOK_DIM];
#pragma unroll
for (std::uint32_t m = 0; m < PQ_CODE_BOOK_DIM; m += 1) {
const uint32_t d = (vlen * m) + (PQ_CODE_BOOK_DIM * k);
if (d >= dim) break;
// Loading 4 x 16-bit VQ-values using 64-bit load ops (from L2$ or device memory)
vq_vals[m].load =
*(reinterpret_cast<const vq_vals_load_t*>(vq_code_book_ptr + d + (dim * vq_code)));
vq_vals[m].load(reinterpret_cast<const half2*>(vq_code_book_ptr),
d + (dim * vq_code));
}
// Compute distance
std::uint32_t pq_code = pq_codes[e];
Expand All @@ -151,7 +149,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
// Loading PQ code book in smem
diff2 -= *(reinterpret_cast<half2*>(
smem_pq_code_book_ptr + (1 << PQ_BITS) * 2 * (m / 2) + (2 * (pq_code & 0xff))));
diff2 -= vq_vals[d1 / vlen].data[(d1 % vlen) / 2];
diff2 -= vq_vals[d1 / vlen].val.data[(d1 % vlen) / 2];
norm2 += diff2 * diff2;
}
pq_code >>= 8;
Expand All @@ -165,34 +163,31 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
const std::uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen;
if (k >= n_subspace) break;
// Loading VQ code-book
typename device::data_load_t<CODE_BOOK_T, vlen>::type vq_vals[PQ_CODE_BOOK_DIM];
using vq_vals_load_t = typename device::code_book_load_t<CODE_BOOK_T, vlen>::type;
raft::TxN_t<CODE_BOOK_T, vlen> vq_vals[PQ_CODE_BOOK_DIM];
#pragma unroll
for (std::uint32_t m = 0; m < PQ_CODE_BOOK_DIM; m++) {
const std::uint32_t d = (vlen * m) + (PQ_CODE_BOOK_DIM * k);
if (d >= dim) break;
// Loading 4 x 8/16-bit VQ-values using 32/64-bit load ops (from L2$ or device memory)
vq_vals[m].load =
*(reinterpret_cast<const vq_vals_load_t*>(vq_code_book_ptr + d + (dim * vq_code)));
vq_vals[m].load(reinterpret_cast<const half2*>(vq_code_book_ptr),
d + (dim * vq_code));
}
// Compute distance
std::uint32_t pq_code = pq_codes[e];
#pragma unroll
for (std::uint32_t v = 0; v < vlen; v++) {
if (PQ_CODE_BOOK_DIM * (v + k) >= dim) break;
typename device::data_load_t<CODE_BOOK_T, PQ_CODE_BOOK_DIM>::type pq_vals;
using pq_vals_load_t = device::code_book_load_t<CODE_BOOK_T, PQ_CODE_BOOK_DIM>;
pq_vals.load = *(reinterpret_cast<const pq_vals_load_t*>(
smem_pq_code_book_ptr +
(PQ_CODE_BOOK_DIM * (pq_code & 0xff)))); // (from L1$ or smem)
raft::TxN_t<CODE_BOOK_T, PQ_CODE_BOOK_DIM> pq_vals;
pq_vals.load(reinterpret_cast<const half2*>(smem_pq_code_book_ptr),
(PQ_CODE_BOOK_DIM * (pq_code & 0xff))); // (from L1$ or smem)
#pragma unroll
for (std::uint32_t m = 0; m < PQ_CODE_BOOK_DIM; m++) {
const std::uint32_t d1 = m + (PQ_CODE_BOOK_DIM * v);
const std::uint32_t d = d1 + (PQ_CODE_BOOK_DIM * k);
// if (d >= dataset_dim) break;
DISTANCE_T diff = query_ptr[d]; // (from smem)
diff -= pq_scale * static_cast<float>(pq_vals.data[m]);
diff -= vq_scale * static_cast<float>(vq_vals[d1 / vlen].data[d1 % vlen]);
diff -= vq_scale * static_cast<float>(vq_vals[d1 / vlen].val.data[d1 % vlen]);
norm += diff * diff;
}
pq_code >>= 8;
Expand Down

0 comments on commit 38a8bf2

Please sign in to comment.