diff --git a/cpp/include/raft/stats/detail/recall.cuh b/cpp/include/raft/stats/detail/neighborhood_recall.cuh similarity index 97% rename from cpp/include/raft/stats/detail/recall.cuh rename to cpp/include/raft/stats/detail/neighborhood_recall.cuh index ea43976f11..1233c0177e 100644 --- a/cpp/include/raft/stats/detail/recall.cuh +++ b/cpp/include/raft/stats/detail/neighborhood_recall.cuh @@ -37,7 +37,7 @@ template -__global__ void recall( +__global__ void neighborhood_recall( raft::device_matrix_view indices, raft::device_matrix_view ref_indices, std::optional> @@ -50,7 +50,7 @@ __global__ void recall( IndexType const row_idx = blockIdx.x; auto const lane_idx = threadIdx.x % 32; - // Each warp stores a recall score computed across the columns per lane + // Each warp stores a recall score computed across the columns per row IndexType thread_recall_score = 0; for (IndexType col_idx = lane_idx; col_idx < indices.extent(1); col_idx += 32) { for (IndexType ref_col_idx = 0; ref_col_idx < ref_indices.extent(1); ref_col_idx++) { @@ -92,7 +92,7 @@ template -void recall( +void neighborhood_recall( raft::resources const& res, raft::device_matrix_view indices, raft::device_matrix_view ref_indices, @@ -107,7 +107,7 @@ void recall( auto constexpr kNumThreads = 32; auto const num_blocks = indices.extent(0); - recall<<>>( + neighborhood_recall<<>>( indices, ref_indices, distances, ref_distances, recall_score, eps); } diff --git a/cpp/include/raft/stats/recall.cuh b/cpp/include/raft/stats/neighborhood_recall.cuh similarity index 81% rename from cpp/include/raft/stats/recall.cuh rename to cpp/include/raft/stats/neighborhood_recall.cuh index 67e7d4ac24..f53eaf508c 100644 --- a/cpp/include/raft/stats/recall.cuh +++ b/cpp/include/raft/stats/neighborhood_recall.cuh @@ -16,7 +16,7 @@ #pragma once -#include "detail/recall.cuh" +#include "detail/neighborhood_recall.cuh" #include #include @@ -32,16 +32,16 @@ namespace raft::stats { /** - * @defgroup stats_recall Recall Score + * @defgroup stats_neighborhood_recall Neighborhood Recall Score * @{ */ /** - * @brief Calculate Recall score on the device for indices, distances computed by any Nearest - * Neighbors Algorithm against reference indices, distances. Recall score is calculated by comparing - * the total number of matching indices and dividing that value by the total size of the indices - * matrix of dimensions (D, k). If distance matrices are provided, then non-matching indices could - * be considered a match if abs(dist, ref_dist) < eps. + * @brief Calculate Neighborhood Recall score on the device for indices, distances computed by any + * Nearest Neighbors Algorithm against reference indices, distances. Recall score is calculated by + * comparing the total number of matching indices and dividing that value by the total size of the + * indices matrix of dimensions (D, k). If distance matrices are provided, then non-matching indices + * could be considered a match if abs(dist, ref_dist) < eps. * * @tparam IndicesValueType data-type of the indices * @tparam IndexType data-type to index all matrices @@ -59,7 +59,7 @@ template -void recall( +void neighborhood_recall( raft::resources const& res, raft::device_matrix_view indices, raft::device_matrix_view ref_indices, @@ -93,15 +93,16 @@ void recall( DistanceValueType eps_val = 0.001; if (eps.has_value()) { eps_val = *eps.value().data_handle(); } - detail::recall(res, indices, ref_indices, distances, ref_distances, recall_score, eps_val); + detail::neighborhood_recall( + res, indices, ref_indices, distances, ref_distances, recall_score, eps_val); } /** - * @brief Calculate Recall score on the host for indices, distances computed by any Nearest - * Neighbors Algorithm against reference indices, distances. Recall score is calculated by comparing - * the total number of matching indices and dividing that value by the total size of the indices - * matrix of dimensions (D, k). If distance matrices are provided, then non-matching indices could - * be considered a match if abs(dist, ref_dist) < eps. + * @brief Calculate Neighborhood Recall score on the host for indices, distances computed by any + * Nearest Neighbors Algorithm against reference indices, distances. Recall score is calculated by + * comparing the total number of matching indices and dividing that value by the total size of the + * indices matrix of dimensions (D, k). If distance matrices are provided, then non-matching indices + * could be considered a match if abs(dist, ref_dist) < eps. * * @tparam IndicesValueType data-type of the indices * @tparam IndexType data-type to index all matrices @@ -119,7 +120,7 @@ template -void recall( +void neighborhood_recall( raft::resources const& res, raft::device_matrix_view indices, raft::device_matrix_view ref_indices, @@ -131,7 +132,8 @@ void recall( std::optional> eps = std::nullopt) { auto recall_score_d = raft::make_device_scalar(res, *recall_score.data_handle()); - recall(res, indices, ref_indices, recall_score_d.view(), distances, ref_distances, eps); + neighborhood_recall( + res, indices, ref_indices, recall_score_d.view(), distances, ref_distances, eps); raft::update_host(recall_score.data_handle(), recall_score_d.data_handle(), 1, diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 0aa9790031..8487610ff0 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -420,9 +420,9 @@ if(BUILD_TESTS) test/stats/mean_center.cu test/stats/minmax.cu test/stats/mutual_info_score.cu + test/stats/neighborhood_recall.cu test/stats/r2_score.cu test/stats/rand_index.cu - test/stats/recall.cu test/stats/regression_metrics.cu test/stats/silhouette_score.cu test/stats/stddev.cu diff --git a/cpp/test/stats/recall.cu b/cpp/test/stats/neighborhood_recall.cu similarity index 73% rename from cpp/test/stats/recall.cu rename to cpp/test/stats/neighborhood_recall.cu index b766e876a3..b5ab30e6f0 100644 --- a/cpp/test/stats/recall.cu +++ b/cpp/test/stats/neighborhood_recall.cu @@ -24,24 +24,24 @@ #include -#include +#include #include #include namespace raft::stats { -struct RecallInputs { +struct NeighborhoodRecallInputs { int n_rows; int n_cols; int k; }; template -class RecallTest : public ::testing::TestWithParam { +class NeighborhoodRecallTest : public ::testing::TestWithParam { public: - RecallTest() - : ps{::testing::TestWithParam::GetParam()}, + NeighborhoodRecallTest() + : ps{::testing::TestWithParam::GetParam()}, data_1{raft::make_device_matrix(res, ps.n_rows, ps.n_cols)}, data_2{raft::make_device_matrix(res, ps.n_rows, ps.n_cols)} { @@ -113,23 +113,23 @@ class RecallTest : public ::testing::TestWithParam { // find GPU recall scores auto s1 = 0; auto indices_only_recall_scalar = raft::make_host_scalar(s1); - recall(res, - raft::make_const_mdspan(indices_1.view()), - raft::make_const_mdspan(indices_2.view()), - indices_only_recall_scalar.view()); + neighborhood_recall(res, + raft::make_const_mdspan(indices_1.view()), + raft::make_const_mdspan(indices_2.view()), + indices_only_recall_scalar.view()); auto s2 = 0; auto recall_scalar = raft::make_host_scalar(s2); DistanceT s3 = 0.001; auto eps_mda = raft::make_host_scalar(s3); - recall(res, - raft::make_const_mdspan(indices_1.view()), - raft::make_const_mdspan(indices_2.view()), - recall_scalar.view(), - raft::make_const_mdspan(distances_1.view()), - raft::make_const_mdspan(distances_2.view()), - raft::make_const_mdspan(eps_mda.view())); + neighborhood_recall(res, + raft::make_const_mdspan(indices_1.view()), + raft::make_const_mdspan(indices_2.view()), + recall_scalar.view(), + raft::make_const_mdspan(distances_1.view()), + raft::make_const_mdspan(distances_2.view()), + raft::make_const_mdspan(eps_mda.view())); // assert correctness ASSERT_TRUE(raft::match(indices_only_recall_h, @@ -159,19 +159,21 @@ class RecallTest : public ::testing::TestWithParam { private: raft::resources res; - RecallInputs ps; + NeighborhoodRecallInputs ps; raft::device_matrix data_1; raft::device_matrix data_2; }; -const std::vector inputs = - raft::util::itertools::product({10, 50, 100}, // n_rows - {80, 100}, // dim - {32, 64}); +const std::vector inputs = + raft::util::itertools::product({10, 50, 100}, // n_rows + {80, 100}, // n_cols + {32, 64}); // k -using RecallTestF_U32 = RecallTest; -TEST_P(RecallTestF_U32, AnnCagra) { this->test_recall(); } +using NeighborhoodRecallTestF_U32 = NeighborhoodRecallTest; +TEST_P(NeighborhoodRecallTestF_U32, AnnCagra) { this->test_recall(); } -INSTANTIATE_TEST_CASE_P(RecallTest, RecallTestF_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(NeighborhoodRecallTest, + NeighborhoodRecallTestF_U32, + ::testing::ValuesIn(inputs)); } // end namespace raft::stats