Skip to content

Commit

Permalink
tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
tarang-jain committed Apr 11, 2024
1 parent 890372b commit 810ddd1
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 54 deletions.
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void build_knn_graph(raft::resources const& res,
mdspan<const DataT, matrix_extent<int64_t>, row_major, accessor> dataset,
raft::host_matrix_view<IdxT, int64_t, row_major> knn_graph,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded,
std::optional<float> refine_rate = std::nullopt,
std::optional<float> refine_rate = std::nullopt,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
Expand Down
30 changes: 14 additions & 16 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -308,22 +308,20 @@ void search_main(raft::resources const& res,

static_assert(std::is_same_v<DistanceT, float>,
"only float distances are supported at the moment");
if (index.metric() != distance::InnerProduct) {
float* dist_out = distances.data_handle();
const DistanceT* dist_in = distances.data_handle();
// We're converting the data from T to DistanceT during distance computation
// and divide the values by kDivisor. Here we restore the original scale.
constexpr float kScale = spatial::knn::detail::utils::config<T>::kDivisor /
spatial::knn::detail::utils::config<DistanceT>::kDivisor;
ivf::detail::postprocess_distances(dist_out,
dist_in,
index.metric(),
distances.extent(0),
distances.extent(1),
kScale,
true,
resource::get_cuda_stream(res));
}
float* dist_out = distances.data_handle();
const DistanceT* dist_in = distances.data_handle();
// We're converting the data from T to DistanceT during distance computation
// and divide the values by kDivisor. Here we restore the original scale.
constexpr float kScale = spatial::knn::detail::utils::config<T>::kDivisor /
spatial::knn::detail::utils::config<DistanceT>::kDivisor;
ivf::detail::postprocess_distances(dist_out,
dist_in,
index.metric(),
distances.extent(0),
distances.extent(1),
kScale,
true,
resource::get_cuda_stream(res));
}
/** @} */ // end group cagra

Expand Down
13 changes: 0 additions & 13 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@
namespace raft::neighbors::cagra::detail {
namespace multi_cta_search {

template <typename T>
RAFT_KERNEL negate_kernel(T* dists, uint32_t n)
{
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) { dists[tid] *= -1; }
}

template <unsigned TEAM_SIZE,
unsigned DATASET_BLOCK_DIM,
typename DATASET_DESCRIPTOR_T,
Expand Down Expand Up @@ -268,12 +261,6 @@ struct search : public search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
true,
NULL,
stream);

if (this->metric == raft::distance::DistanceType::InnerProduct) {
dim3 threads(1024, 1, 1);
dim3 blocks((num_queries * topk + 1023) / 1024, 1, 1);
negate_kernel<<<blocks, threads, 0, stream>>>(topk_distances_ptr, num_queries * topk);
}
}
};

Expand Down
16 changes: 6 additions & 10 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -542,14 +542,13 @@ RAFT_KERNEL batched_memcpy_kernel(T* const dst, // [batch_size, ld_dst]
const T* const src, // [batch_size, ld_src]
const uint64_t ld_src,
const uint64_t count,
const uint64_t batch_size,
bool invert)
const uint64_t batch_size)
{
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= count * batch_size) { return; }
const auto i = tid % count;
const auto j = tid / count;
dst[i + (ld_dst * j)] = (-2 * invert + 1) * src[i + (ld_src * j)];
dst[i + (ld_dst * j)] = src[i + (ld_src * j)];
}

template <class T>
Expand All @@ -559,15 +558,14 @@ void batched_memcpy(T* const dst, // [batch_size, ld_dst]
const uint64_t ld_src,
const uint64_t count,
const uint64_t batch_size,
cudaStream_t cuda_stream,
bool invert = false)
cudaStream_t cuda_stream)
{
assert(ld_dst >= count);
assert(ld_src >= count);
constexpr uint32_t block_size = 256;
const auto grid_size = (batch_size * count + block_size - 1) / block_size;
batched_memcpy_kernel<T><<<grid_size, block_size, 0, cuda_stream>>>(
dst, ld_dst, src, ld_src, count, batch_size, invert);
batched_memcpy_kernel<T>
<<<grid_size, block_size, 0, cuda_stream>>>(dst, ld_dst, src, ld_src, count, batch_size);
}

template <class T>
Expand Down Expand Up @@ -974,15 +972,13 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
num_queries,
stream);
if (topk_distances_ptr) {
bool invert = this->metric == distance::DistanceType::InnerProduct;
batched_memcpy(topk_distances_ptr,
topk,
result_distances_ptr,
result_buffer_allocation_size,
topk,
num_queries,
stream,
invert);
stream);
}

if (num_executed_iterations) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -781,13 +781,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel(
unsigned ii = i;
if (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); }
const INDEX_T invalid_index = utils::get_max_value<INDEX_T>();
if (result_distances_ptr != nullptr) {
if (metric == distance::InnerProduct && result_indices_buffer[ii] != invalid_index) {
result_distances_ptr[j] = -result_distances_buffer[ii];
} else {
result_distances_ptr[j] = result_distances_buffer[ii];
}
}
if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; }
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;

result_indices_ptr[j] =
Expand Down
1 change: 0 additions & 1 deletion cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
#pragma once

#include "raft/util/cudart_utils.hpp"
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Search with filter instantiation

#include "../test_utils.cuh"
Expand Down
3 changes: 1 addition & 2 deletions cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ TEST_P(AnnCagraFilterTestF_U32, AnnCagraFilter)

INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestF_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF_U32,
::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF_U32, ::testing::ValuesIn(inputs));

} // namespace raft::neighbors::cagra
3 changes: 1 addition & 2 deletions cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ TEST_P(AnnCagraFilterTestH_U32, AnnCagraFilter)

INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestH_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestH_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestH_U32,
::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestH_U32, ::testing::ValuesIn(inputs));

} // namespace raft::neighbors::cagra
3 changes: 1 addition & 2 deletions cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ TEST_P(AnnCagraFilterTestI8_U32, AnnCagraFilter)

INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestI8_U32,
::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestI8_U32, ::testing::ValuesIn(inputs));

} // namespace raft::neighbors::cagra

0 comments on commit 810ddd1

Please sign in to comment.