Skip to content

Commit

Permalink
add double instantiations for sparse pw distance
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Nov 20, 2024
1 parent 1f19a9e commit 8e32018
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 13 deletions.
40 changes: 40 additions & 0 deletions cpp/include/cuvs/distance/distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,46 @@ void pairwise_distance(raft::resources const& handle,
cuvs::distance::DistanceType metric,
float metric_arg = 2.0f);

/**
* @brief Compute sparse pairwise distances between x and y, using the provided
* input configuration and distance function.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
* #include <raft/core/device_csr_matrix.hpp>
* #include <raft/core/device_mdspan.hpp>
*
* int x_n_rows = 100000;
* int y_n_rows = 50000;
* int n_cols = 10000;
*
* raft::device_resources handle;
* auto x = raft::make_device_csr_matrix<double>(handle, x_n_rows, n_cols);
* auto y = raft::make_device_csr_matrix<double>(handle, y_n_rows, n_cols);
*
* ...
* // populate data
* ...
*
* auto out = raft::make_device_matrix<double>(handle, x_nrows, y_nrows);
* auto metric = cuvs::distance::DistanceType::L2Expanded;
* raft::sparse::distance::pairwise_distance(handle, x.view(), y.view(), out, metric);
* @endcode
*
* @param[in] handle raft::resources
* @param[in] x raft::device_csr_matrix_view
* @param[in] y raft::device_csr_matrix_view
* @param[out] dist raft::device_matrix_view dense matrix
* @param[in] metric distance metric to use
* @param[in] metric_arg metric argument (used for Minkowski distance)
*/
void pairwise_distance(raft::resources const& handle,
raft::device_csr_matrix_view<const double, int, int, int> x,
raft::device_csr_matrix_view<const double, int, int, int> y,
raft::device_matrix_view<double, int, raft::row_major> dist,
cuvs::distance::DistanceType metric,
float metric_arg = 2.0f);

/** @} */ // end group pairwise_distance_runtime

}; // namespace cuvs::distance
48 changes: 35 additions & 13 deletions cpp/src/distance/sparse_distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
namespace cuvs {
namespace distance {

void pairwise_distance(raft::resources const& handle,
raft::device_csr_matrix_view<const float, int, int, int> x,
raft::device_csr_matrix_view<const float, int, int, int> y,
raft::device_matrix_view<float, int, raft::row_major> dist,
cuvs::distance::DistanceType metric,
float metric_arg)
template <typename ElementType, typename IndexType>
void pairwise_distance(
raft::resources const& handle,
raft::device_csr_matrix_view<const ElementType, IndexType, IndexType, IndexType> x,
raft::device_csr_matrix_view<const ElementType, IndexType, IndexType, IndexType> y,
raft::device_matrix_view<ElementType, IndexType, raft::row_major> dist,
cuvs::distance::DistanceType metric,
float metric_arg = 2.0f)
{
auto x_structure = x.structure_view();
auto y_structure = y.structure_view();
Expand All @@ -42,22 +44,42 @@ void pairwise_distance(raft::resources const& handle,
"Number of columns in output must be equal to "
"number of rows in Y");

detail::sparse::distances_config_t<int, float> input_config(handle);
detail::sparse::distances_config_t<IndexType, ElementType> input_config(handle);
input_config.a_nrows = x_structure.get_n_rows();
input_config.a_ncols = x_structure.get_n_cols();
input_config.a_nnz = x_structure.get_nnz();
input_config.a_indptr = const_cast<int*>(x_structure.get_indptr().data());
input_config.a_indices = const_cast<int*>(x_structure.get_indices().data());
input_config.a_data = const_cast<float*>(x.get_elements().data());
input_config.a_indptr = const_cast<IndexType*>(x_structure.get_indptr().data());
input_config.a_indices = const_cast<IndexType*>(x_structure.get_indices().data());
input_config.a_data = const_cast<ElementType*>(x.get_elements().data());

input_config.b_nrows = y_structure.get_n_rows();
input_config.b_ncols = y_structure.get_n_cols();
input_config.b_nnz = y_structure.get_nnz();
input_config.b_indptr = const_cast<int*>(y_structure.get_indptr().data());
input_config.b_indices = const_cast<int*>(y_structure.get_indices().data());
input_config.b_data = const_cast<float*>(y.get_elements().data());
input_config.b_indptr = const_cast<IndexType*>(y_structure.get_indptr().data());
input_config.b_indices = const_cast<IndexType*>(y_structure.get_indices().data());
input_config.b_data = const_cast<ElementType*>(y.get_elements().data());

pairwiseDistance(dist.data_handle(), input_config, metric, metric_arg);
}

void pairwise_distance(raft::resources const& handle,
raft::device_csr_matrix_view<const float, int, int, int> x,
raft::device_csr_matrix_view<const float, int, int, int> y,
raft::device_matrix_view<float, int, raft::row_major> dist,
cuvs::distance::DistanceType metric,
float metric_arg)
{
pairwise_distance<float, int>(handle, x, y, dist, metric, metric_arg);
}

void pairwise_distance(raft::resources const& handle,
raft::device_csr_matrix_view<const double, int, int, int> x,
raft::device_csr_matrix_view<const double, int, int, int> y,
raft::device_matrix_view<double, int, raft::row_major> dist,
cuvs::distance::DistanceType metric,
float metric_arg)
{
pairwise_distance<double, int>(handle, x, y, dist, metric, metric_arg);
}
} // namespace distance
} // namespace cuvs

0 comments on commit 8e32018

Please sign in to comment.