Skip to content

Commit

Permalink
rename recall to neighborhood_recall
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Sep 27, 2023
1 parent c228a36 commit 65d41cd
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ template <typename IndicesValueType,
typename DistanceValueType,
typename IndexType,
typename ScalarType>
__global__ void recall(
__global__ void neighborhood_recall(
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> indices,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> ref_indices,
std::optional<raft::device_matrix_view<const DistanceValueType, IndexType, raft::row_major>>
Expand All @@ -50,7 +50,7 @@ __global__ void recall(
IndexType const row_idx = blockIdx.x;
auto const lane_idx = threadIdx.x % 32;

// Each warp stores a recall score computed across the columns per lane
// Each warp stores a recall score computed across the columns per row
IndexType thread_recall_score = 0;
for (IndexType col_idx = lane_idx; col_idx < indices.extent(1); col_idx += 32) {
for (IndexType ref_col_idx = 0; ref_col_idx < ref_indices.extent(1); ref_col_idx++) {
Expand Down Expand Up @@ -92,7 +92,7 @@ template <typename IndicesValueType,
typename DistanceValueType,
typename IndexType,
typename ScalarType>
void recall(
void neighborhood_recall(
raft::resources const& res,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> indices,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> ref_indices,
Expand All @@ -107,7 +107,7 @@ void recall(
auto constexpr kNumThreads = 32;
auto const num_blocks = indices.extent(0);

recall<<<num_blocks, kNumThreads>>>(
neighborhood_recall<<<num_blocks, kNumThreads>>>(
indices, ref_indices, distances, ref_distances, recall_score, eps);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#pragma once

#include "detail/recall.cuh"
#include "detail/neighborhood_recall.cuh"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
Expand All @@ -32,16 +32,16 @@
namespace raft::stats {

/**
* @defgroup stats_recall Recall Score
* @defgroup stats_neighborhood_recall Neighborhood Recall Score
* @{
*/

/**
* @brief Calculate Recall score on the device for indices, distances computed by any Nearest
* Neighbors Algorithm against reference indices, distances. Recall score is calculated by comparing
* the total number of matching indices and dividing that value by the total size of the indices
* matrix of dimensions (D, k). If distance matrices are provided, then non-matching indices could
* be considered a match if abs(dist, ref_dist) < eps.
* @brief Calculate Neighborhood Recall score on the device for indices, distances computed by any
* Nearest Neighbors Algorithm against reference indices, distances. Recall score is calculated by
* comparing the total number of matching indices and dividing that value by the total size of the
* indices matrix of dimensions (D, k). If distance matrices are provided, then non-matching indices
* could be considered a match if abs(dist, ref_dist) < eps.
*
* @tparam IndicesValueType data-type of the indices
* @tparam IndexType data-type to index all matrices
Expand All @@ -59,7 +59,7 @@ template <typename IndicesValueType,
typename IndexType,
typename ScalarType,
typename DistanceValueType = float>
void recall(
void neighborhood_recall(
raft::resources const& res,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> indices,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> ref_indices,
Expand Down Expand Up @@ -93,15 +93,16 @@ void recall(
DistanceValueType eps_val = 0.001;
if (eps.has_value()) { eps_val = *eps.value().data_handle(); }

detail::recall(res, indices, ref_indices, distances, ref_distances, recall_score, eps_val);
detail::neighborhood_recall(
res, indices, ref_indices, distances, ref_distances, recall_score, eps_val);
}

/**
* @brief Calculate Recall score on the host for indices, distances computed by any Nearest
* Neighbors Algorithm against reference indices, distances. Recall score is calculated by comparing
* the total number of matching indices and dividing that value by the total size of the indices
* matrix of dimensions (D, k). If distance matrices are provided, then non-matching indices could
* be considered a match if abs(dist, ref_dist) < eps.
* @brief Calculate Neighborhood Recall score on the host for indices, distances computed by any
* Nearest Neighbors Algorithm against reference indices, distances. Recall score is calculated by
* comparing the total number of matching indices and dividing that value by the total size of the
* indices matrix of dimensions (D, k). If distance matrices are provided, then non-matching indices
* could be considered a match if abs(dist, ref_dist) < eps.
*
* @tparam IndicesValueType data-type of the indices
* @tparam IndexType data-type to index all matrices
Expand All @@ -119,7 +120,7 @@ template <typename IndicesValueType,
typename IndexType,
typename ScalarType,
typename DistanceValueType = float>
void recall(
void neighborhood_recall(
raft::resources const& res,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> indices,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> ref_indices,
Expand All @@ -131,7 +132,8 @@ void recall(
std::optional<raft::host_scalar_view<const DistanceValueType>> eps = std::nullopt)
{
auto recall_score_d = raft::make_device_scalar(res, *recall_score.data_handle());
recall(res, indices, ref_indices, recall_score_d.view(), distances, ref_distances, eps);
neighborhood_recall(
res, indices, ref_indices, recall_score_d.view(), distances, ref_distances, eps);
raft::update_host(recall_score.data_handle(),
recall_score_d.data_handle(),
1,
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,9 @@ if(BUILD_TESTS)
test/stats/mean_center.cu
test/stats/minmax.cu
test/stats/mutual_info_score.cu
test/stats/neighborhood_recall.cu
test/stats/r2_score.cu
test/stats/rand_index.cu
test/stats/recall.cu
test/stats/regression_metrics.cu
test/stats/silhouette_score.cu
test/stats/stddev.cu
Expand Down
50 changes: 26 additions & 24 deletions cpp/test/stats/recall.cu → cpp/test/stats/neighborhood_recall.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,24 @@

#include <raft_internal/neighbors/naive_knn.cuh>

#include <raft/stats/recall.cuh>
#include <raft/stats/neighborhood_recall.cuh>
#include <raft/util/itertools.hpp>

#include <gtest/gtest.h>

namespace raft::stats {

struct RecallInputs {
struct NeighborhoodRecallInputs {
int n_rows;
int n_cols;
int k;
};

template <typename DistanceT, typename IdxT>
class RecallTest : public ::testing::TestWithParam<RecallInputs> {
class NeighborhoodRecallTest : public ::testing::TestWithParam<NeighborhoodRecallInputs> {
public:
RecallTest()
: ps{::testing::TestWithParam<RecallInputs>::GetParam()},
NeighborhoodRecallTest()
: ps{::testing::TestWithParam<NeighborhoodRecallInputs>::GetParam()},
data_1{raft::make_device_matrix<DistanceT, IdxT>(res, ps.n_rows, ps.n_cols)},
data_2{raft::make_device_matrix<DistanceT, IdxT>(res, ps.n_rows, ps.n_cols)}
{
Expand Down Expand Up @@ -113,23 +113,23 @@ class RecallTest : public ::testing::TestWithParam<RecallInputs> {
// find GPU recall scores
auto s1 = 0;
auto indices_only_recall_scalar = raft::make_host_scalar<double>(s1);
recall(res,
raft::make_const_mdspan(indices_1.view()),
raft::make_const_mdspan(indices_2.view()),
indices_only_recall_scalar.view());
neighborhood_recall(res,
raft::make_const_mdspan(indices_1.view()),
raft::make_const_mdspan(indices_2.view()),
indices_only_recall_scalar.view());

auto s2 = 0;
auto recall_scalar = raft::make_host_scalar<double>(s2);
DistanceT s3 = 0.001;
auto eps_mda = raft::make_host_scalar<DistanceT>(s3);

recall<IdxT, IdxT, double, DistanceT>(res,
raft::make_const_mdspan(indices_1.view()),
raft::make_const_mdspan(indices_2.view()),
recall_scalar.view(),
raft::make_const_mdspan(distances_1.view()),
raft::make_const_mdspan(distances_2.view()),
raft::make_const_mdspan(eps_mda.view()));
neighborhood_recall<IdxT, IdxT, double, DistanceT>(res,
raft::make_const_mdspan(indices_1.view()),
raft::make_const_mdspan(indices_2.view()),
recall_scalar.view(),
raft::make_const_mdspan(distances_1.view()),
raft::make_const_mdspan(distances_2.view()),
raft::make_const_mdspan(eps_mda.view()));

// assert correctness
ASSERT_TRUE(raft::match(indices_only_recall_h,
Expand Down Expand Up @@ -159,19 +159,21 @@ class RecallTest : public ::testing::TestWithParam<RecallInputs> {

private:
raft::resources res;
RecallInputs ps;
NeighborhoodRecallInputs ps;
raft::device_matrix<DistanceT, IdxT> data_1;
raft::device_matrix<DistanceT, IdxT> data_2;
};

const std::vector<RecallInputs> inputs =
raft::util::itertools::product<RecallInputs>({10, 50, 100}, // n_rows
{80, 100}, // dim
{32, 64});
const std::vector<NeighborhoodRecallInputs> inputs =
raft::util::itertools::product<NeighborhoodRecallInputs>({10, 50, 100}, // n_rows
{80, 100}, // n_cols
{32, 64}); // k

using RecallTestF_U32 = RecallTest<float, std::uint32_t>;
TEST_P(RecallTestF_U32, AnnCagra) { this->test_recall(); }
using NeighborhoodRecallTestF_U32 = NeighborhoodRecallTest<float, std::uint32_t>;
TEST_P(NeighborhoodRecallTestF_U32, AnnCagra) { this->test_recall(); }

INSTANTIATE_TEST_CASE_P(RecallTest, RecallTestF_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(NeighborhoodRecallTest,
NeighborhoodRecallTestF_U32,
::testing::ValuesIn(inputs));

} // end namespace raft::stats

0 comments on commit 65d41cd

Please sign in to comment.