From 4da47520da061b6221279d94ac4bcba99a101e24 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 12 Nov 2024 12:29:08 -0500 Subject: [PATCH] Removing remaining stats --- .../stats/detail/batched/silhouette_score.cuh | 281 --------------- .../raft/stats/detail/neighborhood_recall.cuh | 115 ------- .../raft/stats/detail/silhouette_score.cuh | 324 ------------------ .../stats/detail/trustworthiness_score.cuh | 221 ------------ .../raft/stats/neighborhood_recall.cuh | 194 ----------- cpp/include/raft/stats/silhouette_score.cuh | 226 ------------ .../raft/stats/trustworthiness_score.cuh | 101 ------ 7 files changed, 1462 deletions(-) delete mode 100644 cpp/include/raft/stats/detail/batched/silhouette_score.cuh delete mode 100644 cpp/include/raft/stats/detail/neighborhood_recall.cuh delete mode 100644 cpp/include/raft/stats/detail/silhouette_score.cuh delete mode 100644 cpp/include/raft/stats/detail/trustworthiness_score.cuh delete mode 100644 cpp/include/raft/stats/neighborhood_recall.cuh delete mode 100644 cpp/include/raft/stats/silhouette_score.cuh delete mode 100644 cpp/include/raft/stats/trustworthiness_score.cuh diff --git a/cpp/include/raft/stats/detail/batched/silhouette_score.cuh b/cpp/include/raft/stats/detail/batched/silhouette_score.cuh deleted file mode 100644 index 643ef77500..0000000000 --- a/cpp/include/raft/stats/detail/batched/silhouette_score.cuh +++ /dev/null @@ -1,281 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "../silhouette_score.cuh" - -#include -#include -#include -#include -#include - -#include -#include - -#include -#include -#include - -namespace raft { -namespace stats { -namespace batched { -namespace detail { - -/** - * This kernel initializes matrix b (n_rows * n_labels) - * For each label that the corresponding row is not a part of is initialized as 0 - * If the corresponding row is the only sample in its label, again 0 - * Only if the there are > 1 samples in the label, row is initialized to max - */ -template -RAFT_KERNEL fill_b_kernel(value_t* b, - const label_idx* y, - value_idx n_rows, - label_idx n_labels, - const value_idx* cluster_counts) -{ - value_idx idx = threadIdx.x + blockIdx.x * blockDim.x; - label_idx idy = threadIdx.y + blockIdx.y * blockDim.y; - - if (idx >= n_rows || idy >= n_labels) { return; } - - auto row_cluster = y[idx]; - - auto col_cluster_count = cluster_counts[idy]; - - // b for own cluster should be max value - // so that it does not interfere with min operator - // b is also max if col cluster count is 0 - // however, b is 0 if self cluster count is 1 - if (row_cluster == idy || col_cluster_count == 0) { - if (cluster_counts[row_cluster] == 1) { - b[idx * n_labels + idy] = 0; - } else { - b[idx * n_labels + idy] = std::numeric_limits::max(); - } - } else { - b[idx * n_labels + idy] = 0; - } -} - -/** - * This kernel does an elementwise sweep of chunked pairwise distance matrix - * By knowing the offsets of the chunked pairwise distance matrix in the - * global pairwise distance matrix, we are able to calculate - * intermediate values of a and b for the rows and columns present in the - * current chunked pairwise distance matrix. - */ -template -RAFT_KERNEL compute_chunked_a_b_kernel(value_t* a, - value_t* b, - value_idx row_offset, - value_idx col_offset, - const label_idx* y, - label_idx n_labels, - const value_idx* cluster_counts, - const value_t* distances, - value_idx dist_rows, - value_idx dist_cols) -{ - value_idx row_id = threadIdx.x + blockIdx.x * blockDim.x; - value_idx col_id = threadIdx.y + blockIdx.y * blockDim.y; - - // these are global offsets of current element - // in the full pairwise distance matrix - value_idx pw_row_id = row_id + row_offset; - value_idx pw_col_id = col_id + col_offset; - - if (row_id >= dist_rows || col_id >= dist_cols || pw_row_id == pw_col_id) { return; } - - auto row_cluster = y[pw_row_id]; - if (cluster_counts[row_cluster] == 1) { return; } - - auto col_cluster = y[pw_col_id]; - auto col_cluster_counts = cluster_counts[col_cluster]; - - if (col_cluster == row_cluster) { - atomicAdd(&a[pw_row_id], distances[row_id * dist_cols + col_id] / (col_cluster_counts - 1)); - } else { - atomicAdd(&b[pw_row_id * n_labels + col_cluster], - distances[row_id * dist_cols + col_id] / col_cluster_counts); - } -} - -template -rmm::device_uvector get_cluster_counts(raft::resources const& handle, - const label_idx* y, - value_idx& n_rows, - label_idx& n_labels) -{ - auto stream = resource::get_cuda_stream(handle); - - rmm::device_uvector cluster_counts(n_labels, stream); - - rmm::device_uvector workspace(1, stream); - - raft::stats::detail::countLabels(y, cluster_counts.data(), n_rows, n_labels, workspace, stream); - - return cluster_counts; -} - -template -rmm::device_uvector get_pairwise_distance(raft::resources const& handle, - const value_t* left_begin, - const value_t* right_begin, - value_idx& n_left_rows, - value_idx& n_right_rows, - value_idx& n_cols, - raft::distance::DistanceType metric, - cudaStream_t stream) -{ - rmm::device_uvector distances(n_left_rows * n_right_rows, stream); - - raft::distance::pairwise_distance( - handle, left_begin, right_begin, distances.data(), n_left_rows, n_right_rows, n_cols, metric); - - return distances; -} - -template -void compute_chunked_a_b(raft::resources const& handle, - value_t* a, - value_t* b, - value_idx& row_offset, - value_idx& col_offset, - const label_idx* y, - label_idx& n_labels, - const value_idx* cluster_counts, - const value_t* distances, - value_idx& dist_rows, - value_idx& dist_cols, - cudaStream_t stream) -{ - dim3 block_size(std::min(dist_rows, 32), std::min(dist_cols, 32)); - dim3 grid_size(raft::ceildiv(dist_rows, (value_idx)block_size.x), - raft::ceildiv(dist_cols, (value_idx)block_size.y)); - - detail::compute_chunked_a_b_kernel<<>>( - a, b, row_offset, col_offset, y, n_labels, cluster_counts, distances, dist_rows, dist_cols); -} - -template -value_t silhouette_score( - raft::resources const& handle, - const value_t* X, - value_idx n_rows, - value_idx n_cols, - const label_idx* y, - label_idx n_labels, - value_t* scores, - value_idx chunk, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Unexpanded) -{ - ASSERT(n_labels >= 2 && n_labels <= (n_rows - 1), - "silhouette Score not defined for the given number of labels!"); - - rmm::device_uvector cluster_counts = get_cluster_counts(handle, y, n_rows, n_labels); - - auto stream = resource::get_cuda_stream(handle); - auto policy = resource::get_thrust_policy(handle); - - auto b_size = n_rows * n_labels; - - value_t *a_ptr, *b_ptr; - rmm::device_uvector a(0, stream); - rmm::device_uvector b(b_size, stream); - - b_ptr = b.data(); - - // since a and silhouette score per sample are same size, reusing - if (scores == nullptr || scores == NULL) { - a.resize(n_rows, stream); - a_ptr = a.data(); - } else { - a_ptr = scores; - } - - thrust::fill(policy, a_ptr, a_ptr + n_rows, 0); - - dim3 block_size(std::min(n_rows, 32), std::min(n_labels, 32)); - dim3 grid_size(raft::ceildiv(n_rows, (value_idx)block_size.x), - raft::ceildiv(n_labels, (label_idx)block_size.y)); - detail::fill_b_kernel<<>>( - b_ptr, y, n_rows, n_labels, cluster_counts.data()); - - resource::wait_stream_pool_on_stream(handle); - - auto n_iters = 0; - - for (value_idx i = 0; i < n_rows; i += chunk) { - for (value_idx j = 0; j < n_rows; j += chunk) { - ++n_iters; - - auto chunk_stream = resource::get_next_usable_stream(handle, i + chunk * j); - - const auto* left_begin = X + (i * n_cols); - const auto* right_begin = X + (j * n_cols); - - auto n_left_rows = (i + chunk) < n_rows ? chunk : (n_rows - i); - auto n_right_rows = (j + chunk) < n_rows ? chunk : (n_rows - j); - - rmm::device_uvector distances = get_pairwise_distance( - handle, left_begin, right_begin, n_left_rows, n_right_rows, n_cols, metric, chunk_stream); - - compute_chunked_a_b(handle, - a_ptr, - b_ptr, - i, - j, - y, - n_labels, - cluster_counts.data(), - distances.data(), - n_left_rows, - n_right_rows, - chunk_stream); - } - } - - resource::sync_stream_pool(handle); - - // calculating row-wise minimum in b - // this prim only supports int indices for now - raft::linalg::reduce( - b_ptr, - b_ptr, - n_labels, - n_rows, - std::numeric_limits::max(), - true, - true, - stream, - false, - raft::identity_op(), - raft::min_op()); - - // calculating the silhouette score per sample - raft::linalg::binaryOp, value_t, value_idx>( - a_ptr, a_ptr, b_ptr, n_rows, raft::stats::detail::SilOp(), stream); - - return thrust::reduce(policy, a_ptr, a_ptr + n_rows, value_t(0)) / n_rows; -} - -} // namespace detail -} // namespace batched -} // namespace stats -} // namespace raft diff --git a/cpp/include/raft/stats/detail/neighborhood_recall.cuh b/cpp/include/raft/stats/detail/neighborhood_recall.cuh deleted file mode 100644 index fe3a3f6ece..0000000000 --- a/cpp/include/raft/stats/detail/neighborhood_recall.cuh +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include - -namespace raft::stats::detail { - -template -RAFT_KERNEL neighborhood_recall( - raft::device_matrix_view indices, - raft::device_matrix_view ref_indices, - std::optional> - distances, - std::optional> - ref_distances, - raft::device_scalar_view recall_score, - DistanceValueType const eps) -{ - auto constexpr kThreadsPerBlock = 32; - IndexType const row_idx = blockIdx.x; - auto const lane_idx = threadIdx.x % kThreadsPerBlock; - - // Each warp stores a recall score computed across the columns per row - IndexType thread_recall_score = 0; - for (IndexType col_idx = lane_idx; col_idx < indices.extent(1); col_idx += kThreadsPerBlock) { - for (IndexType ref_col_idx = 0; ref_col_idx < ref_indices.extent(1); ref_col_idx++) { - if (indices(row_idx, col_idx) == ref_indices(row_idx, ref_col_idx)) { - thread_recall_score += 1; - break; - } else if (distances.has_value()) { - auto dist = distances.value()(row_idx, col_idx); - auto ref_dist = ref_distances.value()(row_idx, ref_col_idx); - DistanceValueType diff = raft::abs(dist - ref_dist); - DistanceValueType m = std::max(raft::abs(dist), raft::abs(ref_dist)); - DistanceValueType ratio = diff > eps ? diff / m : diff; - - if (ratio <= eps) { - thread_recall_score += 1; - break; - } - } - } - } - - // Reduce across a warp for row score - typedef cub::BlockReduce BlockReduce; - - __shared__ typename BlockReduce::TempStorage temp_storage; - - ScalarType row_recall_score = BlockReduce(temp_storage).Sum(thread_recall_score); - - // Reduce across all rows for global score - if (lane_idx == 0) { - cuda::atomic_ref device_recall_score{ - *recall_score.data_handle()}; - std::size_t const total_count = indices.extent(0) * indices.extent(1); - device_recall_score.fetch_add(row_recall_score / total_count); - } -} - -template -void neighborhood_recall( - raft::resources const& res, - raft::device_matrix_view indices, - raft::device_matrix_view ref_indices, - std::optional> - distances, - std::optional> - ref_distances, - raft::device_scalar_view recall_score, - DistanceValueType const eps) -{ - // One warp per row, launch a warp-width block per-row kernel - auto constexpr kThreadsPerBlock = 32; - auto const num_blocks = indices.extent(0); - - neighborhood_recall<<>>( - indices, ref_indices, distances, ref_distances, recall_score, eps); -} - -} // end namespace raft::stats::detail diff --git a/cpp/include/raft/stats/detail/silhouette_score.cuh b/cpp/include/raft/stats/detail/silhouette_score.cuh deleted file mode 100644 index 4285f84fcc..0000000000 --- a/cpp/include/raft/stats/detail/silhouette_score.cuh +++ /dev/null @@ -1,324 +0,0 @@ -/* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include - -#include -#include -#include - -namespace raft { -namespace stats { -namespace detail { - -/** - * @brief kernel that calculates the average intra-cluster distance for every sample data point and - * updates the cluster distance to max value - * @tparam DataT: type of the data samples - * @tparam LabelT: type of the labels - * @param sampleToClusterSumOfDistances: the pointer to the 2D array that contains the sum of - * distances from every sample to every cluster (nRows x nLabels) - * @param binCountArray: pointer to the 1D array that contains the count of samples per cluster (1 x - * nLabels) - * @param d_aArray: the pointer to the array of average intra-cluster distances for every sample in - * device memory (1 x nRows) - * @param labels: the pointer to the array containing labels for every data sample (1 x nRows) - * @param nRows: number of data samples - * @param nLabels: number of Labels - * @param MAX_VAL: DataT specific upper limit - */ -template -RAFT_KERNEL populateAKernel(DataT* sampleToClusterSumOfDistances, - DataT* binCountArray, - DataT* d_aArray, - const LabelT* labels, - int nRows, - int nLabels, - const DataT MAX_VAL) -{ - // getting the current index - int sampleIndex = threadIdx.x + blockIdx.x * blockDim.x; - - if (sampleIndex >= nRows) return; - - // sampleDistanceVector is an array that stores that particular row of the distanceMatrix - DataT* sampleToClusterSumOfDistancesVector = - &sampleToClusterSumOfDistances[sampleIndex * nLabels]; - - LabelT sampleCluster = labels[sampleIndex]; - - int sampleClusterIndex = (int)sampleCluster; - - if (binCountArray[sampleClusterIndex] - 1 <= 0) { - d_aArray[sampleIndex] = -1; - return; - - } - - else { - d_aArray[sampleIndex] = (sampleToClusterSumOfDistancesVector[sampleClusterIndex]) / - (binCountArray[sampleClusterIndex] - 1); - - // modifying the sampleDistanceVector to give sample average distance - sampleToClusterSumOfDistancesVector[sampleClusterIndex] = MAX_VAL; - } -} - -/** - * @brief function to calculate the bincounts of number of samples in every label - * @tparam DataT: type of the data samples - * @tparam LabelT: type of the labels - * @param labels: the pointer to the array containing labels for every data sample (1 x nRows) - * @param binCountArray: pointer to the 1D array that contains the count of samples per cluster (1 x - * nLabels) - * @param nRows: number of data samples - * @param nUniqueLabels: number of Labels - * @param workspace: device buffer containing workspace memory - * @param stream: the cuda stream where to launch this kernel - */ -template -void countLabels(const LabelT* labels, - DataT* binCountArray, - int nRows, - int nUniqueLabels, - rmm::device_uvector& workspace, - cudaStream_t stream) -{ - int num_levels = nUniqueLabels + 1; - LabelT lower_level = 0; - LabelT upper_level = nUniqueLabels; - size_t temp_storage_bytes = 0; - - rmm::device_uvector countArray(nUniqueLabels, stream); - - RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(nullptr, - temp_storage_bytes, - labels, - binCountArray, - num_levels, - lower_level, - upper_level, - nRows, - stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(workspace.data(), - temp_storage_bytes, - labels, - binCountArray, - num_levels, - lower_level, - upper_level, - nRows, - stream)); -} - -/** - * @brief structure that defines the division Lambda for elementwise op - */ -template -struct DivOp { - HDI DataT operator()(DataT a, int b, int c) - { - if (b == 0) - return ULLONG_MAX; - else - return a / b; - } -}; - -/** - * @brief structure that defines the elementwise operation to calculate silhouette score using - * params 'a' and 'b' - */ -template -struct SilOp { - HDI DataT operator()(DataT a, DataT b) - { - if (a == 0 && b == 0 || a == b) - return 0; - else if (a == -1) - return 0; - else if (a > b) - return (b - a) / a; - else - return (b - a) / b; - } -}; - -/** - * @brief main function that returns the average silhouette score for a given set of data and its - * clusterings - * @tparam DataT: type of the data samples - * @tparam LabelT: type of the labels - * @param X_in: pointer to the input Data samples array (nRows x nCols) - * @param nRows: number of data samples - * @param nCols: number of features - * @param labels: the pointer to the array containing labels for every data sample (1 x nRows) - * @param nLabels: number of Labels - * @param silhouette_scorePerSample: pointer to the array that is optionally taken in as input and - * is populated with the silhouette score for every sample (1 x nRows) - * @param stream: the cuda stream where to launch this kernel - * @param metric: the numerical value that maps to the type of distance metric to be used in the - * calculations - */ -template -DataT silhouette_score( - raft::resources const& handle, - const DataT* X_in, - int nRows, - int nCols, - const LabelT* labels, - int nLabels, - DataT* silhouette_scorePerSample, - cudaStream_t stream, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Unexpanded) -{ - ASSERT(nLabels >= 2 && nLabels <= (nRows - 1), - "silhouette Score not defined for the given number of labels!"); - - // compute the distance matrix - rmm::device_uvector distanceMatrix(nRows * nRows, stream); - rmm::device_uvector workspace(1, stream); - - raft::distance::pairwise_distance( - handle, X_in, X_in, distanceMatrix.data(), nRows, nRows, nCols, metric); - - // deciding on the array of silhouette scores for each dataPoint - rmm::device_uvector silhouette_scoreSamples(0, stream); - DataT* perSampleSilScore = nullptr; - if (silhouette_scorePerSample == nullptr) { - silhouette_scoreSamples.resize(nRows, stream); - perSampleSilScore = silhouette_scoreSamples.data(); - } else { - perSampleSilScore = silhouette_scorePerSample; - } - RAFT_CUDA_TRY(cudaMemsetAsync(perSampleSilScore, 0, nRows * sizeof(DataT), stream)); - - // getting the sample count per cluster - rmm::device_uvector binCountArray(nLabels, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(binCountArray.data(), 0, nLabels * sizeof(DataT), stream)); - countLabels(labels, binCountArray.data(), nRows, nLabels, workspace, stream); - - // calculating the sample-cluster-distance-sum-array - rmm::device_uvector sampleToClusterSumOfDistances(nRows * nLabels, stream); - RAFT_CUDA_TRY(cudaMemsetAsync( - sampleToClusterSumOfDistances.data(), 0, nRows * nLabels * sizeof(DataT), stream)); - raft::linalg::reduce_cols_by_key(distanceMatrix.data(), - labels, - sampleToClusterSumOfDistances.data(), - nRows, - nRows, - nLabels, - stream); - - // creating the a array and b array - rmm::device_uvector d_aArray(nRows, stream); - rmm::device_uvector d_bArray(nRows, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(d_aArray.data(), 0, nRows * sizeof(DataT), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(d_bArray.data(), 0, nRows * sizeof(DataT), stream)); - - // kernel that populates the d_aArray - // kernel configuration - dim3 numThreadsPerBlock(32, 1, 1); - dim3 numBlocks(raft::ceildiv(nRows, numThreadsPerBlock.x), 1, 1); - - // calling the kernel - populateAKernel<<>>( - sampleToClusterSumOfDistances.data(), - binCountArray.data(), - d_aArray.data(), - labels, - nRows, - nLabels, - std::numeric_limits::max()); - - // elementwise dividing by bincounts - rmm::device_uvector averageDistanceBetweenSampleAndCluster(nRows * nLabels, stream); - RAFT_CUDA_TRY(cudaMemsetAsync( - averageDistanceBetweenSampleAndCluster.data(), 0, nRows * nLabels * sizeof(DataT), stream)); - - raft::linalg::matrixVectorOp(averageDistanceBetweenSampleAndCluster.data(), - sampleToClusterSumOfDistances.data(), - binCountArray.data(), - binCountArray.data(), - nLabels, - nRows, - true, - true, - DivOp(), - stream); - - // calculating row-wise minimum - raft::linalg::reduce( - d_bArray.data(), - averageDistanceBetweenSampleAndCluster.data(), - nLabels, - nRows, - std::numeric_limits::max(), - true, - true, - stream, - false, - raft::identity_op{}, - raft::min_op{}); - - // calculating the silhouette score per sample using the d_aArray and d_bArray - raft::linalg::binaryOp>( - perSampleSilScore, d_aArray.data(), d_bArray.data(), nRows, SilOp(), stream); - - // calculating the sum of all the silhouette score - rmm::device_scalar d_avgSilhouetteScore(stream); - RAFT_CUDA_TRY(cudaMemsetAsync(d_avgSilhouetteScore.data(), 0, sizeof(DataT), stream)); - - raft::linalg::mapThenSumReduce(d_avgSilhouetteScore.data(), - nRows, - raft::identity_op(), - stream, - perSampleSilScore, - perSampleSilScore); - - DataT avgSilhouetteScore = d_avgSilhouetteScore.value(stream); - - resource::sync_stream(handle, stream); - - avgSilhouetteScore /= nRows; - - return avgSilhouetteScore; -} - -}; // namespace detail -}; // namespace stats -}; // namespace raft diff --git a/cpp/include/raft/stats/detail/trustworthiness_score.cuh b/cpp/include/raft/stats/detail/trustworthiness_score.cuh deleted file mode 100644 index 8b9e5c2cc8..0000000000 --- a/cpp/include/raft/stats/detail/trustworthiness_score.cuh +++ /dev/null @@ -1,221 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -#include -#include - -#define N_THREADS 512 - -namespace raft { -namespace stats { -namespace detail { - -/** - * @brief Build the lookup table - * @param[out] lookup_table: Lookup table giving nearest neighbor order - * of pairwise distance calculations given sample index - * @param[in] X_ind: Sorted indexes of pairwise distance calculations of X - * @param n: Number of samples - * @param work: Number of elements to consider - */ -RAFT_KERNEL build_lookup_table(int* lookup_table, const int* X_ind, int n, int work) -{ - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= work) return; - - int sample_idx = i / n; - int nn_idx = i % n; - - int idx = X_ind[i]; - lookup_table[(sample_idx * n) + idx] = nn_idx; -} - -/** - * @brief Compute a the rank of trustworthiness score - * @param[out] rank: Resulting rank - * @param[out] lookup_table: Lookup table giving nearest neighbor order - * of pairwise distance calculations given sample index - * @param[in] emb_ind: Indexes of KNN on embeddings - * @param n: Number of samples - * @param n_neighbors: Number of neighbors considered by trustworthiness score - * @param work: Batch to consider (to do it at once use n * n_neighbors) - */ -template -RAFT_KERNEL compute_rank(double* rank, - const int* lookup_table, - const knn_index_t* emb_ind, - int n, - int n_neighbors, - int work) -{ - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= work) return; - - int sample_idx = i / n_neighbors; - - knn_index_t emb_nn_ind = emb_ind[i]; - - int r = lookup_table[(sample_idx * n) + emb_nn_ind]; - int tmp = r - n_neighbors + 1; - if (tmp > 0) raft::myAtomicAdd(rank, tmp); -} - -/** - * @brief Compute a kNN and returns the indices of the nearest neighbors - * @param h Raft handle - * @param[in] input Input matrix containing the dataset - * @param n Number of samples - * @param d Number of features - * @param n_neighbors number of neighbors - * @param[out] indices KNN indexes - * @param[out] distances KNN distances - */ -template -void run_knn(const raft::resources& h, - math_t* input, - int n, - int d, - int n_neighbors, - int64_t* indices, - math_t* distances) -{ - std::vector ptrs(1); - std::vector sizes(1); - ptrs[0] = input; - sizes[0] = n; - - raft::spatial::knn::brute_force_knn(h, - ptrs, - sizes, - d, - input, - n, - indices, - distances, - n_neighbors, - true, - true, - nullptr, - distance_type); -} - -/** - * @brief Compute the trustworthiness score - * @param h Raft handle - * @param X[in]: Data in original dimension - * @param X_embedded[in]: Data in target dimension (embedding) - * @param n: Number of samples - * @param m: Number of features in high/original dimension - * @param d: Number of features in low/embedded dimension - * @param n_neighbors Number of neighbors considered by trustworthiness score - * @param batchSize Batch size - * @return Trustworthiness score - */ -template -double trustworthiness_score(const raft::resources& h, - const math_t* X, - math_t* X_embedded, - int n, - int m, - int d, - int n_neighbors, - int batchSize = 512) -{ - cudaStream_t stream = resource::get_cuda_stream(h); - - const int KNN_ALLOC = n * (n_neighbors + 1); - rmm::device_uvector emb_ind(KNN_ALLOC, stream); - rmm::device_uvector emb_dist(KNN_ALLOC, stream); - - run_knn(h, X_embedded, n, d, n_neighbors + 1, emb_ind.data(), emb_dist.data()); - - const int PAIRWISE_ALLOC = batchSize * n; - rmm::device_uvector X_ind(PAIRWISE_ALLOC, stream); - rmm::device_uvector X_dist(PAIRWISE_ALLOC, stream); - rmm::device_uvector lookup_table(PAIRWISE_ALLOC, stream); - - double t = 0.0; - rmm::device_scalar t_dbuf(stream); - - int toDo = n; - while (toDo > 0) { - int curBatchSize = min(toDo, batchSize); - - // Takes at most batchSize vectors at a time - raft::distance::pairwise_distance( - h, &X[(n - toDo) * m], X, X_dist.data(), curBatchSize, n, m, distance_type); - - size_t colSortWorkspaceSize = 0; - bool bAllocWorkspace = false; - - raft::matrix::sort_cols_per_row(X_dist.data(), - X_ind.data(), - curBatchSize, - n, - bAllocWorkspace, - nullptr, - colSortWorkspaceSize, - stream); - - if (bAllocWorkspace) { - rmm::device_uvector sortColsWorkspace(colSortWorkspaceSize, stream); - - raft::matrix::sort_cols_per_row(X_dist.data(), - X_ind.data(), - curBatchSize, - n, - bAllocWorkspace, - sortColsWorkspace.data(), - colSortWorkspaceSize, - stream); - } - - int work = curBatchSize * n; - int n_blocks = raft::ceildiv(work, N_THREADS); - build_lookup_table<<>>( - lookup_table.data(), X_ind.data(), n, work); - - RAFT_CUDA_TRY(cudaMemsetAsync(t_dbuf.data(), 0, sizeof(double), stream)); - - work = curBatchSize * (n_neighbors + 1); - n_blocks = raft::ceildiv(work, N_THREADS); - compute_rank<<>>( - t_dbuf.data(), - lookup_table.data(), - &emb_ind.data()[(n - toDo) * (n_neighbors + 1)], - n, - n_neighbors + 1, - work); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - t += t_dbuf.value(stream); - - toDo -= curBatchSize; - } - - t = 1.0 - ((2.0 / ((n * n_neighbors) * ((2.0 * n) - (3.0 * n_neighbors) - 1.0))) * t); - - return t; -} - -} // namespace detail -} // namespace stats -} // namespace raft diff --git a/cpp/include/raft/stats/neighborhood_recall.cuh b/cpp/include/raft/stats/neighborhood_recall.cuh deleted file mode 100644 index e082bc87b4..0000000000 --- a/cpp/include/raft/stats/neighborhood_recall.cuh +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "detail/neighborhood_recall.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::stats { - -/** - * @defgroup stats_neighborhood_recall Neighborhood Recall Score - * @{ - */ - -/** - * @brief Calculate Neighborhood Recall score on the device for indices, distances computed by any - * Nearest Neighbors Algorithm against reference indices, distances. Recall score is calculated by - * comparing the total number of matching indices and dividing that value by the total size of the - * indices matrix of dimensions (D, k). If distance matrices are provided, then non-matching indices - * could be considered a match if abs(dist, ref_dist) < eps. - * - * Usage example: - * @code{.cpp} - * raft::device_resources res; - * // assume D rows and N column dataset - * auto k = 64; - * auto indices = raft::make_device_matrix(res, D, k); - * auto distances = raft::make_device_matrix(res, D, k); - * // run ANN algorithm of choice - * - * auto ref_indices = raft::make_device_matrix(res, D, k); - * auto ref_distances = raft::make_device_matrix(res, D, k); - * // run brute-force KNN for reference - * - * auto scalar = 0.0f; - * auto recall_score = raft::make_device_scalar(res, scalar); - * - * raft::stats::neighborhood_recall(res, - raft::make_const_mdspan(indices.view()), - raft::make_const_mdspan(ref_indices.view()), - recall_score.view(), - raft::make_const_mdspan(distances.view()), - raft::make_const_mdspan(ref_distances.view())); - * @endcode - * - * @tparam IndicesValueType data-type of the indices - * @tparam IndexType data-type to index all matrices - * @tparam ScalarType data-type to store recall score - * @tparam DistanceValueType data-type of the distances - * @param res raft::resources object to manage resources - * @param[in] indices raft::device_matrix_view indices of neighbors - * @param[in] ref_indices raft::device_matrix_view reference indices of neighbors - * @param[out] recall_score raft::device_scalar_view output recall score - * @param[in] distances (optional) raft::device_matrix_view distances of neighbors - * @param[in] ref_distances (optional) raft::device_matrix_view reference distances of neighbors - * @param[in] eps (optional, default = 0.001) value within which distances are considered matching - */ -template -void neighborhood_recall( - raft::resources const& res, - raft::device_matrix_view indices, - raft::device_matrix_view ref_indices, - raft::device_scalar_view recall_score, - std::optional> - distances = std::nullopt, - std::optional> - ref_distances = std::nullopt, - std::optional> eps = std::nullopt) -{ - RAFT_EXPECTS(indices.extent(0) == ref_indices.extent(0), - "The number of rows in indices and reference indices should be equal"); - RAFT_EXPECTS(indices.extent(1) == ref_indices.extent(1), - "The number of columns in indices and reference indices should be equal"); - - if (distances.has_value() or ref_distances.has_value()) { - RAFT_EXPECTS(distances.has_value() and ref_distances.has_value(), - "Both distances and reference distances should have values"); - - RAFT_EXPECTS(distances.value().extent(0) == ref_distances.value().extent(0), - "The number of rows in distances and reference distances should be equal"); - RAFT_EXPECTS(distances.value().extent(1) == ref_distances.value().extent(1), - "The number of columns in indices and reference indices should be equal"); - - RAFT_EXPECTS(indices.extent(0) == distances.value().extent(0), - "The number of rows in indices and distances should be equal"); - RAFT_EXPECTS(indices.extent(1) == distances.value().extent(1), - "The number of columns in indices and distances should be equal"); - } - - DistanceValueType eps_val = 0.001; - if (eps.has_value()) { eps_val = *eps.value().data_handle(); } - - detail::neighborhood_recall( - res, indices, ref_indices, distances, ref_distances, recall_score, eps_val); -} - -/** - * @brief Calculate Neighborhood Recall score on the host for indices, distances computed by any - * Nearest Neighbors Algorithm against reference indices, distances. Recall score is calculated by - * comparing the total number of matching indices and dividing that value by the total size of the - * indices matrix of dimensions (D, k). If distance matrices are provided, then non-matching indices - * could be considered a match if abs(dist, ref_dist) < eps. - * - * Usage example: - * @code{.cpp} - * raft::device_resources res; - * // assume D rows and N column dataset - * auto k = 64; - * auto indices = raft::make_device_matrix(res, D, k); - * auto distances = raft::make_device_matrix(res, D, k); - * // run ANN algorithm of choice - * - * auto ref_indices = raft::make_device_matrix(res, D, k); - * auto ref_distances = raft::make_device_matrix(res, D, k); - * // run brute-force KNN for reference - * - * auto scalar = 0.0f; - * auto recall_score = raft::make_host_scalar(scalar); - * - * raft::stats::neighborhood_recall(res, - raft::make_const_mdspan(indices.view()), - raft::make_const_mdspan(ref_indices.view()), - recall_score.view(), - raft::make_const_mdspan(distances.view()), - raft::make_const_mdspan(ref_distances.view())); - * @endcode - * - * @tparam IndicesValueType data-type of the indices - * @tparam IndexType data-type to index all matrices - * @tparam ScalarType data-type to store recall score - * @tparam DistanceValueType data-type of the distances - * @param res raft::resources object to manage resources - * @param[in] indices raft::device_matrix_view indices of neighbors - * @param[in] ref_indices raft::device_matrix_view reference indices of neighbors - * @param[out] recall_score raft::host_scalar_view output recall score - * @param[in] distances (optional) raft::device_matrix_view distances of neighbors - * @param[in] ref_distances (optional) raft::device_matrix_view reference distances of neighbors - * @param[in] eps (optional, default = 0.001) value within which distances are considered matching - */ -template -void neighborhood_recall( - raft::resources const& res, - raft::device_matrix_view indices, - raft::device_matrix_view ref_indices, - raft::host_scalar_view recall_score, - std::optional> - distances = std::nullopt, - std::optional> - ref_distances = std::nullopt, - std::optional> eps = std::nullopt) -{ - auto recall_score_d = raft::make_device_scalar(res, *recall_score.data_handle()); - neighborhood_recall( - res, indices, ref_indices, recall_score_d.view(), distances, ref_distances, eps); - raft::update_host(recall_score.data_handle(), - recall_score_d.data_handle(), - 1, - raft::resource::get_cuda_stream(res)); - raft::resource::sync_stream(res); -} - -/** @} */ // end group stats_recall - -} // end namespace raft::stats diff --git a/cpp/include/raft/stats/silhouette_score.cuh b/cpp/include/raft/stats/silhouette_score.cuh deleted file mode 100644 index 23eef84604..0000000000 --- a/cpp/include/raft/stats/silhouette_score.cuh +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef __SILHOUETTE_SCORE_H -#define __SILHOUETTE_SCORE_H - -#pragma once - -#include -#include -#include -#include - -namespace raft { -namespace stats { - -/** - * @brief main function that returns the average silhouette score for a given set of data and its - * clusterings - * @tparam DataT: type of the data samples - * @tparam LabelT: type of the labels - * @param handle: raft handle for managing expensive resources - * @param X_in: pointer to the input Data samples array (nRows x nCols) - * @param nRows: number of data samples - * @param nCols: number of features - * @param labels: the pointer to the array containing labels for every data sample (1 x nRows) - * @param nLabels: number of Labels - * @param silhouette_scorePerSample: pointer to the array that is optionally taken in as input and - * is populated with the silhouette score for every sample (1 x nRows) - * @param stream: the cuda stream where to launch this kernel - * @param metric: the numerical value that maps to the type of distance metric to be used in the - * calculations - */ -template -DataT silhouette_score( - raft::resources const& handle, - DataT* X_in, - int nRows, - int nCols, - LabelT* labels, - int nLabels, - DataT* silhouette_scorePerSample, - cudaStream_t stream, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Unexpanded) -{ - return detail::silhouette_score( - handle, X_in, nRows, nCols, labels, nLabels, silhouette_scorePerSample, stream, metric); -} - -template -value_t silhouette_score_batched( - raft::resources const& handle, - value_t* X, - value_idx n_rows, - value_idx n_cols, - label_idx* y, - label_idx n_labels, - value_t* scores, - value_idx chunk, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Unexpanded) -{ - return batched::detail::silhouette_score( - handle, X, n_rows, n_cols, y, n_labels, scores, chunk, metric); -} - -/** - * @defgroup stats_silhouette_score Silhouette Score - * @{ - */ - -/** - * @brief main function that returns the average silhouette score for a given set of data and its - * clusterings - * @tparam value_t: type of the data samples - * @tparam label_t: type of the labels - * @tparam idx_t index type - * @param[in] handle: raft handle for managing expensive resources - * @param[in] X_in: input matrix Data in row-major format (nRows x nCols) - * @param[in] labels: the pointer to the array containing labels for every data sample (length: - * nRows) - * @param[out] silhouette_score_per_sample: optional array populated with the silhouette score - * for every sample (length: nRows) - * @param[in] n_unique_labels: number of unique labels in the labels array - * @param[in] metric: the numerical value that maps to the type of distance metric to be used in - * the calculations - * @return: The silhouette score. - */ -template -value_t silhouette_score( - raft::resources const& handle, - raft::device_matrix_view X_in, - raft::device_vector_view labels, - std::optional> silhouette_score_per_sample, - idx_t n_unique_labels, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Unexpanded) -{ - RAFT_EXPECTS(labels.extent(0) == X_in.extent(0), "Size mismatch between labels and data"); - - value_t* silhouette_score_per_sample_ptr = nullptr; - if (silhouette_score_per_sample.has_value()) { - silhouette_score_per_sample_ptr = silhouette_score_per_sample.value().data_handle(); - RAFT_EXPECTS(silhouette_score_per_sample.value().extent(0) == X_in.extent(0), - "Size mismatch between silhouette_score_per_sample and data"); - } - return detail::silhouette_score(handle, - X_in.data_handle(), - X_in.extent(0), - X_in.extent(1), - labels.data_handle(), - n_unique_labels, - silhouette_score_per_sample_ptr, - resource::get_cuda_stream(handle), - metric); -} - -/** - * @brief function that returns the average silhouette score for a given set of data and its - * clusterings - * @tparam value_t: type of the data samples - * @tparam label_t: type of the labels - * @tparam idx_t index type - * @param[in] handle: raft handle for managing expensive resources - * @param[in] X: input matrix Data in row-major format (nRows x nCols) - * @param[in] labels: the pointer to the array containing labels for every data sample (length: - * nRows) - * @param[out] silhouette_score_per_sample: optional array populated with the silhouette score - * for every sample (length: nRows) - * @param[in] n_unique_labels: number of unique labels in the labels array - * @param[in] batch_size: number of samples per batch - * @param[in] metric: the numerical value that maps to the type of distance metric to be used in - * the calculations - * @return: The silhouette score. - */ -template -value_t silhouette_score_batched( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view labels, - std::optional> silhouette_score_per_sample, - idx_t n_unique_labels, - idx_t batch_size, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Unexpanded) -{ - static_assert(std::is_integral_v, - "silhouette_score_batched: The index type " - "of each mdspan argument must be an integral type."); - static_assert(std::is_integral_v, - "silhouette_score_batched: The label type must be an integral type."); - RAFT_EXPECTS(labels.extent(0) == X.extent(0), "Size mismatch between labels and data"); - - value_t* scores_ptr = nullptr; - if (silhouette_score_per_sample.has_value()) { - scores_ptr = silhouette_score_per_sample.value().data_handle(); - RAFT_EXPECTS(silhouette_score_per_sample.value().extent(0) == X.extent(0), - "Size mismatch between silhouette_score_per_sample and data"); - } - return batched::detail::silhouette_score(handle, - X.data_handle(), - X.extent(0), - X.extent(1), - labels.data_handle(), - n_unique_labels, - scores_ptr, - batch_size, - metric); -} - -/** @} */ // end group stats_silhouette_score - -/** - * @brief Overload of `silhouette_score` to help the - * compiler find the above overload, in case users pass in - * `std::nullopt` for the optional arguments. - * - * Please see above for documentation of `silhouette_score`. - */ -template -value_t silhouette_score( - raft::resources const& handle, - raft::device_matrix_view X_in, - raft::device_vector_view labels, - std::nullopt_t silhouette_score_per_sample, - idx_t n_unique_labels, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Unexpanded) -{ - std::optional> opt_scores = silhouette_score_per_sample; - return silhouette_score(handle, X_in, labels, opt_scores, n_unique_labels, metric); -} - -/** - * @brief Overload of `silhouette_score_batched` to help the - * compiler find the above overload, in case users pass in - * `std::nullopt` for the optional arguments. - * - * Please see above for documentation of `silhouette_score_batched`. - */ -template -value_t silhouette_score_batched( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view labels, - std::nullopt_t silhouette_score_per_sample, - idx_t n_unique_labels, - idx_t batch_size, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Unexpanded) -{ - std::optional> opt_scores = silhouette_score_per_sample; - return silhouette_score_batched( - handle, X, labels, opt_scores, n_unique_labels, batch_size, metric); -} -}; // namespace stats -}; // namespace raft - -#endif \ No newline at end of file diff --git a/cpp/include/raft/stats/trustworthiness_score.cuh b/cpp/include/raft/stats/trustworthiness_score.cuh deleted file mode 100644 index 3f4464f4d3..0000000000 --- a/cpp/include/raft/stats/trustworthiness_score.cuh +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __TRUSTWORTHINESS_SCORE_H -#define __TRUSTWORTHINESS_SCORE_H - -#pragma once -#include -#include -#include - -namespace raft { -namespace stats { - -/** - * @brief Compute the trustworthiness score - * @param[in] h: raft handle - * @param[in] X: Data in original dimension - * @param[in] X_embedded: Data in target dimension (embedding) - * @param[in] n: Number of samples - * @param[in] m: Number of features in high/original dimension - * @param[in] d: Number of features in low/embedded dimension - * @param[in] n_neighbors Number of neighbors considered by trustworthiness score - * @param[in] batchSize Batch size - * @return[out] Trustworthiness score - */ -template -double trustworthiness_score(const raft::resources& h, - const math_t* X, - math_t* X_embedded, - int n, - int m, - int d, - int n_neighbors, - int batchSize = 512) -{ - return detail::trustworthiness_score( - h, X, X_embedded, n, m, d, n_neighbors, batchSize); -} - -/** - * @defgroup stats_trustworthiness Trustworthiness - * @{ - */ - -/** - * @brief Compute the trustworthiness score - * @tparam value_t the data type - * @tparam idx_t Integer type used to for addressing - * @param[in] handle the raft handle - * @param[in] X: Data in original dimension - * @param[in] X_embedded: Data in target dimension (embedding) - * @param[in] n_neighbors Number of neighbors considered by trustworthiness score - * @param[in] batch_size Batch size - * @return Trustworthiness score - * @note The constness of the data in X_embedded is currently casted away and the data is slightly - * modified. - */ -template -double trustworthiness_score( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view X_embedded, - int n_neighbors, - int batch_size = 512) -{ - RAFT_EXPECTS(X.extent(0) == X_embedded.extent(0), "Size mismatch between X and X_embedded"); - RAFT_EXPECTS(std::is_integral_v && X.extent(0) <= std::numeric_limits::max(), - "Index type not supported"); - - // TODO: Change the underlying implementation to remove the need to const_cast X_embedded. - return detail::trustworthiness_score( - handle, - X.data_handle(), - const_cast(X_embedded.data_handle()), - X.extent(0), - X.extent(1), - X_embedded.extent(1), - n_neighbors, - batch_size); -} - -/** @} */ // end group stats_trustworthiness - -} // namespace stats -} // namespace raft - -#endif \ No newline at end of file