Skip to content

Commit

Permalink
Fix incorrect addressing using TxN_t
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Mar 19, 2024
1 parent 38a8bf2 commit cda2cb8
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
for (std::uint32_t m = 0; m < PQ_CODE_BOOK_DIM; m += 1) {
const uint32_t d = (vlen * m) + (PQ_CODE_BOOK_DIM * k);
if (d >= dim) break;
vq_vals[m].load(reinterpret_cast<const half2*>(vq_code_book_ptr),
d + (dim * vq_code));
vq_vals[m].load(
reinterpret_cast<const half2*>(vq_code_book_ptr + d + (dim * vq_code)), 0);
}
// Compute distance
std::uint32_t pq_code = pq_codes[e];
Expand Down Expand Up @@ -169,17 +169,18 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
const std::uint32_t d = (vlen * m) + (PQ_CODE_BOOK_DIM * k);
if (d >= dim) break;
// Loading 4 x 8/16-bit VQ-values using 32/64-bit load ops (from L2$ or device memory)
vq_vals[m].load(reinterpret_cast<const half2*>(vq_code_book_ptr),
d + (dim * vq_code));
vq_vals[m].load(
reinterpret_cast<const half2*>(vq_code_book_ptr + d + (dim * vq_code)), 0);
}
// Compute distance
std::uint32_t pq_code = pq_codes[e];
#pragma unroll
for (std::uint32_t v = 0; v < vlen; v++) {
if (PQ_CODE_BOOK_DIM * (v + k) >= dim) break;
raft::TxN_t<CODE_BOOK_T, PQ_CODE_BOOK_DIM> pq_vals;
pq_vals.load(reinterpret_cast<const half2*>(smem_pq_code_book_ptr),
(PQ_CODE_BOOK_DIM * (pq_code & 0xff))); // (from L1$ or smem)
pq_vals.load(reinterpret_cast<const half2*>(smem_pq_code_book_ptr +
PQ_CODE_BOOK_DIM * (pq_code & 0xff)),
0); // (from L1$ or smem)
#pragma unroll
for (std::uint32_t m = 0; m < PQ_CODE_BOOK_DIM; m++) {
const std::uint32_t d1 = m + (PQ_CODE_BOOK_DIM * v);
Expand Down

0 comments on commit cda2cb8

Please sign in to comment.