-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The current python code was creating a cuvsResources_t to pass to the CAGRA apis on the build/search code. This both ignored the passed in `resources` object, and also leaked memory since it wasn't calling cuvsResourcesDestroy. Fix by exposing the cuvsResources_t object to python.
- Loading branch information
Showing
7 changed files
with
154 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# cython: language_level=3 | ||
|
||
from cuvs.common.c_api cimport cuvsResources_t | ||
|
||
|
||
cdef class Resources: | ||
cdef cuvsResources_t c_obj |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# cython: language_level=3 | ||
|
||
import functools | ||
|
||
from cuda.ccudart cimport cudaStream_t | ||
|
||
from cuvs.common.c_api cimport ( | ||
cuvsResources_t, | ||
cuvsResourcesCreate, | ||
cuvsResourcesDestroy, | ||
cuvsStreamSet, | ||
cuvsStreamSync, | ||
) | ||
|
||
from cuvs.common.exceptions import check_cuvs | ||
|
||
|
||
cdef class Resources: | ||
""" | ||
Resources is a lightweight python wrapper around the corresponding | ||
C++ class of resources exposed by RAFT's C++ interface. Refer to | ||
the header file raft/core/resources.hpp for interface level | ||
details of this struct. | ||
Parameters | ||
---------- | ||
stream : Optional stream to use for ordering CUDA instructions | ||
Examples | ||
-------- | ||
Basic usage: | ||
>>> from cuvs.common import Resources | ||
>>> handle = Resources() | ||
>>> | ||
>>> # call algos here | ||
>>> | ||
>>> # final sync of all work launched in the stream of this handle | ||
>>> handle.sync() | ||
Using a cuPy stream with cuVS Resources: | ||
>>> import cupy | ||
>>> from cuvs.common import Resources | ||
>>> | ||
>>> cupy_stream = cupy.cuda.Stream() | ||
>>> handle = Resources(stream=cupy_stream.ptr) | ||
""" | ||
|
||
def __cinit__(self, stream=None): | ||
check_cuvs(cuvsResourcesCreate(&self.c_obj)) | ||
if stream: | ||
check_cuvs(cuvsStreamSet(self.c_obj, <cudaStream_t>stream)) | ||
|
||
def sync(self): | ||
check_cuvs(cuvsStreamSync(self.c_obj)) | ||
|
||
def get_c_obj(self): | ||
""" | ||
Return the pointer to the underlying c_obj as a size_t | ||
""" | ||
return <size_t> self.c_obj | ||
|
||
def __dealloc__(self): | ||
check_cuvs(cuvsResourcesDestroy(self.c_obj)) | ||
|
||
|
||
_resources_param_string = """ | ||
resources : Optional cuVS Resource handle for reusing CUDA resources. | ||
If Resources aren't supplied, CUDA resources will be | ||
allocated inside this function and synchronized before the | ||
function exits. If resources are supplied, you will need to | ||
explicitly synchronize yourself by calling `resources.sync()` | ||
before accessing the output. | ||
""".strip() | ||
|
||
|
||
def auto_sync_resources(f): | ||
"""Decorator to automatically call sync on a cuVS Resources object when | ||
it isn't passed to a function. | ||
When a resources=None is passed to the wrapped function, this decorator | ||
will automatically create a default resources for the function, and | ||
call sync on that resources when the function exits. | ||
This will also insert the appropriate docstring for the resources parameter | ||
""" | ||
|
||
@functools.wraps(f) | ||
def wrapper(*args, resources=None, **kwargs): | ||
sync_resources = resources is None | ||
resources = resources if resources is not None else Resources() | ||
|
||
ret_value = f(*args, resources=resources, **kwargs) | ||
|
||
if sync_resources: | ||
resources.sync() | ||
|
||
return ret_value | ||
|
||
wrapper.__doc__ = wrapper.__doc__.format( | ||
resources_docstring=_resources_param_string | ||
) | ||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters