Skip to content

Commit

Permalink
Add multigpu kmeans fit function (#348)
Browse files Browse the repository at this point in the history
Changes to support using kmeans clustering inside of cuml, so we can transition cuml off of the RAFT kmeans code

* Add a multigpu kmeans fit function
* Adds instantiations for kmeans on int64_t indicies, which unfortunately also requires int64_t indices for the PW distance functions
* Add support for `double` precision kmeans

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #348
  • Loading branch information
benfred authored Oct 2, 2024
1 parent ce01a0b commit 2fe2e88
Show file tree
Hide file tree
Showing 25 changed files with 2,059 additions and 43 deletions.
8 changes: 8 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,22 @@ target_compile_options(
add_library(
cuvs SHARED
src/cluster/kmeans_balanced_fit_float.cu
src/cluster/kmeans_fit_mg_float.cu
src/cluster/kmeans_fit_mg_double.cu
src/cluster/kmeans_fit_double.cu
src/cluster/kmeans_fit_float.cu
src/cluster/kmeans_auto_find_k_float.cu
src/cluster/kmeans_fit_predict_double.cu
src/cluster/kmeans_fit_predict_float.cu
src/cluster/kmeans_predict_double.cu
src/cluster/kmeans_predict_float.cu
src/cluster/kmeans_balanced_fit_float.cu
src/cluster/kmeans_balanced_fit_predict_float.cu
src/cluster/kmeans_balanced_predict_float.cu
src/cluster/kmeans_balanced_fit_int8.cu
src/cluster/kmeans_balanced_fit_predict_int8.cu
src/cluster/kmeans_balanced_predict_int8.cu
src/cluster/kmeans_transform_double.cu
src/cluster/kmeans_transform_float.cu
src/cluster/single_linkage_float.cu
src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu
Expand Down Expand Up @@ -342,6 +348,8 @@ add_library(
src/distance/detail/pairwise_matrix/dispatch_russel_rao_half_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_rbf.cu
src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int64_t.cu
src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int64_t.cu
src/distance/detail/fused_distance_nn.cu
src/distance/distance.cu
src/distance/pairwise_distance.cu
Expand Down
506 changes: 485 additions & 21 deletions cpp/include/cuvs/cluster/kmeans.hpp

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions cpp/src/cluster/detail/connectivities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#pragma once

#include "../../distance/distance.cuh"
#include "./kmeans_common.cuh"
#include <cuvs/cluster/agglomerative.hpp>
#include <cuvs/distance/distance.hpp>
#include <raft/core/resource/cuda_stream.hpp>
Expand Down Expand Up @@ -153,7 +153,11 @@ void pairwise_distances(const raft::resources& handle,
// TODO: It would ultimately be nice if the MST could accept
// dense inputs directly so we don't need to double the memory
// usage to hand it a sparse array here.
distance::pairwise_distance<value_t, value_idx>(handle, X, X, data, m, m, n, metric);
auto X_view = raft::make_device_matrix_view<const value_t, value_idx>(X, m, n);

cuvs::cluster::kmeans::detail::pairwise_distance_kmeans<value_t, value_idx>(
handle, X_view, X_view, raft::make_device_matrix_view<value_t, value_idx>(data, m, m), metric);

// self-loops get max distance
auto transform_in =
thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), data));
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void kmeansPlusPlus(raft::resources const& handle,
// Output - pwd [n_trials x n_samples]
auto pwd = distBuffer.view();
cuvs::cluster::kmeans::detail::pairwise_distance_kmeans<DataT, IndexT>(
handle, centroidCandidates.view(), X, pwd, workspace, metric);
handle, centroidCandidates.view(), X, pwd, metric);

// Update nearest cluster distance for each centroid candidate
// Note pwd and minDistBuf points to same buffer which currently holds pairwise distance values.
Expand Down Expand Up @@ -1247,7 +1247,7 @@ void kmeans_transform(raft::resources const& handle,
// calculate pairwise distance between cluster centroids and current batch
// of input dataset
pairwise_distance_kmeans<DataT, IndexT>(
handle, datasetView, centroids, pairwiseDistanceView, workspace, metric);
handle, datasetView, centroids, pairwiseDistanceView, metric);
}
}

Expand Down
31 changes: 19 additions & 12 deletions cpp/src/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ void pairwise_distance_kmeans(raft::resources const& handle,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::device_matrix_view<DataT, IndexT> pairwiseDistance,
rmm::device_uvector<char>& workspace,
cuvs::distance::DistanceType metric)
{
auto n_samples = X.extent(0);
Expand All @@ -303,15 +302,23 @@ void pairwise_distance_kmeans(raft::resources const& handle,
ASSERT(X.extent(1) == centroids.extent(1),
"# features in dataset and centroids are different (must be same)");

cuvs::distance::pairwise_distance(handle,
X.data_handle(),
centroids.data_handle(),
pairwiseDistance.data_handle(),
n_samples,
n_clusters,
n_features,
workspace,
metric);
if (metric == cuvs::distance::DistanceType::L2Expanded) {
cuvs::distance::distance<cuvs::distance::DistanceType::L2Expanded,
DataT,
DataT,
DataT,
raft::layout_c_contiguous,
IndexT>(handle, X, centroids, pairwiseDistance);
} else if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
cuvs::distance::distance<cuvs::distance::DistanceType::L2SqrtExpanded,
DataT,
DataT,
DataT,
raft::layout_c_contiguous,
IndexT>(handle, X, centroids, pairwiseDistance);
} else {
RAFT_FAIL("kmeans requires L2Expanded or L2SqrtExpanded distance, have %i", metric);
}
}

// shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores
Expand Down Expand Up @@ -461,7 +468,7 @@ void minClusterAndDistanceCompute(
// calculate pairwise distance between current tile of cluster centroids
// and input dataset
pairwise_distance_kmeans<DataT, IndexT>(
handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric);
handle, datasetView, centroidsView, pairwiseDistanceView, metric);

// argmin reduction returning <index, value> pair
// calculates the closest centroid and the distance to the closest
Expand Down Expand Up @@ -591,7 +598,7 @@ void minClusterDistanceCompute(raft::resources const& handle,
// calculate pairwise distance between current tile of cluster centroids
// and input dataset
pairwise_distance_kmeans<DataT, IndexT>(
handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric);
handle, datasetView, centroidsView, pairwiseDistanceView, metric);

raft::linalg::coalescedReduction(minClusterDistanceView.data_handle(),
pairwiseDistanceView.data_handle(),
Expand Down
Loading

0 comments on commit 2fe2e88

Please sign in to comment.