diff --git a/python/pylibraft/pylibraft/distance/CMakeLists.txt b/python/pylibraft/pylibraft/distance/CMakeLists.txt index 14f0cc441a..2530e07a98 100644 --- a/python/pylibraft/pylibraft/distance/CMakeLists.txt +++ b/python/pylibraft/pylibraft/distance/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources pairwise_distance.pyx fused_l2_nn.pyx) +set(cython_sources pairwise_distance.pyx fused_l2_nn.pyx fused_distance_nn.pyx) set(linked_libraries raft::raft raft::compiled) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/distance/__init__.py b/python/pylibraft/pylibraft/distance/__init__.py index f059b5f3dd..d16ab30b2f 100644 --- a/python/pylibraft/pylibraft/distance/__init__.py +++ b/python/pylibraft/pylibraft/distance/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,12 @@ # limitations under the License. # +from .fused_distance_nn import fused_distance_nn_argmin from .fused_l2_nn import fused_l2_nn_argmin from .pairwise_distance import DISTANCE_TYPES, distance as pairwise_distance -__all__ = ["fused_l2_nn_argmin", "pairwise_distance"] +__all__ = [ + "fused_distance_nn_argmin", + "fused_l2_nn_argmin", + "pairwise_distance", +] diff --git a/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx new file mode 100755 index 0000000000..f98b9e710a --- /dev/null +++ b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx @@ -0,0 +1,218 @@ +# +# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +from cython.operator cimport dereference as deref +from libc.stdint cimport uintptr_t +from libcpp cimport bool + +from .distance_type cimport DistanceType + +from pylibraft.common import ( + Handle, + auto_convert_output, + cai_wrapper, + device_ndarray, +) +from pylibraft.common.handle import auto_sync_handle + +from pylibraft.common.handle cimport device_resources + + +cdef extern from "raft_runtime/distance/fused_distance_nn.hpp" \ + namespace "raft::runtime::distance" nogil: + + void fused_distance_nn_min_arg( + const device_resources &handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + DistanceType metric, + bool isRowMajor, + float metric_arg) except + + + +DISTANCE_TYPES = { + "l2": DistanceType.L2SqrtExpanded, + "sqeuclidean": DistanceType.L2Expanded, + "euclidean": DistanceType.L2SqrtExpanded, + "l1": DistanceType.L1, + "cityblock": DistanceType.L1, + "inner_product": DistanceType.InnerProduct, + "chebyshev": DistanceType.Linf, + "canberra": DistanceType.Canberra, + "cosine": DistanceType.CosineExpanded, + "lp": DistanceType.LpUnexpanded, + "correlation": DistanceType.CorrelationExpanded, + "jaccard": DistanceType.JaccardExpanded, + "hellinger": DistanceType.HellingerExpanded, + "braycurtis": DistanceType.BrayCurtis, + "jensenshannon": DistanceType.JensenShannon, + "hamming": DistanceType.HammingUnexpanded, + "kl_divergence": DistanceType.KLDivergence, + "minkowski": DistanceType.LpUnexpanded, + "russellrao": DistanceType.RusselRaoExpanded, + "dice": DistanceType.DiceExpanded, +} + +SUPPORTED_DISTANCES = ["euclidean", "l2", "cosine", "sqeuclidean"] + + +@auto_sync_handle +@auto_convert_output +def fused_distance_nn_argmin(X, Y, out=None, sqrt=True, metric="euclidean", + handle=None): + """ + Compute the 1-nearest neighbors between X and Y using the L2 distance + + Parameters + ---------- + + X : CUDA array interface compliant matrix shape (m, k) + Y : CUDA array interface compliant matrix shape (n, k) + out : Writable CUDA array interface matrix shape (m, 1) + metric : string denoting the metric type (default="euclidean") + + {handle_docstring} + + Examples + -------- + To compute the 1-nearest neighbors argmin: + + >>> import cupy as cp + >>> from pylibraft.common import Handle + >>> from pylibraft.distance import fused_distance_nn_argmin + >>> n_samples = 5000 + >>> n_clusters = 5 + >>> n_features = 50 + >>> in1 = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> in2 = cp.random.random_sample((n_clusters, n_features), + ... dtype=cp.float32) + >>> # A single RAFT handle can optionally be reused across + >>> # pylibraft functions. + >>> handle = Handle() + + >>> output = fused_distance_nn_argmin(in1, in2, handle=handle) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + + The output can also be computed in-place on a preallocated + array: + + >>> import cupy as cp + >>> from pylibraft.common import Handle + >>> from pylibraft.distance import fused_distance_nn_argmin + >>> n_samples = 5000 + >>> n_clusters = 5 + >>> n_features = 50 + >>> in1 = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> in2 = cp.random.random_sample((n_clusters, n_features), + ... dtype=cp.float32) + >>> output = cp.empty((n_samples, 1), dtype=cp.int32) + >>> # A single RAFT handle can optionally be reused across + >>> # pylibraft functions. + >>> handle = Handle() + + >>> fused_distance_nn_argmin(in1, in2, out=output, handle=handle) + array(...) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + """ + + x_cai = cai_wrapper(X) + y_cai = cai_wrapper(Y) + + x_dt = x_cai.dtype + y_dt = y_cai.dtype + + m = x_cai.shape[0] + n = y_cai.shape[0] + + if out is None: + output = device_ndarray.empty((m,), dtype="int32") + else: + output = out + + output_cai = cai_wrapper(output) + + x_k = x_cai.shape[1] + y_k = y_cai.shape[1] + + if x_k != y_k: + raise ValueError("Inputs must have same number of columns. " + "a=%s, b=%s" % (x_k, y_k)) + + if metric not in SUPPORTED_DISTANCES: + raise ValueError("metric %s is not supported" % metric) + + cdef DistanceType distance_type = DISTANCE_TYPES[metric] + + x_ptr = x_cai.data + y_ptr = y_cai.data + + d_ptr = output_cai.data + + handle = handle if handle is not None else Handle() + cdef device_resources *h = handle.getHandle() + + d_dt = output_cai.dtype + + x_c_contiguous = x_cai.c_contiguous + y_c_contiguous = y_cai.c_contiguous + + if x_c_contiguous != y_c_contiguous: + raise ValueError("Inputs must have matching strides") + + if not x_c_contiguous: + raise ValueError("Inputs must be C contiguous") + + if x_dt != y_dt: + raise ValueError("Inputs must have the same dtypes") + if d_dt != np.int32: + raise ValueError("Output array must be int32") + # unused arg for now. + metric_arg = 0.0 + if x_dt == np.float32: + fused_distance_nn_min_arg(deref(h), + d_ptr, + x_ptr, + y_ptr, + m, + n, + x_k, + sqrt, + distance_type, + x_c_contiguous, + metric_arg) + else: + raise ValueError("dtype %s not supported" % x_dt) + + return output diff --git a/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py b/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py new file mode 100755 index 0000000000..26aa3f5ab2 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from scipy.spatial.distance import cdist + +from pylibraft.common import DeviceResources, device_ndarray +from pylibraft.distance import fused_distance_nn_argmin + + +@pytest.mark.parametrize("inplace", [True, False]) +@pytest.mark.parametrize("n_rows", [10, 100]) +@pytest.mark.parametrize("n_clusters", [50, 100]) +@pytest.mark.parametrize("n_cols", [128, 31]) +@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize( + "metric", + [ + "euclidean", + "cosine", + "sqeuclidean", + ], +) +def test_fused_distance_nn_minarg( + n_rows, n_cols, n_clusters, dtype, inplace, metric +): + input1 = np.random.random_sample((n_rows, n_cols)) + input1 = np.asarray(input1, order="C").astype(dtype) + + input2 = np.random.random_sample((n_clusters, n_cols)) + input2 = np.asarray(input2, order="C").astype(dtype) + + output = np.zeros((n_rows), dtype="int32") + expected = cdist(input1, input2, metric) + + expected = expected.argmin(axis=1) + + input1_device = device_ndarray(input1) + input2_device = device_ndarray(input2) + output_device = device_ndarray(output) if inplace else None + + is_sqrt = True if metric == "sqeuclidean" else False + handle = DeviceResources() + ret_output = fused_distance_nn_argmin( + input1_device, + input2_device, + output_device, + is_sqrt, + metric, + handle=handle, + ) + handle.sync() + output_device = ret_output if not inplace else output_device + actual = output_device.copy_to_host() + + assert np.allclose(expected, actual, rtol=1e-4)