Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
tarang-jain committed Feb 15, 2024
1 parent addb485 commit d534487
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "device_common.hpp"
#include "hashmap.hpp"
#include "raft/distance/distance_types.hpp"
#include "utils.hpp"
#include <type_traits>

Expand Down Expand Up @@ -61,7 +62,8 @@ struct distance_op<LOAD_T, DATA_T, DISTANCE_T, DATASET_BLOCK_DIM, TEAM_SIZE, fal

__device__ DISTANCE_T operator()(const DATA_T* const dataset_ptr,
const std::uint32_t dataset_dim,
const bool valid)
const bool valid,
raft::distance::DistanceType metric)
{
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
constexpr unsigned vlen = get_vlen<LOAD_T, DATA_T>();
Expand All @@ -87,8 +89,13 @@ struct distance_op<LOAD_T, DATA_T, DISTANCE_T, DATASET_BLOCK_DIM, TEAM_SIZE, fal
const uint32_t kv = k + v;
// if (kv >= dataset_dim) break;
DISTANCE_T diff = query_buffer[device::swizzling(kv)];
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
if (metric == raft::distance::L2Expanded) {
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
} else {
diff *= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff;
}
}
}
}
Expand Down Expand Up @@ -130,7 +137,8 @@ struct distance_op<LOAD_T, DATA_T, DISTANCE_T, DATASET_BLOCK_DIM, TEAM_SIZE, tru

__device__ DISTANCE_T operator()(const DATA_T* const dataset_ptr,
const std::uint32_t dataset_dim,
const bool valid)
const bool valid,
raft::distance::DistanceType metric)
{
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
constexpr unsigned vlen = get_vlen<LOAD_T, DATA_T>();
Expand All @@ -155,8 +163,13 @@ struct distance_op<LOAD_T, DATA_T, DISTANCE_T, DATASET_BLOCK_DIM, TEAM_SIZE, tru
DISTANCE_T diff;
const unsigned ev = (vlen * e) + v;
diff = query_frags[ev];
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
if (metric == raft::distance::L2Expanded) {
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
} else {
diff *= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff;
}
}
}
}
Expand Down

0 comments on commit d534487

Please sign in to comment.