Skip to content

Commit

Permalink
fix typo, refactor cython
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Nov 23, 2023
1 parent c60ae05 commit 7f04f97
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 22 deletions.
2 changes: 1 addition & 1 deletion cpp/src/raft_runtime/neighbors/cagra_serialize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace raft::runtime::neighbors::cagra {
raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \
}; \
\
void serialize_to_hswnlib_file(raft::resources const& handle, \
void serialize_to_hnswlib_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index) \
{ \
Expand Down
9 changes: 8 additions & 1 deletion python/pylibraft/pylibraft/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,11 @@

from .refine import refine

__all__ = ["common", "refine", "brute_force", "ivf_flat", "ivf_pq", "cagra"]
__all__ = [
"common",
"refine",
"brute_force",
"ivf_flat",
"ivf_pq",
"cagra",
]
39 changes: 39 additions & 0 deletions python/pylibraft/pylibraft/neighbors/cagra/cagra.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
# cython: language_level = 3

from libc.stdint cimport int8_t, uint8_t, uint32_t
from libcpp cimport bool
from libcpp.string cimport string

cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra


cdef class Index:
cdef readonly bool trained
cdef str active_index_type

cdef class IndexFloat(Index):
cdef c_cagra.index[float, uint32_t] * index

cdef class IndexInt8(Index):
cdef c_cagra.index[int8_t, uint32_t] * index

cdef class IndexUint8(Index):
cdef c_cagra.index[uint8_t, uint32_t] * index
3 changes: 0 additions & 3 deletions python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ cdef class Index:


cdef class IndexFloat(Index):
cdef c_cagra.index[float, uint32_t] * index

def __cinit__(self, handle=None):
if handle is None:
Expand Down Expand Up @@ -214,7 +213,6 @@ cdef class IndexFloat(Index):


cdef class IndexInt8(Index):
cdef c_cagra.index[int8_t, uint32_t] * index

def __cinit__(self, handle=None):
if handle is None:
Expand Down Expand Up @@ -279,7 +277,6 @@ cdef class IndexInt8(Index):


cdef class IndexUint8(Index):
cdef c_cagra.index[uint8_t, uint32_t] * index

def __cinit__(self, handle=None):
if handle is None:
Expand Down
13 changes: 0 additions & 13 deletions python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,3 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \
cdef void deserialize_file(const device_resources& handle,
const string& filename,
index[int8_t, uint32_t]* index) except +

cdef class Index:
cdef readonly bool trained
cdef str active_index_type

cdef class IndexFloat(Index):
pass

cdef class IndexInt8(Index):
pass

cdef class IndexUint8(Index):
pass
14 changes: 10 additions & 4 deletions python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ from libc.stdint cimport int8_t, uint8_t, uint32_t
from libcpp.string cimport string

cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra
from pylibraft.neighbors.cagra.cagra cimport (
Index,
IndexFloat,
IndexInt8,
IndexUint8,
)

from pylibraft.common.handle import auto_sync_handle

Expand All @@ -32,7 +38,7 @@ from pylibraft.common import DeviceResources


@auto_sync_handle
def save(filename, c_cagra.Index index, handle=None):
def save(filename, Index index, handle=None):
"""
Saves the index to a file.
Expand Down Expand Up @@ -73,9 +79,9 @@ def save(filename, c_cagra.Index index, handle=None):

cdef string c_filename = filename.encode('utf-8')

cdef c_cagra.IndexFloat idx_float
cdef c_cagra.IndexInt8 idx_int8
cdef c_cagra.IndexUint8 idx_uint8
cdef IndexFloat idx_float
cdef IndexInt8 idx_int8
cdef IndexUint8 idx_uint8

cdef c_cagra.index[float, uint32_t] * c_index_float
cdef c_cagra.index[int8_t, uint32_t] * c_index_int8
Expand Down

0 comments on commit 7f04f97

Please sign in to comment.