From cda2cb8414453656d6c528f8028301d982ed25f9 Mon Sep 17 00:00:00 2001 From: achirkin Date: Tue, 19 Mar 2024 20:33:13 +0100 Subject: [PATCH] Fix incorrect addressing using TxN_t --- .../neighbors/detail/cagra/compute_distance_vpq.cuh | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh b/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh index aef71f04f7..526c4fcdd9 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh @@ -131,8 +131,8 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t= dim) break; - vq_vals[m].load(reinterpret_cast(vq_code_book_ptr), - d + (dim * vq_code)); + vq_vals[m].load( + reinterpret_cast(vq_code_book_ptr + d + (dim * vq_code)), 0); } // Compute distance std::uint32_t pq_code = pq_codes[e]; @@ -169,8 +169,8 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t= dim) break; // Loading 4 x 8/16-bit VQ-values using 32/64-bit load ops (from L2$ or device memory) - vq_vals[m].load(reinterpret_cast(vq_code_book_ptr), - d + (dim * vq_code)); + vq_vals[m].load( + reinterpret_cast(vq_code_book_ptr + d + (dim * vq_code)), 0); } // Compute distance std::uint32_t pq_code = pq_codes[e]; @@ -178,8 +178,9 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t= dim) break; raft::TxN_t pq_vals; - pq_vals.load(reinterpret_cast(smem_pq_code_book_ptr), - (PQ_CODE_BOOK_DIM * (pq_code & 0xff))); // (from L1$ or smem) + pq_vals.load(reinterpret_cast(smem_pq_code_book_ptr + + PQ_CODE_BOOK_DIM * (pq_code & 0xff)), + 0); // (from L1$ or smem) #pragma unroll for (std::uint32_t m = 0; m < PQ_CODE_BOOK_DIM; m++) { const std::uint32_t d1 = m + (PQ_CODE_BOOK_DIM * v);