Skip to content

Commit

Permalink
Merge branch 'branch-24.06' into rhdong/select_k_csr
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong authored Apr 4, 2024
2 parents 9e8ae31 + 8a68518 commit 5924bbb
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 41 deletions.
3 changes: 2 additions & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ void launch_vpq_search_main_core(
CagraSampleFilterT sample_filter)
{
RAFT_EXPECTS(vpq_dset->pq_bits() == 8, "Only pq_bits = 8 is supported for now");
RAFT_EXPECTS(vpq_dset->pq_len() == 2, "Only pq_len 2 is supported for now");
RAFT_EXPECTS(vpq_dset->pq_len() == 2 || vpq_dset->pq_len() == 4,
"Only pq_len 2 or 4 is supported for now");
RAFT_EXPECTS(vpq_dset->dim() % vpq_dset->pq_dim() == 0,
"dim must be a multiple of pq_dim at the moment");

Expand Down
29 changes: 16 additions & 13 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
using CODE_BOOK_T = CODE_BOOK_T_;
using QUERY_T = typename dataset_descriptor_base_t<half, DISTANCE_T, INDEX_T>::QUERY_T;

static_assert(std::is_same_v<CODE_BOOK_T, half>, "Only CODE_BOOK_T = `half` is supported now");

const std::uint8_t* encoded_dataset_ptr;
const std::uint32_t encoded_dataset_dim;
const std::uint32_t n_subspace;
Expand All @@ -53,18 +55,19 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
smem_pq_code_book_ptr = reinterpret_cast<CODE_BOOK_T*>(smem_ptr);

// Copy PQ table
if constexpr (std::is_same<CODE_BOOK_T, half>::value) {
for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) {
half2 buf2;
buf2.x = pq_code_book_ptr[i];
buf2.y = pq_code_book_ptr[i + 1];
(reinterpret_cast<half2*>(smem_pq_code_book_ptr + i))[0] = buf2;
}
} else {
for (unsigned i = threadIdx.x; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x) {
// TODO: vectorize
smem_pq_code_book_ptr[i] = pq_code_book_ptr[i];
}
for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) {
half2 buf2;
buf2.x = pq_code_book_ptr[i];
buf2.y = pq_code_book_ptr[i + 1];

// Change the order of PQ code book array to reduce the
// frequency of bank conflicts.
constexpr auto num_elements_per_bank = 4 / utils::size_of<CODE_BOOK_T>();
constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank;
const auto j = i / num_elements_per_bank;
const auto smem_index =
(j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS);
reinterpret_cast<half2*>(smem_pq_code_book_ptr)[smem_index] = buf2;
}
}

Expand Down Expand Up @@ -136,7 +139,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
4 + k));
}
//
if constexpr ((std::is_same<CODE_BOOK_T, half>::value) && (PQ_LEN % 2 == 0)) {
if constexpr (PQ_LEN % 2 == 0) {
// **** Use half2 for distance computation ****
half2 norm2{0, 0};
#pragma unroll
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/vpq_dataset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ auto fill_missing_params_heuristics(const vpq_params& params, const DatasetT& da
vpq_params r = params;
double n_rows = dataset.extent(0);
size_t dim = dataset.extent(1);
if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{2}); }
if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{4}); }
if (r.pq_bits == 0) { r.pq_bits = 8; }
if (r.vq_n_centers == 0) { r.vq_n_centers = raft::round_up_safe<uint32_t>(std::sqrt(n_rows), 8); }
if (r.vq_kmeans_trainset_fraction == 0) {
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/neighbors/ann_cagra_vpq.cuh
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class AnnCagraVpqTest : public ::testing::TestWithParam<AnnCagraVpqInputs> {
resource::sync_stream(handle_);
}

const auto vpq_k = ps.k * 16;
const auto vpq_k = ps.k * 4;
{
rmm::device_uvector<DistanceT> distances_dev(vpq_k * ps.n_queries, stream_);
rmm::device_uvector<IdxT> indices_dev(vpq_k * ps.n_queries, stream_);
Expand Down Expand Up @@ -319,7 +319,7 @@ const std::vector<AnnCagraVpqInputs> vpq_inputs = raft::util::itertools::product
{1000, 10000}, // n_rows
{128, 132, 192, 256, 512, 768}, // dim
{8, 12}, // k
{2}, // pq_len
{2, 4}, // pq_len
{8}, // pq_bits
{graph_build_algo::NN_DESCENT}, // build_algo
{search_algo::SINGLE_CTA, search_algo::MULTI_CTA}, // algo
Expand Down
52 changes: 28 additions & 24 deletions notebooks/VectorSearch_QuestionRetrieval.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "eb1e81c3",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -154,7 +154,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "ee4c5cc0",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -184,7 +184,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "0a1a6307",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -249,7 +249,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "ad90b4be",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -292,7 +292,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "724dcacb",
"metadata": {
"scrolled": true
Expand Down Expand Up @@ -320,7 +320,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "c27d4715",
"metadata": {},
"outputs": [
Expand All @@ -347,7 +347,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "bc375518",
"metadata": {},
"outputs": [
Expand All @@ -373,7 +373,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "ab154181",
"metadata": {},
"outputs": [
Expand All @@ -399,7 +399,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "2d6017ed",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -435,7 +435,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "f5cfb644",
"metadata": {},
"outputs": [
Expand All @@ -462,7 +462,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "b5694d00",
"metadata": {},
"outputs": [
Expand All @@ -489,7 +489,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "fcfc3c5b",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -528,7 +528,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "50df1f43-c580-4019-949a-06bdc7185536",
"metadata": {},
"outputs": [],
Expand All @@ -538,29 +538,29 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"id": "091cde52-4652-4230-af2b-75c35357f833",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1min 23s, sys: 2min 7s, total: 3min 31s\n",
"Wall time: 4min 43s\n"
"CPU times: user 35.3 s, sys: 4.5 s, total: 39.8 s\n",
"Wall time: 2.16 s\n"
]
}
],
"source": [
"%%time\n",
"params = cagra.IndexParams(intermediate_graph_degree=128, graph_degree=64)\n",
"params = cagra.IndexParams(intermediate_graph_degree=32, graph_degree=16, build_algo=\"nn_descent\")\n",
"cagra_index = cagra.build(params, corpus_embeddings)\n",
"search_params = cagra.SearchParams()"
"search_params = cagra.SearchParams(algo=\"multi_cta\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"id": "df229e21-f6b6-4d6c-ad54-2724f8738934",
"metadata": {},
"outputs": [],
Expand All @@ -569,9 +569,12 @@
" # Encode the query using the bi-encoder and find potentially relevant passages\n",
" question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n",
"\n",
" start_time = time.time()\n",
" hits = cagra.search(search_params, cagra_index, question_embedding[None], top_k)\n",
" end_time = time.time()\n",
"\n",
" # Output of top-k hits\n",
" print(\"Results (after {:.3f} seconds):\".format(end_time - start_time))\n",
" print(\"Input question:\", query)\n",
" for k in range(top_k):\n",
" print(\"\\t{:.3f}\\t{}\".format(hits[0][0, k], passages[hits[1][0, k]]))"
Expand All @@ -587,19 +590,20 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 16 µs, sys: 25 µs, total: 41 µs\n",
"Wall time: 83.7 µs\n",
"Results (after 0.005 seconds):\n",
"Input question: Who was Grace Hopper?\n",
"\t181.649\t['Grace Hopper', 'Hopper was born in New York, USA. Hopper graduated from Vassar College in 1928 and Yale University in 1934 with a Ph.D degree in mathematics. She joined the US Navy during the World War II in 1943. She worked on computers in the Navy for 43 years. She then worked in other private industry companies after 1949. She retired from the Navy in 1986 and died on January 1, 1992.']\n",
"\t192.946\t['Leona Helmsley', 'Leona Helmsley (July 4, 1920 – August 20, 2007) was an American businesswoman. She was known for having a flamboyant personality. She had a reputation for tyrannical behavior; she was nicknamed the Queen of Mean.']\n",
"\t194.951\t['Grace Hopper', 'Grace Murray Hopper (December 9 1906 – January 1 1992) was an American computer scientist and United States Navy officer.']\n",
"\t202.192\t['Nellie Bly', 'Elizabeth Cochrane Seaman (born Elizabeth Jane Cochran; May 5, 1864 – January 27, 1922), better known by her pen name Nellie Bly, was an American journalist, novelist and inventor. She was a newspaper reporter, who worked at various jobs for exposing poor working conditions. Nellie Bly, also, fought for women\\'s right and was known for investigative reporting. She best known for her record-breaking trip around the world in 72 days, inspired by the adventure novel \"Around the World in Eighty Days\" by Jules Verne. In the 1880s, she went undercover as a mentally ill patient in a psychiatric hospital for ten days, with the report being made public in a book called \"\"Ten Days in a Mad-House\"\". She was added to the National Women\\'s Hall of Fame in 1998.']\n",
"\t205.038\t['Abbie Hoffman', 'Abbot Howard \"Abbie\" Hoffman (November 30, 1936 – April 12, 1989) was an American social and political activist.']\n"
"\t205.038\t['Abbie Hoffman', 'Abbot Howard \"Abbie\" Hoffman (November 30, 1936 – April 12, 1989) was an American social and political activist.']\n",
"CPU times: user 4.18 ms, sys: 3.88 ms, total: 8.07 ms\n",
"Wall time: 9.97 ms\n"
]
}
],
"source": [
"%time \n",
"%%time \n",
"search_raft_cagra(query=\"Who was Grace Hopper?\")"
]
}
Expand All @@ -620,7 +624,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 5924bbb

Please sign in to comment.