Skip to content

Commit

Permalink
Fix Q desc
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Mar 11, 2024
1 parent 2389194 commit 423a1ae
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 185 deletions.
3 changes: 2 additions & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ void search_main(raft::resources const& res,
search_main_core<dataset_desc_t, CagraSampleFilterT>(
res, params, dataset_desc, graph_internal, queries, neighbors, distances, sample_filter);
} else { // CAGRA Q
using dataset_desc_t = cagra_q_dataset_descriptor_t<half, 8, 4, 0, DistanceT, Internal_IdxT, 0>;
using dataset_desc_t =
cagra_q_dataset_descriptor_t<T, half, 8, 4, 0, DistanceT, Internal_IdxT, 0>;
dataset_desc_t dataset_desc(nullptr, 0, 0, nullptr, 0, nullptr, 0, 0, 0);

search_main_core<dataset_desc_t, CagraSampleFilterT>(
Expand Down
40 changes: 21 additions & 19 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ template <unsigned TEAM_SIZE,
_RAFT_DEVICE void compute_distance_to_random_nodes(
INDEX_T* const result_indices_ptr, // [num_pickup]
DISTANCE_T* const result_distances_ptr, // [num_pickup]
const float* const query_buffer,
const typename DATASET_DESCRIPTOR_T::QUERY_T* const query_buffer,
const DATASET_DESCRIPTOR_T& dataset_desc,
const std::size_t num_pickup,
const unsigned num_distilation,
Expand Down Expand Up @@ -142,21 +142,22 @@ template <unsigned TEAM_SIZE,
class DATASET_DESCRIPTOR_T,
class DISTANCE_T,
class INDEX_T>
_RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_indices_ptr,
DISTANCE_T* const result_child_distances_ptr,
// query
const float* const query_buffer,
// [dataset_dim, dataset_size]
const DATASET_DESCRIPTOR_T& dataset_desc,
// [knn_k, dataset_size]
const INDEX_T* const knn_graph,
const std::uint32_t knn_k,
// hashmap
INDEX_T* const visited_hashmap_ptr,
const std::uint32_t hash_bitlen,
const INDEX_T* const parent_indices,
const INDEX_T* const internal_topk_list,
const std::uint32_t search_width)
_RAFT_DEVICE void compute_distance_to_child_nodes(
INDEX_T* const result_child_indices_ptr,
DISTANCE_T* const result_child_distances_ptr,
// query
const typename DATASET_DESCRIPTOR_T::QUERY_T* const query_buffer,
// [dataset_dim, dataset_size]
const DATASET_DESCRIPTOR_T& dataset_desc,
// [knn_k, dataset_size]
const INDEX_T* const knn_graph,
const std::uint32_t knn_k,
// hashmap
INDEX_T* const visited_hashmap_ptr,
const std::uint32_t hash_bitlen,
const INDEX_T* const parent_indices,
const INDEX_T* const internal_topk_list,
const std::uint32_t search_width)
{
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;
const INDEX_T invalid_index = utils::get_max_value<INDEX_T>();
Expand Down Expand Up @@ -224,8 +225,9 @@ template <class DATA_T_,
class DISTANCE_T = float>
struct standard_dataset_descriptor_t
: public dataset_descriptor_base_t<float, DISTANCE_T, INDEX_T> {
using LOAD_T = device::LOAD_128BIT_T;
using DATA_T = DATA_T_;
using LOAD_T = device::LOAD_128BIT_T;
using DATA_T = DATA_T_;
using QUERY_T = typename dataset_descriptor_base_t<float, DISTANCE_T, INDEX_T>::QUERY_T;
static const std::uint32_t DATASET_BLOCK_DIM = DATASET_BLOCK_DIM_;
static const std::uint32_t TEAM_SIZE = TEAM_SIZE_;

Expand All @@ -245,7 +247,7 @@ struct standard_dataset_descriptor_t
static const std::uint32_t smem_buffer_size_in_byte = 0;
__device__ void set_smem_ptr(void* const){};

__device__ DISTANCE_T compute_similarity(const float* const query_ptr,
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T dataset_i,
const bool valid) const
{
Expand Down
8 changes: 6 additions & 2 deletions cpp/include/raft/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@ _RAFT_HOST_DEVICE inline uint64_t xorshift64(uint64_t u)
return u * 0x2545F4914F6CDD1DULL;
}

template <class T>
template <class T, unsigned X_MAX = 1024>
_RAFT_DEVICE inline T swizzling(T x)
{
// Address swizzling reduces bank conflicts in shared memory, but increases
// the amount of operation instead.
// return x;
return x ^ (x >> 5); // "x" must be less than 1024
if constexpr (X_MAX <= 1024) {
return (x) ^ ((x) >> 5);
} else {
return (x) ^ (((x) >> 5) & 0x1f);
}
}

} // namespace device
Expand Down
32 changes: 20 additions & 12 deletions cpp/include/raft/neighbors/detail/cagra/q.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
#include "compute_distance.hpp"

namespace raft::neighbors::cagra::detail {
template <class CODE_BOOK_T,
template <class DATA_T_,
class CODE_BOOK_T_,
unsigned PQ_BITS,
unsigned PQ_CODE_BOOK_DIM,
unsigned DATASET_BLOCK_DIM_,
class DISTANCE_T,
class INDEX_T,
unsigned TEAM_SIZE_>
struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DISTANCE_T, INDEX_T> {
using LOAD_T = device::LOAD_128BIT_T;
using DATA_T = CODE_BOOK_T;
using LOAD_T = device::LOAD_128BIT_T;
using DATA_T = DATA_T_;
using CODE_BOOK_T = CODE_BOOK_T_;
using QUERY_T = typename dataset_descriptor_base_t<half, DISTANCE_T, INDEX_T>::QUERY_T;
static const std::uint32_t DATASET_BLOCK_DIM = DATASET_BLOCK_DIM_;
static const std::uint32_t TEAM_SIZE = TEAM_SIZE_;

Expand All @@ -36,27 +39,28 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
const std::uint32_t n_subspace;
const CODE_BOOK_T* const vq_code_book_ptr;
const float vq_scale;
const CODE_BOOK_T* const pq_code_book_ptr;
CODE_BOOK_T* pq_code_book_ptr;
const float pq_scale;
using dataset_descriptor_base_t<half, DISTANCE_T, INDEX_T>::size;
using dataset_descriptor_base_t<half, DISTANCE_T, INDEX_T>::dim;

// Set on device
CODE_BOOK_T* pq_codebook_buffer;
static const std::uint32_t smem_buffer_size_in_byte =
(1 << PQ_BITS) * PQ_CODE_BOOK_DIM * utils::size_of<CODE_BOOK_T>();

__device__ void set_smem_ptr(void* const smem_ptr)
{
pq_code_book_ptr = reinterpret_cast<CODE_BOOK_T*>(smem_ptr);

// Copy PQ table
}

cagra_q_dataset_descriptor_t(const std::uint8_t* const encoded_dataset_ptr,
const std::uint32_t encoded_dataset_dim,
const std::uint32_t n_subspace,
const CODE_BOOK_T* const vq_codebook_ptr,
const CODE_BOOK_T* const vq_code_book_ptr,
const float vq_scale,
const CODE_BOOK_T* const pq_codebook_ptr,
CODE_BOOK_T* const pq_code_book_ptr,
const float pq_scale,
const std::size_t size,
const std::uint32_t dim)
Expand All @@ -71,7 +75,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
{
}

__device__ DISTANCE_T compute_similarity(const half* const query_ptr,
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T node_id,
const bool valid) const
{
Expand All @@ -90,7 +94,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
const std::uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen;
if (k >= n_subspace) break;
// Loading 4 x 8-bit PQ-codes using 32-bit load ops (from device memory)
pq_codes[e] = *(static_cast<std::uint32_t*>(
pq_codes[e] = *(reinterpret_cast<const std::uint32_t*>(
encoded_dataset_ptr + (static_cast<std::uint64_t>(encoded_dataset_dim) * node_id) + 4 +
k));
}
Expand Down Expand Up @@ -184,27 +188,31 @@ template <std::uint32_t DATASET_BLOCK_DIM_OUT,
std::uint32_t TEAM_SIZE_OUT,
std::uint32_t DATASET_BLOCK_DIM_IN,
std::uint32_t TEAM_SIZE_IN,
class DATA_T,
class INDEX_T,
class DISTANCE_T,
class CODE_BOOK_T,
unsigned PQ_BITS,
unsigned PQ_CODE_BOOK_DIM>
cagra_q_dataset_descriptor_t<CODE_BOOK_T,
cagra_q_dataset_descriptor_t<DATA_T,
CODE_BOOK_T,
PQ_BITS,
PQ_CODE_BOOK_DIM,
DATASET_BLOCK_DIM_OUT,
DISTANCE_T,
INDEX_T,
TEAM_SIZE_OUT>
set_compute_template_params(cagra_q_dataset_descriptor_t<CODE_BOOK_T,
set_compute_template_params(cagra_q_dataset_descriptor_t<DATA_T,
CODE_BOOK_T,
PQ_BITS,
PQ_CODE_BOOK_DIM,
DATASET_BLOCK_DIM_IN,
DISTANCE_T,
INDEX_T,
TEAM_SIZE_IN>& desc_in)
{
return cagra_q_dataset_descriptor_t<CODE_BOOK_T,
return cagra_q_dataset_descriptor_t<DATA_T,
CODE_BOOK_T,
PQ_BITS,
PQ_CODE_BOOK_DIM,
DATASET_BLOCK_DIM_OUT,
Expand Down
Loading

0 comments on commit 423a1ae

Please sign in to comment.