Skip to content

Commit

Permalink
add recall score
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Sep 27, 2023
1 parent d17376a commit 63400a4
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 0 deletions.
108 changes: 108 additions & 0 deletions cpp/include/raft/stats/detail/recall.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <atomic>
#include <cstddef>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/error.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resources.hpp>

#include <cub/cub.cuh>

#include <cuda/atomic>

#include <optional>

namespace raft::stats::detail {

template <typename IndicesValueType,
typename DistanceValueType,
typename IndexType,
typename ScalarType>
__global__ void recall(
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> indices,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> ref_indices,
std::optional<raft::device_matrix_view<const DistanceValueType, IndexType, raft::row_major>>
distances,
std::optional<raft::device_matrix_view<const DistanceValueType, IndexType, raft::row_major>>
ref_distances,
raft::device_scalar_view<ScalarType> recall_score,
DistanceValueType const threshold)
{
IndexType const row_idx = blockIdx.x;
auto const lane_idx = threadIdx.x % 32;

// Each warp stores a recall score computed across the columns per lane
IndexType thread_recall_score = 0;
for (IndexType col_idx = lane_idx; col_idx < indices.extent(1); col_idx += 32) {
for (IndexType ref_col_idx = 0; ref_col_idx < ref_indices.extent(1); ref_col_idx++) {
if (indices(row_idx, col_idx) == ref_indices(row_idx, ref_col_idx) or
((distances.has_value()) and
(raft::abs(distances.value()(row_idx, col_idx) -
ref_distances.value()(row_idx, ref_col_idx)) < threshold))) {
thread_recall_score += 1;
break;
}
}
}

// Reduce across a warp for row score
typedef cub::BlockReduce<int, 32> BlockReduce;

__shared__ typename BlockReduce::TempStorage temp_storage;

ScalarType row_recall_score = BlockReduce(temp_storage).Sum(thread_recall_score);

// Reduce across all rows for global score
if (lane_idx == 0) {
cuda::atomic_ref<ScalarType, cuda::thread_scope_device> device_recall_score{
*recall_score.data_handle()};
std::size_t const total_count = indices.extent(0) * indices.extent(1);
device_recall_score.fetch_add(row_recall_score / total_count);
}
}

template <typename IndicesValueType,
typename DistanceValueType,
typename IndexType,
typename ScalarType>
void recall(
raft::resources const& res,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> indices,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> ref_indices,
std::optional<raft::device_matrix_view<const DistanceValueType, IndexType, raft::row_major>>
distances,
std::optional<raft::device_matrix_view<const DistanceValueType, IndexType, raft::row_major>>
ref_distances,
raft::device_scalar_view<ScalarType> recall_score,
DistanceValueType const threshold)
{
// One warp per row, launch a warp-width block per-row kernel
auto constexpr kNumThreads = 32;
auto const num_blocks = indices.extent(0);

std::cout << "total count: " << indices.extent(0) * indices.extent(1);

recall<<<num_blocks, kNumThreads>>>(
indices, ref_indices, distances, ref_distances, recall_score, threshold);
}

} // end namespace raft::stats::detail
144 changes: 144 additions & 0 deletions cpp/include/raft/stats/recall.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "detail/recall.cuh"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/error.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>

#include <optional>

namespace raft::stats {

/**
* @defgroup stats_recall Recall Score
* @{
*/

/**
* @brief Calculate Recall score on the device for indices, distances computed by any Nearest
* Neighbors Algorithm against reference indices, distances. Recall score is calculated by comparing
* the total number of matching indices and dividing that value by the total size of the indices
* matrix of dimensions (D, k). If distance matrices are provided, then non-matching indices could
* be considered a match if abs(dist, ref_dist) < threshold.
*
* @tparam IndicesValueType data-type of the indices
* @tparam IndexType data-type to index all matrices
* @tparam ScalarType data-type to store recall score
* @tparam DistanceValueType data-type of the distances
* @param res raft::resources object to manage resources
* @param[in] indices raft::device_matrix_view indices of neighbors
* @param[in] ref_indices raft::device_matrix_view reference indices of neighbors
* @param[out] recall_score raft::device_scalar_view output recall score
* @param[in] distances (optional) raft::device_matrix_view distances of neighbors
* @param[in] ref_distances (optional) raft::device_matrix_view reference distances of neighbors
* @param[in] threshold (optional, default = 0.001) value for distance comparison
*/
template <typename IndicesValueType,
typename IndexType,
typename ScalarType,
typename DistanceValueType = float>
void recall(
raft::resources const& res,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> indices,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> ref_indices,
raft::device_scalar_view<ScalarType> recall_score,
std::optional<raft::device_matrix_view<const DistanceValueType, IndexType, raft::row_major>>
distances = std::nullopt,
std::optional<raft::device_matrix_view<const DistanceValueType, IndexType, raft::row_major>>
ref_distances = std::nullopt,
std::optional<raft::host_scalar_view<const DistanceValueType>> threshold = std::nullopt)
{
RAFT_EXPECTS(indices.extent(0) == ref_indices.extent(0),
"The number of rows in indices and reference indices should be equal");
RAFT_EXPECTS(indices.extent(1) == ref_indices.extent(1),
"The number of columns in indices and reference indices should be equal");

if (distances.has_value() or ref_distances.has_value()) {
RAFT_EXPECTS(distances.has_value() and ref_distances.has_value(),
"Both distances and reference distances should have values");

RAFT_EXPECTS(distances.value().extent(0) == ref_distances.value().extent(0),
"The number of rows in distances and reference distances should be equal");
RAFT_EXPECTS(distances.value().extent(1) == ref_distances.value().extent(1),
"The number of columns in indices and reference indices should be equal");

RAFT_EXPECTS(indices.extent(0) == distances.value().extent(0),
"The number of rows in indices and distances should be equal");
RAFT_EXPECTS(indices.extent(1) == distances.value().extent(1),
"The number of columns in indices and distances should be equal");
}

DistanceValueType threshold_val = 0.001;
if (threshold.has_value()) { threshold_val = *threshold.value().data_handle(); }

detail::recall(res, indices, ref_indices, distances, ref_distances, recall_score, threshold_val);
}

/**
* @brief Calculate Recall score on the host for indices, distances computed by any Nearest
* Neighbors Algorithm against reference indices, distances. Recall score is calculated by comparing
* the total number of matching indices and dividing that value by the total size of the indices
* matrix of dimensions (D, k). If distance matrices are provided, then non-matching indices could
* be considered a match if abs(dist, ref_dist) < threshold.
*
* @tparam IndicesValueType data-type of the indices
* @tparam IndexType data-type to index all matrices
* @tparam ScalarType data-type to store recall score
* @tparam DistanceValueType data-type of the distances
* @param res raft::resources object to manage resources
* @param[in] indices raft::device_matrix_view indices of neighbors
* @param[in] ref_indices raft::device_matrix_view reference indices of neighbors
* @param[out] recall_score raft::host_scalar_view output recall score
* @param[in] distances (optional) raft::device_matrix_view distances of neighbors
* @param[in] ref_distances (optional) raft::device_matrix_view reference distances of neighbors
* @param[in] threshold (optional, default = 0.001) value for distance comparison
*/
template <typename IndicesValueType,
typename IndexType,
typename ScalarType,
typename DistanceValueType = float>
void recall(
raft::resources const& res,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> indices,
raft::device_matrix_view<const IndicesValueType, IndexType, raft::row_major> ref_indices,
raft::host_scalar_view<ScalarType> recall_score,
std::optional<raft::device_matrix_view<const DistanceValueType, IndexType, raft::row_major>>
distances = std::nullopt,
std::optional<raft::device_matrix_view<const DistanceValueType, IndexType, raft::row_major>>
ref_distances = std::nullopt,
std::optional<raft::host_scalar_view<const DistanceValueType>> threshold = std::nullopt)
{
auto recall_score_d = raft::make_device_scalar(res, *recall_score.data_handle());
recall(res, indices, ref_indices, recall_score_d.view(), distances, ref_distances, threshold);
raft::update_host(recall_score.data_handle(),
recall_score_d.data_handle(),
1,
raft::resource::get_cuda_stream(res));
raft::resource::sync_stream(res);
}

/** @} */ // end group stats_recall

} // end namespace raft::stats

0 comments on commit 63400a4

Please sign in to comment.