Skip to content

Commit

Permalink
Merge pull request #373 from rapidsai/branch-24.10
Browse files Browse the repository at this point in the history
Forward-merge branch-24.10 into branch-24.12
  • Loading branch information
GPUtester authored Oct 2, 2024
2 parents 260bfb4 + 2fe2e88 commit 40f4a58
Show file tree
Hide file tree
Showing 25 changed files with 2,059 additions and 43 deletions.
8 changes: 8 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,22 @@ target_compile_options(
add_library(
cuvs SHARED
src/cluster/kmeans_balanced_fit_float.cu
src/cluster/kmeans_fit_mg_float.cu
src/cluster/kmeans_fit_mg_double.cu
src/cluster/kmeans_fit_double.cu
src/cluster/kmeans_fit_float.cu
src/cluster/kmeans_auto_find_k_float.cu
src/cluster/kmeans_fit_predict_double.cu
src/cluster/kmeans_fit_predict_float.cu
src/cluster/kmeans_predict_double.cu
src/cluster/kmeans_predict_float.cu
src/cluster/kmeans_balanced_fit_float.cu
src/cluster/kmeans_balanced_fit_predict_float.cu
src/cluster/kmeans_balanced_predict_float.cu
src/cluster/kmeans_balanced_fit_int8.cu
src/cluster/kmeans_balanced_fit_predict_int8.cu
src/cluster/kmeans_balanced_predict_int8.cu
src/cluster/kmeans_transform_double.cu
src/cluster/kmeans_transform_float.cu
src/cluster/single_linkage_float.cu
src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu
Expand Down Expand Up @@ -342,6 +348,8 @@ add_library(
src/distance/detail/pairwise_matrix/dispatch_russel_rao_half_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_rbf.cu
src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int64_t.cu
src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int64_t.cu
src/distance/detail/fused_distance_nn.cu
src/distance/distance.cu
src/distance/pairwise_distance.cu
Expand Down
506 changes: 485 additions & 21 deletions cpp/include/cuvs/cluster/kmeans.hpp

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions cpp/src/cluster/detail/connectivities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#pragma once

#include "../../distance/distance.cuh"
#include "./kmeans_common.cuh"
#include <cuvs/cluster/agglomerative.hpp>
#include <cuvs/distance/distance.hpp>
#include <raft/core/resource/cuda_stream.hpp>
Expand Down Expand Up @@ -153,7 +153,11 @@ void pairwise_distances(const raft::resources& handle,
// TODO: It would ultimately be nice if the MST could accept
// dense inputs directly so we don't need to double the memory
// usage to hand it a sparse array here.
distance::pairwise_distance<value_t, value_idx>(handle, X, X, data, m, m, n, metric);
auto X_view = raft::make_device_matrix_view<const value_t, value_idx>(X, m, n);

cuvs::cluster::kmeans::detail::pairwise_distance_kmeans<value_t, value_idx>(
handle, X_view, X_view, raft::make_device_matrix_view<value_t, value_idx>(data, m, m), metric);

// self-loops get max distance
auto transform_in =
thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), data));
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void kmeansPlusPlus(raft::resources const& handle,
// Output - pwd [n_trials x n_samples]
auto pwd = distBuffer.view();
cuvs::cluster::kmeans::detail::pairwise_distance_kmeans<DataT, IndexT>(
handle, centroidCandidates.view(), X, pwd, workspace, metric);
handle, centroidCandidates.view(), X, pwd, metric);

// Update nearest cluster distance for each centroid candidate
// Note pwd and minDistBuf points to same buffer which currently holds pairwise distance values.
Expand Down Expand Up @@ -1247,7 +1247,7 @@ void kmeans_transform(raft::resources const& handle,
// calculate pairwise distance between cluster centroids and current batch
// of input dataset
pairwise_distance_kmeans<DataT, IndexT>(
handle, datasetView, centroids, pairwiseDistanceView, workspace, metric);
handle, datasetView, centroids, pairwiseDistanceView, metric);
}
}

Expand Down
31 changes: 19 additions & 12 deletions cpp/src/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ void pairwise_distance_kmeans(raft::resources const& handle,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::device_matrix_view<DataT, IndexT> pairwiseDistance,
rmm::device_uvector<char>& workspace,
cuvs::distance::DistanceType metric)
{
auto n_samples = X.extent(0);
Expand All @@ -303,15 +302,23 @@ void pairwise_distance_kmeans(raft::resources const& handle,
ASSERT(X.extent(1) == centroids.extent(1),
"# features in dataset and centroids are different (must be same)");

cuvs::distance::pairwise_distance(handle,
X.data_handle(),
centroids.data_handle(),
pairwiseDistance.data_handle(),
n_samples,
n_clusters,
n_features,
workspace,
metric);
if (metric == cuvs::distance::DistanceType::L2Expanded) {
cuvs::distance::distance<cuvs::distance::DistanceType::L2Expanded,
DataT,
DataT,
DataT,
raft::layout_c_contiguous,
IndexT>(handle, X, centroids, pairwiseDistance);
} else if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
cuvs::distance::distance<cuvs::distance::DistanceType::L2SqrtExpanded,
DataT,
DataT,
DataT,
raft::layout_c_contiguous,
IndexT>(handle, X, centroids, pairwiseDistance);
} else {
RAFT_FAIL("kmeans requires L2Expanded or L2SqrtExpanded distance, have %i", metric);
}
}

// shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores
Expand Down Expand Up @@ -461,7 +468,7 @@ void minClusterAndDistanceCompute(
// calculate pairwise distance between current tile of cluster centroids
// and input dataset
pairwise_distance_kmeans<DataT, IndexT>(
handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric);
handle, datasetView, centroidsView, pairwiseDistanceView, metric);

// argmin reduction returning <index, value> pair
// calculates the closest centroid and the distance to the closest
Expand Down Expand Up @@ -591,7 +598,7 @@ void minClusterDistanceCompute(raft::resources const& handle,
// calculate pairwise distance between current tile of cluster centroids
// and input dataset
pairwise_distance_kmeans<DataT, IndexT>(
handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric);
handle, datasetView, centroidsView, pairwiseDistanceView, metric);

raft::linalg::coalescedReduction(minClusterDistanceView.data_handle(),
pairwiseDistanceView.data_handle(),
Expand Down
Loading

0 comments on commit 40f4a58

Please sign in to comment.