diff --git a/cpp/include/cuvs/distance/distance.hpp b/cpp/include/cuvs/distance/distance.hpp index e4c45605d..42c574e58 100644 --- a/cpp/include/cuvs/distance/distance.hpp +++ b/cpp/include/cuvs/distance/distance.hpp @@ -372,6 +372,46 @@ void pairwise_distance(raft::resources const& handle, cuvs::distance::DistanceType metric, float metric_arg = 2.0f); +/** + * @brief Compute sparse pairwise distances between x and y, using the provided + * input configuration and distance function. + * + * @code{.cpp} + * #include + * #include + * #include + * + * int x_n_rows = 100000; + * int y_n_rows = 50000; + * int n_cols = 10000; + * + * raft::device_resources handle; + * auto x = raft::make_device_csr_matrix(handle, x_n_rows, n_cols); + * auto y = raft::make_device_csr_matrix(handle, y_n_rows, n_cols); + * + * ... + * // populate data + * ... + * + * auto out = raft::make_device_matrix(handle, x_nrows, y_nrows); + * auto metric = cuvs::distance::DistanceType::L2Expanded; + * raft::sparse::distance::pairwise_distance(handle, x.view(), y.view(), out, metric); + * @endcode + * + * @param[in] handle raft::resources + * @param[in] x raft::device_csr_matrix_view + * @param[in] y raft::device_csr_matrix_view + * @param[out] dist raft::device_matrix_view dense matrix + * @param[in] metric distance metric to use + * @param[in] metric_arg metric argument (used for Minkowski distance) + */ +void pairwise_distance(raft::resources const& handle, + raft::device_csr_matrix_view x, + raft::device_csr_matrix_view y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg = 2.0f); + /** @} */ // end group pairwise_distance_runtime }; // namespace cuvs::distance diff --git a/cpp/src/distance/sparse_distance.cu b/cpp/src/distance/sparse_distance.cu index 4891ca1e6..338c4e908 100644 --- a/cpp/src/distance/sparse_distance.cu +++ b/cpp/src/distance/sparse_distance.cu @@ -22,12 +22,14 @@ namespace cuvs { namespace distance { -void pairwise_distance(raft::resources const& handle, - raft::device_csr_matrix_view x, - raft::device_csr_matrix_view y, - raft::device_matrix_view dist, - cuvs::distance::DistanceType metric, - float metric_arg) +template +void pairwise_distance( + raft::resources const& handle, + raft::device_csr_matrix_view x, + raft::device_csr_matrix_view y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg = 2.0f) { auto x_structure = x.structure_view(); auto y_structure = y.structure_view(); @@ -42,22 +44,42 @@ void pairwise_distance(raft::resources const& handle, "Number of columns in output must be equal to " "number of rows in Y"); - detail::sparse::distances_config_t input_config(handle); + detail::sparse::distances_config_t input_config(handle); input_config.a_nrows = x_structure.get_n_rows(); input_config.a_ncols = x_structure.get_n_cols(); input_config.a_nnz = x_structure.get_nnz(); - input_config.a_indptr = const_cast(x_structure.get_indptr().data()); - input_config.a_indices = const_cast(x_structure.get_indices().data()); - input_config.a_data = const_cast(x.get_elements().data()); + input_config.a_indptr = const_cast(x_structure.get_indptr().data()); + input_config.a_indices = const_cast(x_structure.get_indices().data()); + input_config.a_data = const_cast(x.get_elements().data()); input_config.b_nrows = y_structure.get_n_rows(); input_config.b_ncols = y_structure.get_n_cols(); input_config.b_nnz = y_structure.get_nnz(); - input_config.b_indptr = const_cast(y_structure.get_indptr().data()); - input_config.b_indices = const_cast(y_structure.get_indices().data()); - input_config.b_data = const_cast(y.get_elements().data()); + input_config.b_indptr = const_cast(y_structure.get_indptr().data()); + input_config.b_indices = const_cast(y_structure.get_indices().data()); + input_config.b_data = const_cast(y.get_elements().data()); pairwiseDistance(dist.data_handle(), input_config, metric, metric_arg); } + +void pairwise_distance(raft::resources const& handle, + raft::device_csr_matrix_view x, + raft::device_csr_matrix_view y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg) +{ + pairwise_distance(handle, x, y, dist, metric, metric_arg); +} + +void pairwise_distance(raft::resources const& handle, + raft::device_csr_matrix_view x, + raft::device_csr_matrix_view y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg) +{ + pairwise_distance(handle, x, y, dist, metric, metric_arg); +} } // namespace distance } // namespace cuvs