Skip to content

Commit

Permalink
Add support for float16 to the python pairwise distance api (#547)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #547
  • Loading branch information
benfred authored Dec 20, 2024
1 parent 89ebf15 commit f48e9aa
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 9 deletions.
13 changes: 9 additions & 4 deletions cpp/src/distance/pairwise_distance_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

namespace {

template <typename T>
template <typename T, typename DistT>
void _pairwise_distance(cuvsResources_t res,
DLManagedTensor* x_tensor,
DLManagedTensor* y_tensor,
Expand All @@ -40,7 +40,7 @@ void _pairwise_distance(cuvsResources_t res,
auto res_ptr = reinterpret_cast<raft::resources*>(res);

using mdspan_type = raft::device_matrix_view<T const, int64_t, raft::row_major>;
using distances_mdspan_type = raft::device_matrix_view<T, int64_t, raft::row_major>;
using distances_mdspan_type = raft::device_matrix_view<DistT, int64_t, raft::row_major>;

auto x_mds = cuvs::core::from_dlpack<mdspan_type>(x_tensor);
auto y_mds = cuvs::core::from_dlpack<mdspan_type>(y_tensor);
Expand Down Expand Up @@ -71,9 +71,14 @@ extern "C" cuvsError_t cuvsPairwiseDistance(cuvsResources_t res,
}

if (x_dt.bits == 32) {
_pairwise_distance<float>(res, x_tensor, y_tensor, distances_tensor, metric, metric_arg);
_pairwise_distance<float, float>(
res, x_tensor, y_tensor, distances_tensor, metric, metric_arg);
} else if (x_dt.bits == 16) {
_pairwise_distance<half, float>(
res, x_tensor, y_tensor, distances_tensor, metric, metric_arg);
} else if (x_dt.bits == 64) {
_pairwise_distance<double>(res, x_tensor, y_tensor, distances_tensor, metric, metric_arg);
_pairwise_distance<double, double>(
res, x_tensor, y_tensor, distances_tensor, metric, metric_arg);
} else {
RAFT_FAIL("Unsupported DLtensor dtype: %d and bits: %d", x_dt.code, x_dt.bits);
}
Expand Down
7 changes: 5 additions & 2 deletions python/cuvs/cuvs/distance/distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def pairwise_distance(X, Y, out=None, metric="euclidean", metric_arg=2.0,
n = y_cai.shape[0]

if out is None:
out = device_ndarray.empty((m, n), dtype=y_cai.dtype)
output_dtype = y_cai.dtype
if np.issubdtype(y_cai.dtype, np.float16):
output_dtype = np.float32
out = device_ndarray.empty((m, n), dtype=output_dtype)
out_cai = wrap_array(out)

x_k = x_cai.shape[1]
Expand All @@ -119,7 +122,7 @@ def pairwise_distance(X, Y, out=None, metric="euclidean", metric_arg=2.0,
y_dt = y_cai.dtype
d_dt = out_cai.dtype

if x_dt != y_dt or x_dt != d_dt:
if x_dt != y_dt:
raise ValueError("Inputs must have the same dtypes")

cdef cydlpack.DLManagedTensor* x_dlpack = \
Expand Down
13 changes: 10 additions & 3 deletions python/cuvs/cuvs/test/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
],
)
@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.float16])
def test_distance(n_rows, n_cols, inplace, metric, dtype):
input1 = np.random.random_sample((n_rows, n_cols))
input1 = np.asarray(input1).astype(dtype)
Expand All @@ -55,7 +55,10 @@ def test_distance(n_rows, n_cols, inplace, metric, dtype):
norm = np.sum(input1, axis=1)
input1 = (input1.T / norm).T

output = np.zeros((n_rows, n_rows), dtype=dtype)
output_dtype = dtype
if np.issubdtype(dtype, np.float16):
output_dtype = np.float32
output = np.zeros((n_rows, n_rows), dtype=output_dtype)

if metric == "inner_product":
expected = np.matmul(input1, input1.T)
Expand All @@ -76,4 +79,8 @@ def test_distance(n_rows, n_cols, inplace, metric, dtype):

actual = output_device.copy_to_host()

assert np.allclose(expected, actual, atol=1e-3, rtol=1e-3)
tol = 1e-3
if np.issubdtype(dtype, np.float16):
tol = 1e-1

assert np.allclose(expected, actual, atol=tol, rtol=tol)

0 comments on commit f48e9aa

Please sign in to comment.