Skip to content

Commit

Permalink
Move check_input_array from pylibraft
Browse files Browse the repository at this point in the history
With the changes in rapidsai/raft#2498 we no longer
have a pylibraft.neighbors module - but were still using a utility function
`_check_input_array` from it in cuvs. Move this over to cuvs to unblock ci
  • Loading branch information
benfred committed Nov 18, 2024
1 parent bb9c669 commit 7ab03ee
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ from cuvs.distance_type cimport cuvsDistanceType
from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray
from pylibraft.common.cai_wrapper import wrap_array
from pylibraft.common.interruptible import cuda_interruptible
from pylibraft.neighbors.common import _check_input_array

from cuvs.distance import DISTANCE_TYPES
from cuvs.neighbors.common import check_input_array

from cuvs.common.c_api cimport cuvsResources_t

Expand Down
3 changes: 2 additions & 1 deletion python/cuvs/cuvs/neighbors/cagra/cagra.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ from cuvs.common cimport cydlpack
from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray
from pylibraft.common.cai_wrapper import wrap_array
from pylibraft.common.interruptible import cuda_interruptible
from pylibraft.neighbors.common import _check_input_array

from cuvs.neighbors.common import check_input_array

from libc.stdint cimport (
int8_t,
Expand Down
36 changes: 36 additions & 0 deletions python/cuvs/cuvs/neighbors/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def check_input_array(cai, exp_dt, exp_rows=None, exp_cols=None):
if cai.dtype not in exp_dt:
raise TypeError("dtype %s not supported" % cai.dtype)

if not cai.c_contiguous:
raise ValueError("Row major input is expected")

if exp_cols is not None and cai.shape[1] != exp_cols:
raise ValueError(
"Incorrect number of columns, expected {} got {}".format(
exp_cols, cai.shape[1]
)
)

if exp_rows is not None and cai.shape[0] != exp_rows:
raise ValueError(
"Incorrect number of rows, expected {} , got {}".format(
exp_rows, cai.shape[0]
)
)
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/neighbors/filters/filters.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ import numpy as np
from libc.stdint cimport uintptr_t

from cuvs.common cimport cydlpack
from cuvs.neighbors.common import check_input_array

from .filters cimport BITMAP, NO_FILTER, cuvsFilter

from pylibraft.common.cai_wrapper import wrap_array
from pylibraft.neighbors.common import _check_input_array


cdef class Prefilter:
Expand Down
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ from libcpp.string cimport string

from cuvs.common.exceptions import check_cuvs
from cuvs.common.resources import auto_sync_resources
from cuvs.neighbors.common import check_input_array

from cuvs.common cimport cydlpack

Expand All @@ -36,7 +37,6 @@ import uuid
from pylibraft.common import auto_convert_output
from pylibraft.common.cai_wrapper import wrap_array
from pylibraft.common.interruptible import cuda_interruptible
from pylibraft.neighbors.common import _check_input_array


cdef class SearchParams:
Expand Down
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ from cuvs.distance_type cimport cuvsDistanceType
from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray
from pylibraft.common.cai_wrapper import wrap_array
from pylibraft.common.interruptible import cuda_interruptible
from pylibraft.neighbors.common import _check_input_array

from cuvs.distance import DISTANCE_TYPES
from cuvs.neighbors.common import check_input_array

from libc.stdint cimport (
int8_t,
Expand Down
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/neighbors/ivf_pq/ivf_pq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ from cuvs.distance_type cimport cuvsDistanceType
from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray
from pylibraft.common.cai_wrapper import wrap_array
from pylibraft.common.interruptible import cuda_interruptible
from pylibraft.neighbors.common import _check_input_array

from cuvs.distance import DISTANCE_TYPES
from cuvs.neighbors.common import check_input_array

from libc.stdint cimport (
int8_t,
Expand Down
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/neighbors/refine.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ from cuvs.distance_type cimport cuvsDistanceType
from pylibraft.common import auto_convert_output, device_ndarray
from pylibraft.common.cai_wrapper import wrap_array
from pylibraft.common.interruptible import cuda_interruptible
from pylibraft.neighbors.common import _check_input_array

from cuvs.distance import DISTANCE_TYPES

from cuvs.common.c_api cimport cuvsResources_t

from cuvs.common.exceptions import check_cuvs
from cuvs.neighbors.common import check_input_array


@auto_sync_resources
Expand Down

0 comments on commit 7ab03ee

Please sign in to comment.