From f48e9aab593232b72f74fd79ad256ed51b997b43 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 19 Dec 2024 19:39:29 -0800 Subject: [PATCH] Add support for float16 to the python pairwise distance api (#547) Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/547 --- cpp/src/distance/pairwise_distance_c.cpp | 13 +++++++++---- python/cuvs/cuvs/distance/distance.pyx | 7 +++++-- python/cuvs/cuvs/test/test_distance.py | 13 ++++++++++--- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/cpp/src/distance/pairwise_distance_c.cpp b/cpp/src/distance/pairwise_distance_c.cpp index d457198a2..061adaa2c 100644 --- a/cpp/src/distance/pairwise_distance_c.cpp +++ b/cpp/src/distance/pairwise_distance_c.cpp @@ -29,7 +29,7 @@ namespace { -template +template void _pairwise_distance(cuvsResources_t res, DLManagedTensor* x_tensor, DLManagedTensor* y_tensor, @@ -40,7 +40,7 @@ void _pairwise_distance(cuvsResources_t res, auto res_ptr = reinterpret_cast(res); using mdspan_type = raft::device_matrix_view; - using distances_mdspan_type = raft::device_matrix_view; + using distances_mdspan_type = raft::device_matrix_view; auto x_mds = cuvs::core::from_dlpack(x_tensor); auto y_mds = cuvs::core::from_dlpack(y_tensor); @@ -71,9 +71,14 @@ extern "C" cuvsError_t cuvsPairwiseDistance(cuvsResources_t res, } if (x_dt.bits == 32) { - _pairwise_distance(res, x_tensor, y_tensor, distances_tensor, metric, metric_arg); + _pairwise_distance( + res, x_tensor, y_tensor, distances_tensor, metric, metric_arg); + } else if (x_dt.bits == 16) { + _pairwise_distance( + res, x_tensor, y_tensor, distances_tensor, metric, metric_arg); } else if (x_dt.bits == 64) { - _pairwise_distance(res, x_tensor, y_tensor, distances_tensor, metric, metric_arg); + _pairwise_distance( + res, x_tensor, y_tensor, distances_tensor, metric, metric_arg); } else { RAFT_FAIL("Unsupported DLtensor dtype: %d and bits: %d", x_dt.code, x_dt.bits); } diff --git a/python/cuvs/cuvs/distance/distance.pyx b/python/cuvs/cuvs/distance/distance.pyx index eb34366e4..187532bfe 100644 --- a/python/cuvs/cuvs/distance/distance.pyx +++ b/python/cuvs/cuvs/distance/distance.pyx @@ -100,7 +100,10 @@ def pairwise_distance(X, Y, out=None, metric="euclidean", metric_arg=2.0, n = y_cai.shape[0] if out is None: - out = device_ndarray.empty((m, n), dtype=y_cai.dtype) + output_dtype = y_cai.dtype + if np.issubdtype(y_cai.dtype, np.float16): + output_dtype = np.float32 + out = device_ndarray.empty((m, n), dtype=output_dtype) out_cai = wrap_array(out) x_k = x_cai.shape[1] @@ -119,7 +122,7 @@ def pairwise_distance(X, Y, out=None, metric="euclidean", metric_arg=2.0, y_dt = y_cai.dtype d_dt = out_cai.dtype - if x_dt != y_dt or x_dt != d_dt: + if x_dt != y_dt: raise ValueError("Inputs must have the same dtypes") cdef cydlpack.DLManagedTensor* x_dlpack = \ diff --git a/python/cuvs/cuvs/test/test_distance.py b/python/cuvs/cuvs/test/test_distance.py index 681217fc8..f466c2743 100644 --- a/python/cuvs/cuvs/test/test_distance.py +++ b/python/cuvs/cuvs/test/test_distance.py @@ -40,7 +40,7 @@ ], ) @pytest.mark.parametrize("inplace", [True, False]) -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.float16]) def test_distance(n_rows, n_cols, inplace, metric, dtype): input1 = np.random.random_sample((n_rows, n_cols)) input1 = np.asarray(input1).astype(dtype) @@ -55,7 +55,10 @@ def test_distance(n_rows, n_cols, inplace, metric, dtype): norm = np.sum(input1, axis=1) input1 = (input1.T / norm).T - output = np.zeros((n_rows, n_rows), dtype=dtype) + output_dtype = dtype + if np.issubdtype(dtype, np.float16): + output_dtype = np.float32 + output = np.zeros((n_rows, n_rows), dtype=output_dtype) if metric == "inner_product": expected = np.matmul(input1, input1.T) @@ -76,4 +79,8 @@ def test_distance(n_rows, n_cols, inplace, metric, dtype): actual = output_device.copy_to_host() - assert np.allclose(expected, actual, atol=1e-3, rtol=1e-3) + tol = 1e-3 + if np.issubdtype(dtype, np.float16): + tol = 1e-1 + + assert np.allclose(expected, actual, atol=tol, rtol=tol)