From 7ab03ee77ba0a68c2667439ef821024cee4c23d9 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 18 Nov 2024 13:05:38 -0800 Subject: [PATCH 1/2] Move check_input_array from pylibraft With the changes in https://github.com/rapidsai/raft/pull/2498 we no longer have a pylibraft.neighbors module - but were still using a utility function `_check_input_array` from it in cuvs. Move this over to cuvs to unblock ci --- .../neighbors/brute_force/brute_force.pyx | 2 +- python/cuvs/cuvs/neighbors/cagra/cagra.pyx | 3 +- python/cuvs/cuvs/neighbors/common.py | 36 +++++++++++++++++++ .../cuvs/cuvs/neighbors/filters/filters.pyx | 2 +- python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx | 2 +- .../cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx | 2 +- python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx | 2 +- python/cuvs/cuvs/neighbors/refine.pyx | 2 +- 8 files changed, 44 insertions(+), 7 deletions(-) create mode 100644 python/cuvs/cuvs/neighbors/common.py diff --git a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx index 559302ccc..d515c3723 100644 --- a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx +++ b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx @@ -31,9 +31,9 @@ from cuvs.distance_type cimport cuvsDistanceType from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible -from pylibraft.neighbors.common import _check_input_array from cuvs.distance import DISTANCE_TYPES +from cuvs.neighbors.common import check_input_array from cuvs.common.c_api cimport cuvsResources_t diff --git a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx index 95209dbeb..916d72280 100644 --- a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx +++ b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx @@ -32,7 +32,8 @@ from cuvs.common cimport cydlpack from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible -from pylibraft.neighbors.common import _check_input_array + +from cuvs.neighbors.common import check_input_array from libc.stdint cimport ( int8_t, diff --git a/python/cuvs/cuvs/neighbors/common.py b/python/cuvs/cuvs/neighbors/common.py new file mode 100644 index 000000000..a2419ca82 --- /dev/null +++ b/python/cuvs/cuvs/neighbors/common.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def check_input_array(cai, exp_dt, exp_rows=None, exp_cols=None): + if cai.dtype not in exp_dt: + raise TypeError("dtype %s not supported" % cai.dtype) + + if not cai.c_contiguous: + raise ValueError("Row major input is expected") + + if exp_cols is not None and cai.shape[1] != exp_cols: + raise ValueError( + "Incorrect number of columns, expected {} got {}".format( + exp_cols, cai.shape[1] + ) + ) + + if exp_rows is not None and cai.shape[0] != exp_rows: + raise ValueError( + "Incorrect number of rows, expected {} , got {}".format( + exp_rows, cai.shape[0] + ) + ) diff --git a/python/cuvs/cuvs/neighbors/filters/filters.pyx b/python/cuvs/cuvs/neighbors/filters/filters.pyx index 3a81cb786..d0a587a26 100644 --- a/python/cuvs/cuvs/neighbors/filters/filters.pyx +++ b/python/cuvs/cuvs/neighbors/filters/filters.pyx @@ -20,11 +20,11 @@ import numpy as np from libc.stdint cimport uintptr_t from cuvs.common cimport cydlpack +from cuvs.neighbors.common import check_input_array from .filters cimport BITMAP, NO_FILTER, cuvsFilter from pylibraft.common.cai_wrapper import wrap_array -from pylibraft.neighbors.common import _check_input_array cdef class Prefilter: diff --git a/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx b/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx index 018fcfef9..720b9489e 100644 --- a/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx +++ b/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx @@ -21,6 +21,7 @@ from libcpp.string cimport string from cuvs.common.exceptions import check_cuvs from cuvs.common.resources import auto_sync_resources +from cuvs.neighbors.common import check_input_array from cuvs.common cimport cydlpack @@ -36,7 +37,6 @@ import uuid from pylibraft.common import auto_convert_output from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible -from pylibraft.neighbors.common import _check_input_array cdef class SearchParams: diff --git a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx index 25b9b2aee..2444e9ca9 100644 --- a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx +++ b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx @@ -31,9 +31,9 @@ from cuvs.distance_type cimport cuvsDistanceType from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible -from pylibraft.neighbors.common import _check_input_array from cuvs.distance import DISTANCE_TYPES +from cuvs.neighbors.common import check_input_array from libc.stdint cimport ( int8_t, diff --git a/python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx b/python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx index 3add1df75..9c500a311 100644 --- a/python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx @@ -31,9 +31,9 @@ from cuvs.distance_type cimport cuvsDistanceType from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible -from pylibraft.neighbors.common import _check_input_array from cuvs.distance import DISTANCE_TYPES +from cuvs.neighbors.common import check_input_array from libc.stdint cimport ( int8_t, diff --git a/python/cuvs/cuvs/neighbors/refine.pyx b/python/cuvs/cuvs/neighbors/refine.pyx index 0eccc4108..acff96fb7 100644 --- a/python/cuvs/cuvs/neighbors/refine.pyx +++ b/python/cuvs/cuvs/neighbors/refine.pyx @@ -31,13 +31,13 @@ from cuvs.distance_type cimport cuvsDistanceType from pylibraft.common import auto_convert_output, device_ndarray from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible -from pylibraft.neighbors.common import _check_input_array from cuvs.distance import DISTANCE_TYPES from cuvs.common.c_api cimport cuvsResources_t from cuvs.common.exceptions import check_cuvs +from cuvs.neighbors.common import check_input_array @auto_sync_resources From 3c9ac1aa8d4562205e21056ee3b96ba435377a7b Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 18 Nov 2024 13:09:59 -0800 Subject: [PATCH 2/2] . --- python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx | 2 +- python/cuvs/cuvs/neighbors/cagra/cagra.pyx | 2 +- python/cuvs/cuvs/neighbors/common.py | 2 +- python/cuvs/cuvs/neighbors/filters/filters.pyx | 2 +- python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx | 2 +- python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx | 2 +- python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx | 2 +- python/cuvs/cuvs/neighbors/refine.pyx | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx index d515c3723..9d1d24eae 100644 --- a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx +++ b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx @@ -33,7 +33,7 @@ from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible from cuvs.distance import DISTANCE_TYPES -from cuvs.neighbors.common import check_input_array +from cuvs.neighbors.common import _check_input_array from cuvs.common.c_api cimport cuvsResources_t diff --git a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx index 916d72280..752aef741 100644 --- a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx +++ b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx @@ -33,7 +33,7 @@ from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible -from cuvs.neighbors.common import check_input_array +from cuvs.neighbors.common import _check_input_array from libc.stdint cimport ( int8_t, diff --git a/python/cuvs/cuvs/neighbors/common.py b/python/cuvs/cuvs/neighbors/common.py index a2419ca82..c14b9f8c9 100644 --- a/python/cuvs/cuvs/neighbors/common.py +++ b/python/cuvs/cuvs/neighbors/common.py @@ -14,7 +14,7 @@ # limitations under the License. -def check_input_array(cai, exp_dt, exp_rows=None, exp_cols=None): +def _check_input_array(cai, exp_dt, exp_rows=None, exp_cols=None): if cai.dtype not in exp_dt: raise TypeError("dtype %s not supported" % cai.dtype) diff --git a/python/cuvs/cuvs/neighbors/filters/filters.pyx b/python/cuvs/cuvs/neighbors/filters/filters.pyx index d0a587a26..9bc2a905c 100644 --- a/python/cuvs/cuvs/neighbors/filters/filters.pyx +++ b/python/cuvs/cuvs/neighbors/filters/filters.pyx @@ -20,7 +20,7 @@ import numpy as np from libc.stdint cimport uintptr_t from cuvs.common cimport cydlpack -from cuvs.neighbors.common import check_input_array +from cuvs.neighbors.common import _check_input_array from .filters cimport BITMAP, NO_FILTER, cuvsFilter diff --git a/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx b/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx index 720b9489e..bcfaf167e 100644 --- a/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx +++ b/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx @@ -21,7 +21,7 @@ from libcpp.string cimport string from cuvs.common.exceptions import check_cuvs from cuvs.common.resources import auto_sync_resources -from cuvs.neighbors.common import check_input_array +from cuvs.neighbors.common import _check_input_array from cuvs.common cimport cydlpack diff --git a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx index 2444e9ca9..7a169e1a0 100644 --- a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx +++ b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx @@ -33,7 +33,7 @@ from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible from cuvs.distance import DISTANCE_TYPES -from cuvs.neighbors.common import check_input_array +from cuvs.neighbors.common import _check_input_array from libc.stdint cimport ( int8_t, diff --git a/python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx b/python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx index 9c500a311..531302ee6 100644 --- a/python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx @@ -33,7 +33,7 @@ from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible from cuvs.distance import DISTANCE_TYPES -from cuvs.neighbors.common import check_input_array +from cuvs.neighbors.common import _check_input_array from libc.stdint cimport ( int8_t, diff --git a/python/cuvs/cuvs/neighbors/refine.pyx b/python/cuvs/cuvs/neighbors/refine.pyx index acff96fb7..b7aa35dca 100644 --- a/python/cuvs/cuvs/neighbors/refine.pyx +++ b/python/cuvs/cuvs/neighbors/refine.pyx @@ -37,7 +37,7 @@ from cuvs.distance import DISTANCE_TYPES from cuvs.common.c_api cimport cuvsResources_t from cuvs.common.exceptions import check_cuvs -from cuvs.neighbors.common import check_input_array +from cuvs.neighbors.common import _check_input_array @auto_sync_resources