Skip to content

Commit

Permalink
add errror text to python exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Mar 5, 2024
1 parent 71850c1 commit 8c198e1
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 28 deletions.
2 changes: 1 addition & 1 deletion python/cuvs/cuvs/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# =============================================================================

# Set the list of Cython files to build
set(cython_sources cydlpack.pyx)
set(cython_sources cydlpack.pyx exceptions.pyx)
set(linked_libraries cuvs::cuvs cuvs_c)

# Build all of the Cython targets
Expand Down
1 change: 1 addition & 0 deletions python/cuvs/cuvs/common/c_api.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ cdef extern from "cuvs/core/c_api.h":
cuvsError_t cuvsResourcesCreate(cuvsResources_t* res)
cuvsError_t cuvsResourcesDestroy(cuvsResources_t res)
cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream)
const char * cuvsGetLastErrorText()
37 changes: 37 additions & 0 deletions python/cuvs/cuvs/common/exceptions.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cython: language_level=3

from cuvs.common.c_api cimport cuvsError_t, cuvsGetLastErrorText


class CuvsException(Exception):
pass


def get_last_error_text():
""" returns the last error description from the cuvs c-api """
cdef const char* c_err = cuvsGetLastErrorText()
if c_err is NULL:
return
cdef bytes err = c_err
return err.decode("utf8")


def check_cuvs(status: cuvsError_t):
""" Converts a status code into an exception """
if status == cuvsError_t.CUVS_ERROR:
raise CuvsException(get_last_error_text())
37 changes: 11 additions & 26 deletions python/cuvs/cuvs/neighbors/cagra/cagra.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ from cuvs.common.c_api cimport (
cuvsResourcesCreate,
)

from cuvs.common.exceptions import check_cuvs


cdef class IndexParams:
"""
Expand Down Expand Up @@ -124,19 +126,13 @@ cdef class Index:
cdef bool trained

def __cinit__(self):
cdef cuvsError_t index_create_status
index_create_status = cuvsCagraIndexCreate(&self.index)
self.trained = False

if index_create_status == cuvsError_t.CUVS_ERROR:
raise RuntimeError("Failed to create index.")
check_cuvs(cuvsCagraIndexCreate(&self.index))

def __dealloc__(self):
cdef cuvsError_t index_destroy_status
if self.index is not NULL:
index_destroy_status = cuvsCagraIndexDestroy(self.index)
if index_destroy_status == cuvsError_t.CUVS_ERROR:
raise Exception("Failed to deallocate index.")
check_cuvs(cuvsCagraIndexDestroy(self.index))

@property
def trained(self):
Expand Down Expand Up @@ -203,9 +199,7 @@ def build_index(IndexParams index_params, dataset, resources=None):
cdef cuvsResources_t res_
cdef cuvsError_t cstat

cstat = cuvsResourcesCreate(&res_)
if cstat == cuvsError_t.CUVS_ERROR:
raise RuntimeError("Error creating Device Reources.")
check_cuvs(cuvsResourcesCreate(&res_))

cdef Index idx = Index()
cdef cuvsError_t build_status
Expand All @@ -214,17 +208,13 @@ def build_index(IndexParams index_params, dataset, resources=None):
cdef cuvsCagraIndexParams* params = index_params.params

with cuda_interruptible():
build_status = cuvsCagraBuild(
check_cuvs(cuvsCagraBuild(
res_,
params,
dataset_dlpack,
idx.index
)

if build_status == cuvsError_t.CUVS_ERROR:
raise RuntimeError("Index failed to build.")
else:
idx.trained = True
))
idx.trained = True

return idx

Expand Down Expand Up @@ -451,9 +441,7 @@ def search(SearchParams search_params,
cdef cuvsResources_t res_
cdef cuvsError_t cstat

cstat = cuvsResourcesCreate(&res_)
if cstat == cuvsError_t.CUVS_ERROR:
raise RuntimeError("Error creating Device Reources.")
check_cuvs(cuvsResourcesCreate(&res_))

# todo(dgd): we can make the check of dtype a parameter of wrap_array
# in RAFT to make this a single call
Expand Down Expand Up @@ -487,16 +475,13 @@ def search(SearchParams search_params,
cydlpack.dlpack_c(distances_cai)

with cuda_interruptible():
search_status = cuvsCagraSearch(
check_cuvs(cuvsCagraSearch(
res_,
params,
index.index,
queries_dlpack,
neighbors_dlpack,
distances_dlpack
)

if search_status == cuvsError_t.CUVS_ERROR:
raise RuntimeError("Search failed.")
))

return (distances, neighbors)
2 changes: 1 addition & 1 deletion rust/cuvs/src/cagra/index_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl fmt::Debug for IndexParams {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// custom debug trait here, default value will show the pointer address
// for the inner params object which isn't that useful.
write!(f, "IndexParams {{ params: {:?} }}", unsafe { *self.0 })
write!(f, "IndexParams({:?})", unsafe { *self.0 })
}
}

Expand Down

0 comments on commit 8c198e1

Please sign in to comment.