diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index e18e82df0..78648235f 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -88,7 +88,7 @@ jobs: with: build_type: pull-request enable_check_symbols: true - symbol_exclusions: (void (thrust::|cub::)|raft_cutlass) + symbol_exclusions: (void (thrust::|cub::)) conda-python-build: needs: conda-cpp-build secrets: inherit diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 5f60c0a34..27dc99a11 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -23,7 +23,7 @@ jobs: date: ${{ inputs.date }} sha: ${{ inputs.sha }} enable_check_symbols: true - symbol_exclusions: (void (thrust::|cub::)|raft_cutlass) + symbol_exclusions: (void (thrust::|cub::)) conda-cpp-tests: secrets: inherit uses: rapidsai/shared-workflows/.github/workflows/conda-cpp-tests.yaml@branch-24.12 diff --git a/.gitignore b/.gitignore index 97eab287d..da6eb07f6 100644 --- a/.gitignore +++ b/.gitignore @@ -75,6 +75,7 @@ compile_commands.json .clangd/ # serialized ann indexes +brute_force_index cagra_index ivf_flat_index ivf_pq_index diff --git a/README.md b/README.md index 572e8d098..23759f598 100755 --- a/README.md +++ b/README.md @@ -242,7 +242,7 @@ If you are interested in contributing to the cuVS library, please read our [Cont For the interested reader, many of the accelerated implementations in cuVS are also based on research papers which can provide a lot more background. We also ask you to please cite the corresponding algorithms by referencing them in your own research. - [CAGRA: Highly Parallel Graph Construction and Approximate Nearest Neighbor Search](https://arxiv.org/abs/2308.15136) -- [Top-K Algorithms on GPU: A Comprehensive Study and New Methods](https://dl.acm.org/doi/10.1145/3581784.3607062>) +- [Top-K Algorithms on GPU: A Comprehensive Study and New Methods](https://dl.acm.org/doi/10.1145/3581784.3607062) - [Fast K-NN Graph Construction by GPU Based NN-Descent](https://dl.acm.org/doi/abs/10.1145/3459637.3482344?casa_token=O_nan1B1F5cAAAAA:QHWDEhh0wmd6UUTLY9_Gv6c3XI-5DXM9mXVaUXOYeStlpxTPmV3nKvABRfoivZAaQ3n8FWyrkWw>) - [cuSLINK: Single-linkage Agglomerative Clustering on the GPU](https://arxiv.org/abs/2306.16354) - [GPU Semiring Primitives for Sparse Neighborhood Methods](https://arxiv.org/abs/2104.06357) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 32093776c..eb2e7c7a4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -371,6 +371,7 @@ if(BUILD_SHARED_LIBS) src/distance/pairwise_distance.cu src/distance/sparse_distance.cu src/neighbors/brute_force.cu + src/neighbors/brute_force_serialize.cu src/neighbors/cagra_build_float.cu src/neighbors/cagra_build_half.cu src/neighbors/cagra_build_int8.cu diff --git a/cpp/include/cuvs/neighbors/brute_force.h b/cpp/include/cuvs/neighbors/brute_force.h index c9e172f62..33b92f11b 100644 --- a/cpp/include/cuvs/neighbors/brute_force.h +++ b/cpp/include/cuvs/neighbors/brute_force.h @@ -166,6 +166,66 @@ cuvsError_t cuvsBruteForceSearch(cuvsResources_t res, * @} */ +/** + * @defgroup bruteforce_c_serialize BRUTEFORCE C-API serialize functions + * @{ + */ +/** + * Save the index to file. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. + * + * @code{.c} + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); + * + * // create an index with `cuvsBruteforceBuild` + * cuvsBruteForceSerialize(res, "/path/to/index", index); + * @endcode + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] filename the file name for saving the index + * @param[in] index BRUTEFORCE index + * + */ +cuvsError_t cuvsBruteForceSerialize(cuvsResources_t res, + const char* filename, + cuvsBruteForceIndex_t index); + +/** + * Load index from file. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. + * + * @code{.c} + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); + * + * // Deserialize an index previously built with `cuvsBruteforceBuild` + * cuvsBruteForceIndex_t index; + * cuvsBruteForceIndexCreate(&index); + * cuvsBruteForceDeserialize(res, "/path/to/index", index); + * @endcode + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] filename the name of the file that stores the index + * @param[out] index BRUTEFORCE index loaded disk + */ +cuvsError_t cuvsBruteForceDeserialize(cuvsResources_t res, + const char* filename, + cuvsBruteForceIndex_t index); + +/** + * @} + */ #ifdef __cplusplus } #endif diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp index ba67797ee..d040e03db 100644 --- a/cpp/include/cuvs/neighbors/brute_force.hpp +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -48,6 +48,14 @@ struct index : cuvs::neighbors::index { index& operator=(index&&) = default; ~index() = default; + /** + * @brief Construct an empty index. + * + * Constructs an empty index. This index will either need to be trained with `build` + * or loaded from a saved copy with `deserialize` + */ + index(raft::resources const& handle); + /** Construct a brute force index from dataset * * Constructs a brute force index from a dataset. This lets us precompute norms for @@ -479,4 +487,239 @@ void search(raft::resources const& handle, /** * @} */ + +/** + * @defgroup bruteforce_cpp_index_serialize Bruteforce index serialize functions + * @{ + */ +/** + * Save the index to file. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = brute_force::build(...);` + * cuvs::neighbors::brute_force::serialize(handle, filename, index); + * @endcode + * + * @tparam T data element type + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index brute force index + * @param[in] include_dataset whether to include the dataset in the serialized + * output + */ +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::brute_force::index& index, + bool include_dataset = true); +/** + * Save the index to file. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = brute_force::build(...);` + * cuvs::neighbors::brute_force::serialize(handle, filename, index); + * @endcode + * + * @tparam T data element type + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index brute force index + * @param[in] include_dataset whether to include the dataset in the serialized + * output + * + */ +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::brute_force::index& index, + bool include_dataset = true); + +/** + * Write the index to an output stream + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = cuvs::neighbors::brute_force::build(...);` + * cuvs::neighbors::brute_force::serialize(handle, os, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index brute force index + * @param[in] include_dataset Whether or not to write out the dataset to the file. + */ +void serialize(raft::resources const& handle, + std::ostream& os, + const cuvs::neighbors::brute_force::index& index, + bool include_dataset = true); + +/** + * Write the index to an output stream + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = cuvs::neighbors::brute_force::build(...);` + * cuvs::neighbors::brute_force::serialize(handle, os, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index brute force index + * @param[in] include_dataset Whether or not to write out the dataset to the file. + */ +void serialize(raft::resources const& handle, + std::ostream& os, + const cuvs::neighbors::brute_force::index& index, + bool include_dataset = true); + +/** + * Load index from file. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using T = half; // data element type + * brute_force::index index(handle); + * cuvs::neighbors::brute_force::deserialize(handle, filename, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[out] index brute force index + * + */ +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::brute_force::index* index); +/** + * Load index from file. + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using T = float; // data element type + * brute_force::index index(handle); + * cuvs::neighbors::brute_force::deserialize(handle, filename, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[out] index brute force index + * + */ +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::brute_force::index* index); +/** + * Load index from input stream + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * using T = half; // data element type + * brute_force::index index(handle); + * cuvs::neighbors::brute_force::deserialize(handle, is, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] is input stream + * @param[out] index brute force index + * + */ +void deserialize(raft::resources const& handle, + std::istream& is, + cuvs::neighbors::brute_force::index* index); +/** + * Load index from input stream + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * using T = float; // data element type + * brute_force::index index(handle); + * cuvs::neighbors::brute_force::deserialize(handle, is, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] is input stream + * @param[out] index brute force index + * + */ +void deserialize(raft::resources const& handle, + std::istream& is, + cuvs::neighbors::brute_force::index* index); +/** + * @} + */ + } // namespace cuvs::neighbors::brute_force diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index e48050756..5ceb3010e 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -363,7 +363,7 @@ struct index : cuvs::neighbors::index { * // search K nearest neighbours * auto neighbors = raft::make_device_matrix(res, n_queries, k); * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search(res, search_params, index, queries, neighbors, distances); + * cagra::search(res, search_params, index, queries, neighbors.view(), distances.view()); * @endcode * In the above example, we have passed a host dataset to build. The returned index will own a * device copy of the dataset and the knn_graph. In contrast, if we pass the dataset as a @@ -530,7 +530,7 @@ struct index : cuvs::neighbors::index { * // search K nearest neighbours * auto neighbors = raft::make_device_matrix(res, n_queries, k); * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search(res, search_params, index, queries, neighbors, distances); + * cagra::search(res, search_params, index, queries, neighbors.view(), distances.view()); * @endcode * * @param[in] res @@ -567,7 +567,7 @@ auto build(raft::resources const& res, * // search K nearest neighbours * auto neighbors = raft::make_device_matrix(res, n_queries, k); * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search(res, search_params, index, queries, neighbors, distances); + * cagra::search(res, search_params, index, queries, neighbors.view(), distances.view()); * @endcode * * @param[in] res @@ -604,7 +604,7 @@ auto build(raft::resources const& res, * // search K nearest neighbours * auto neighbors = raft::make_device_matrix(res, n_queries, k); * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search(res, search_params, index, queries, neighbors, distances); + * cagra::search(res, search_params, index, queries, neighbors.view(), distances.view()); * @endcode * * @param[in] res @@ -640,7 +640,7 @@ auto build(raft::resources const& res, * // search K nearest neighbours * auto neighbors = raft::make_device_matrix(res, n_queries, k); * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search(res, search_params, index, queries, neighbors, distances); + * cagra::search(res, search_params, index, queries, neighbors.view(), distances.view()); * @endcode * * @param[in] res @@ -676,7 +676,7 @@ auto build(raft::resources const& res, * // search K nearest neighbours * auto neighbors = raft::make_device_matrix(res, n_queries, k); * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search(res, search_params, index, queries, neighbors, distances); + * cagra::search(res, search_params, index, queries, neighbors.view(), distances.view()); * @endcode * * @param[in] res @@ -713,7 +713,7 @@ auto build(raft::resources const& res, * // search K nearest neighbours * auto neighbors = raft::make_device_matrix(res, n_queries, k); * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search(res, search_params, index, queries, neighbors, distances); + * cagra::search(res, search_params, index, queries, neighbors.view(), distances.view()); * @endcode * * @param[in] res @@ -750,7 +750,7 @@ auto build(raft::resources const& res, * // search K nearest neighbours * auto neighbors = raft::make_device_matrix(res, n_queries, k); * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search(res, search_params, index, queries, neighbors, distances); + * cagra::search(res, search_params, index, queries, neighbors.view(), distances.view()); * @endcode * * @param[in] res @@ -787,7 +787,7 @@ auto build(raft::resources const& res, * // search K nearest neighbours * auto neighbors = raft::make_device_matrix(res, n_queries, k); * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search(res, search_params, index, queries, neighbors, distances); + * cagra::search(res, search_params, index, queries, neighbors.view(), distances.view()); * @endcode * * @param[in] res diff --git a/cpp/src/distance/detail/pairwise_distance_epilogue_elementwise.h b/cpp/src/distance/detail/pairwise_distance_epilogue_elementwise.h index f9955334d..f4a7feaba 100644 --- a/cpp/src/distance/detail/pairwise_distance_epilogue_elementwise.h +++ b/cpp/src/distance/detail/pairwise_distance_epilogue_elementwise.h @@ -61,6 +61,7 @@ class PairwiseDistanceEpilogueElementwise { using ElementT = ElementT_; static int const kElementsPerAccess = ElementsPerAccess; static int const kCount = kElementsPerAccess; + static bool const kIsSingleSource = true; using DistanceOp = DistanceOp_; using FinalOp = FinalOp_; diff --git a/cpp/src/neighbors/brute_force.cu b/cpp/src/neighbors/brute_force.cu index b0f87e9ac..d534676e3 100644 --- a/cpp/src/neighbors/brute_force.cu +++ b/cpp/src/neighbors/brute_force.cu @@ -21,6 +21,21 @@ #include namespace cuvs::neighbors::brute_force { + +template +index::index(raft::resources const& res) + // this constructor is just for a temporary index, for use in the deserialization + // api. all the parameters here will get replaced with loaded values - that aren't + // necessarily known ahead of time before deserialization. + // TODO: do we even need a handle here - could just construct one? + : cuvs::neighbors::index(), + metric_(cuvs::distance::DistanceType::L2Expanded), + dataset_(raft::make_device_matrix(res, 0, 0)), + norms_(std::nullopt), + metric_arg_(0) +{ +} + template index::index(raft::resources const& res, raft::host_matrix_view dataset, diff --git a/cpp/src/neighbors/brute_force_c.cpp b/cpp/src/neighbors/brute_force_c.cpp index eda79aa31..f1a8c995d 100644 --- a/cpp/src/neighbors/brute_force_c.cpp +++ b/cpp/src/neighbors/brute_force_c.cpp @@ -17,10 +17,12 @@ #include #include +#include #include #include #include +#include #include #include @@ -91,6 +93,22 @@ void _search(cuvsResources_t res, } } +template +void _serialize(cuvsResources_t res, const char* filename, cuvsBruteForceIndex index) +{ + auto res_ptr = reinterpret_cast(res); + auto index_ptr = reinterpret_cast*>(index.addr); + cuvs::neighbors::brute_force::serialize(*res_ptr, std::string(filename), *index_ptr); +} + +template +void* _deserialize(cuvsResources_t res, const char* filename) +{ + auto res_ptr = reinterpret_cast(res); + auto index = new cuvs::neighbors::brute_force::index(*res_ptr); + cuvs::neighbors::brute_force::deserialize(*res_ptr, std::string(filename), index); + return index; +} } // namespace extern "C" cuvsError_t cuvsBruteForceIndexCreate(cuvsBruteForceIndex_t* index) @@ -129,7 +147,7 @@ extern "C" cuvsError_t cuvsBruteForceBuild(cuvsResources_t res, if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { index->addr = reinterpret_cast(_build(res, dataset_tensor, metric, metric_arg)); - index->dtype.code = kDLFloat; + index->dtype = dataset.dtype; } else { RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", dataset.dtype.code, @@ -174,3 +192,38 @@ extern "C" cuvsError_t cuvsBruteForceSearch(cuvsResources_t res, } }); } + +extern "C" cuvsError_t cuvsBruteForceDeserialize(cuvsResources_t res, + const char* filename, + cuvsBruteForceIndex_t index) +{ + return cuvs::core::translate_exceptions([=] { + // read the numpy dtype from the beginning of the file + std::ifstream is(filename, std::ios::in | std::ios::binary); + if (!is) { RAFT_FAIL("Cannot open file %s", filename); } + char dtype_string[4]; + is.read(dtype_string, 4); + auto dtype = raft::detail::numpy_serializer::parse_descr(std::string(dtype_string, 4)); + + index->dtype.bits = dtype.itemsize * 8; + if (dtype.kind == 'f' && dtype.itemsize == 4) { + index->dtype.code = kDLFloat; + index->addr = reinterpret_cast(_deserialize(res, filename)); + } else { + RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits); + } + }); +} + +extern "C" cuvsError_t cuvsBruteForceSerialize(cuvsResources_t res, + const char* filename, + cuvsBruteForceIndex_t index) +{ + return cuvs::core::translate_exceptions([=] { + if (index->dtype.code == kDLFloat && index->dtype.bits == 32) { + _serialize(res, filename, *index); + } else { + RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits); + } + }); +} \ No newline at end of file diff --git a/cpp/src/neighbors/brute_force_serialize.cu b/cpp/src/neighbors/brute_force_serialize.cu new file mode 100644 index 000000000..1b5b5111e --- /dev/null +++ b/cpp/src/neighbors/brute_force_serialize.cu @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +namespace cuvs::neighbors::brute_force { + +int constexpr serialization_version = 0; + +template +void serialize(raft::resources const& handle, + std::ostream& os, + const index& index, + bool include_dataset = true) +{ + RAFT_LOG_DEBUG( + "Saving brute force index, size %zu, dim %u", static_cast(index.size()), index.dim()); + + auto dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); + dtype_string.resize(4); + os << dtype_string; + + raft::serialize_scalar(handle, os, serialization_version); + raft::serialize_scalar(handle, os, index.size()); + raft::serialize_scalar(handle, os, index.dim()); + raft::serialize_scalar(handle, os, index.metric()); + raft::serialize_scalar(handle, os, index.metric_arg()); + raft::serialize_scalar(handle, os, include_dataset); + if (include_dataset) { raft::serialize_mdspan(handle, os, index.dataset()); } + auto has_norms = index.has_norms(); + raft::serialize_scalar(handle, os, has_norms); + if (has_norms) { raft::serialize_mdspan(handle, os, index.norms()); } + raft::resource::sync_stream(handle); +} + +void serialize(raft::resources const& handle, + const std::string& filename, + const index& index, + bool include_dataset) +{ + auto os = std::ofstream{filename, std::ios::out | std::ios::binary}; + RAFT_EXPECTS(os, "Cannot open file %s", filename.c_str()); + serialize(handle, os, index, include_dataset); +} + +void serialize(raft::resources const& handle, + const std::string& filename, + const index& index, + bool include_dataset) +{ + auto os = std::ofstream{filename, std::ios::out | std::ios::binary}; + RAFT_EXPECTS(os, "Cannot open file %s", filename.c_str()); + serialize(handle, os, index, include_dataset); +} + +void serialize(raft::resources const& handle, + std::ostream& os, + const index& index, + bool include_dataset) +{ + serialize(handle, os, index, include_dataset); +} + +void serialize(raft::resources const& handle, + std::ostream& os, + const index& index, + bool include_dataset) +{ + serialize(handle, os, index, include_dataset); +} + +template +auto deserialize(raft::resources const& handle, std::istream& is) +{ + auto dtype_string = std::array{}; + is.read(dtype_string.data(), 4); + + auto ver = raft::deserialize_scalar(handle, is); + if (ver != serialization_version) { + RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); + } + std::int64_t rows = raft::deserialize_scalar(handle, is); + std::int64_t dim = raft::deserialize_scalar(handle, is); + auto metric = raft::deserialize_scalar(handle, is); + auto metric_arg = raft::deserialize_scalar(handle, is); + + auto dataset_storage = raft::make_host_matrix(std::int64_t{}, std::int64_t{}); + auto include_dataset = raft::deserialize_scalar(handle, is); + if (include_dataset) { + dataset_storage = raft::make_host_matrix(rows, dim); + raft::deserialize_mdspan(handle, is, dataset_storage.view()); + } + + auto has_norms = raft::deserialize_scalar(handle, is); + auto norms_storage = has_norms ? std::optional{raft::make_host_vector(rows)} + : std::optional>{}; + // TODO(wphicks): Use mdbuffer here when available + auto norms_storage_dev = + has_norms ? std::optional{raft::make_device_vector(handle, rows)} + : std::optional>{}; + if (has_norms) { + raft::deserialize_mdspan(handle, is, norms_storage->view()); + raft::copy(handle, norms_storage_dev->view(), norms_storage->view()); + } + + auto result = index(handle, + raft::make_const_mdspan(dataset_storage.view()), + std::move(norms_storage_dev), + metric, + metric_arg); + raft::resource::sync_stream(handle); + + return result; +} + +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::brute_force::index* index) +{ + auto is = std::ifstream{filename, std::ios::in | std::ios::binary}; + RAFT_EXPECTS(is, "Cannot open file %s", filename.c_str()); + + *index = deserialize(handle, is); +} + +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::brute_force::index* index) +{ + auto is = std::ifstream{filename, std::ios::in | std::ios::binary}; + RAFT_EXPECTS(is, "Cannot open file %s", filename.c_str()); + + *index = deserialize(handle, is); +} + +void deserialize(raft::resources const& handle, + std::istream& is, + cuvs::neighbors::brute_force::index* index) +{ + *index = deserialize(handle, is); +} + +void deserialize(raft::resources const& handle, + std::istream& is, + cuvs::neighbors::brute_force::index* index) +{ + *index = deserialize(handle, is); +} + +} // namespace cuvs::neighbors::brute_force diff --git a/cpp/test/neighbors/ann_brute_force.cuh b/cpp/test/neighbors/ann_brute_force.cuh index c2afa4e8b..03d6e820c 100644 --- a/cpp/test/neighbors/ann_brute_force.cuh +++ b/cpp/test/neighbors/ann_brute_force.cuh @@ -114,12 +114,28 @@ class AnnBruteForceTest : public ::testing::TestWithParam(handle_); + brute_force::deserialize(handle_, std::string{"brute_force_index"}, &index_loaded); + brute_force::search(handle_, - idx, + index_loaded, search_queries_view, indices_out_view, dists_out_view, cuvs::neighbors::filtering::none_sample_filter{}); + raft::resource::sync_stream(handle_); + + ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(indices_naive_dev.data(), + indices_bruteforce_dev.data(), + distances_naive_dev.data(), + distances_bruteforce_dev.data(), + ps.num_queries, + ps.k, + 0.001f, + stream_, + true)); } } diff --git a/docs/source/c_api/neighbors_bruteforce_c.rst b/docs/source/c_api/neighbors_bruteforce_c.rst index af0356eee..a12175209 100644 --- a/docs/source/c_api/neighbors_bruteforce_c.rst +++ b/docs/source/c_api/neighbors_bruteforce_c.rst @@ -32,3 +32,11 @@ Index search :project: cuvs :members: :content-only: + +Index serialize +--------------- + +.. doxygengroup:: bruteforce_c_index_serialize + :project: cuvs + :members: + :content-only: diff --git a/docs/source/c_api/neighbors_hnsw_c.rst b/docs/source/c_api/neighbors_hnsw_c.rst index 4d83cd3e3..988e5b6f3 100644 --- a/docs/source/c_api/neighbors_hnsw_c.rst +++ b/docs/source/c_api/neighbors_hnsw_c.rst @@ -29,13 +29,13 @@ Index Index search ------------ -.. doxygengroup:: cagra_c_index_search +.. doxygengroup:: hnsw_c_index_search :project: cuvs :members: :content-only: Index serialize ------------- +--------------- .. doxygengroup:: hnsw_c_index_serialize :project: cuvs diff --git a/docs/source/c_api/neighbors_ivf_flat_c.rst b/docs/source/c_api/neighbors_ivf_flat_c.rst index 9e1ccc0d1..1254d70ef 100644 --- a/docs/source/c_api/neighbors_ivf_flat_c.rst +++ b/docs/source/c_api/neighbors_ivf_flat_c.rst @@ -48,3 +48,11 @@ Index search :project: cuvs :members: :content-only: + +Index serialize +--------------- + +.. doxygengroup:: ivf_flat_c_index_serialize + :project: cuvs + :members: + :content-only: diff --git a/docs/source/c_api/neighbors_ivf_pq_c.rst b/docs/source/c_api/neighbors_ivf_pq_c.rst index 070719609..260057b8c 100644 --- a/docs/source/c_api/neighbors_ivf_pq_c.rst +++ b/docs/source/c_api/neighbors_ivf_pq_c.rst @@ -48,3 +48,11 @@ Index search :project: cuvs :members: :content-only: + +Index serialize +--------------- + +.. doxygengroup:: ivf_pq_c_index_serialize + :project: cuvs + :members: + :content-only: diff --git a/docs/source/cpp_api/neighbors_bruteforce.rst b/docs/source/cpp_api/neighbors_bruteforce.rst index 3adcb01c5..f75e26b3c 100644 --- a/docs/source/cpp_api/neighbors_bruteforce.rst +++ b/docs/source/cpp_api/neighbors_bruteforce.rst @@ -34,3 +34,11 @@ Index search :project: cuvs :members: :content-only: + +Index serialize +--------------- + +.. doxygengroup:: bruteforce_cpp_index_serialize + :project: cuvs + :members: + :content-only: diff --git a/docs/source/python_api/neighbors_brute_force.rst b/docs/source/python_api/neighbors_brute_force.rst index 5fdc3658f..d756a6c80 100644 --- a/docs/source/python_api/neighbors_brute_force.rst +++ b/docs/source/python_api/neighbors_brute_force.rst @@ -20,3 +20,13 @@ Index search ############ .. autofunction:: cuvs.neighbors.brute_force.search + +Index save +########## + +.. autofunction:: cuvs.neighbors.brute_force.save + +Index load +########## + +.. autofunction:: cuvs.neighbors.brute_force.load diff --git a/docs/source/python_api/neighbors_cagra.rst b/docs/source/python_api/neighbors_cagra.rst index 09b2e2694..e7155efb8 100644 --- a/docs/source/python_api/neighbors_cagra.rst +++ b/docs/source/python_api/neighbors_cagra.rst @@ -34,3 +34,13 @@ Index search ############ .. autofunction:: cuvs.neighbors.cagra.search + +Index save +########## + +.. autofunction:: cuvs.neighbors.cagra.save + +Index load +########## + +.. autofunction:: cuvs.neighbors.cagra.load diff --git a/docs/source/python_api/neighbors_hnsw.rst b/docs/source/python_api/neighbors_hnsw.rst index 9922805b3..64fe5493b 100644 --- a/docs/source/python_api/neighbors_hnsw.rst +++ b/docs/source/python_api/neighbors_hnsw.rst @@ -28,3 +28,13 @@ Index search ############ .. autofunction:: cuvs.neighbors.hnsw.search + +Index save +########## + +.. autofunction:: cuvs.neighbors.hnsw.save + +Index load +########## + +.. autofunction:: cuvs.neighbors.hnsw.load diff --git a/docs/source/python_api/neighbors_ivf_flat.rst b/docs/source/python_api/neighbors_ivf_flat.rst index 5514e5e43..f2c21e68a 100644 --- a/docs/source/python_api/neighbors_ivf_flat.rst +++ b/docs/source/python_api/neighbors_ivf_flat.rst @@ -32,3 +32,13 @@ Index search ############ .. autofunction:: cuvs.neighbors.ivf_flat.search + +Index save +########## + +.. autofunction:: cuvs.neighbors.ivf_flat.save + +Index load +########## + +.. autofunction:: cuvs.neighbors.ivf_flat.load diff --git a/docs/source/python_api/neighbors_ivf_pq.rst b/docs/source/python_api/neighbors_ivf_pq.rst index e3625ba67..57668fbc3 100644 --- a/docs/source/python_api/neighbors_ivf_pq.rst +++ b/docs/source/python_api/neighbors_ivf_pq.rst @@ -32,3 +32,13 @@ Index search ############ .. autofunction:: cuvs.neighbors.ivf_pq.search + +Index save +########## + +.. autofunction:: cuvs.neighbors.ivf_pq.save + +Index load +########## + +.. autofunction:: cuvs.neighbors.ivf_pq.load diff --git a/python/cuvs/cuvs/neighbors/brute_force/__init__.py b/python/cuvs/cuvs/neighbors/brute_force/__init__.py index b88c4b464..6aa0e4bb2 100644 --- a/python/cuvs/cuvs/neighbors/brute_force/__init__.py +++ b/python/cuvs/cuvs/neighbors/brute_force/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. -from .brute_force import Index, build, search +from .brute_force import Index, build, load, save, search -__all__ = ["Index", "build", "search"] +__all__ = ["Index", "build", "search", "save", "load"] diff --git a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pxd b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pxd index 183827916..f1fc14ba7 100644 --- a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pxd +++ b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pxd @@ -47,3 +47,11 @@ cdef extern from "cuvs/neighbors/brute_force.h" nogil: DLManagedTensor* neighbors, DLManagedTensor* distances, cuvsFilter filter) except + + + cuvsError_t cuvsBruteForceSerialize(cuvsResources_t res, + const char * filename, + cuvsBruteForceIndex_t index) except + + + cuvsError_t cuvsBruteForceDeserialize(cuvsResources_t res, + const char * filename, + cuvsBruteForceIndex_t index) except + diff --git a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx index 9d1d24eae..9d43bfb29 100644 --- a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx +++ b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx @@ -24,6 +24,7 @@ from cuvs.common.resources import auto_sync_resources from cython.operator cimport dereference as deref from libc.stdint cimport uint32_t from libcpp cimport bool +from libcpp.string cimport string from cuvs.common cimport cydlpack from cuvs.distance_type cimport cuvsDistanceType @@ -256,3 +257,88 @@ def search(Index index, )) return (distances, neighbors) + + +@auto_sync_resources +def save(filename, Index index, bool include_dataset=True, resources=None): + """ + Saves the index to a file. + + The serialization format can be subject to changes, therefore loading + an index saved with a previous version of cuvs is not guaranteed + to work. + + Parameters + ---------- + filename : string + Name of the file. + index : Index + Trained Brute Force index. + {resources_docstring} + + Examples + -------- + >>> import cupy as cp + >>> from cuvs.neighbors import brute_force + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> index = brute_force.build(dataset) + >>> # Serialize and deserialize the brute_force index built + >>> brute_force.save("my_index.bin", index) + >>> index_loaded = brute_force.load("my_index.bin") + """ + cdef string c_filename = filename.encode('utf-8') + cdef cuvsResources_t res = resources.get_c_obj() + check_cuvs(cuvsBruteForceSerialize(res, + c_filename.c_str(), + index.index)) + + +@auto_sync_resources +def load(filename, resources=None): + """ + Loads index from file. + + The serialization format can be subject to changes, therefore loading + an index saved with a previous version of cuvs is not guaranteed + to work. + + + Parameters + ---------- + filename : string + Name of the file. + {resources_docstring} + + Returns + ------- + index : Index + + Examples + -------- + >>> import cupy as cp + >>> from cuvs.neighbors import brute_force + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> index = brute_force.build(dataset) + >>> # Serialize and deserialize the brute_force index built + >>> brute_force.save("my_index.bin", index) + >>> index_loaded = brute_force.load("my_index.bin") + """ + cdef Index idx = Index() + cdef cuvsResources_t res = resources.get_c_obj() + cdef string c_filename = filename.encode('utf-8') + + check_cuvs(cuvsBruteForceDeserialize( + res, + c_filename.c_str(), + idx.index + )) + idx.trained = True + return idx diff --git a/python/cuvs/cuvs/test/test_serialization.py b/python/cuvs/cuvs/test/test_serialization.py index 4ffccf121..1f4a54e87 100644 --- a/python/cuvs/cuvs/test/test_serialization.py +++ b/python/cuvs/cuvs/test/test_serialization.py @@ -17,7 +17,7 @@ import pytest from pylibraft.common import device_ndarray -from cuvs.neighbors import cagra, ivf_flat, ivf_pq +from cuvs.neighbors import brute_force, cagra, ivf_flat, ivf_pq from cuvs.test.ann_utils import generate_data @@ -35,6 +35,10 @@ def test_save_load_ivf_pq(): run_save_load(ivf_pq, np.float32) +def test_save_load_brute_force(): + run_save_load(brute_force, np.float32) + + def run_save_load(ann_module, dtype): n_rows = 10000 n_cols = 50 @@ -43,8 +47,11 @@ def run_save_load(ann_module, dtype): dataset = generate_data((n_rows, n_cols), dtype) dataset_device = device_ndarray(dataset) - build_params = ann_module.IndexParams() - index = ann_module.build(build_params, dataset_device) + if ann_module == brute_force: + index = ann_module.build(dataset_device) + else: + build_params = ann_module.IndexParams() + index = ann_module.build(build_params, dataset_device) assert index.trained filename = "my_index.bin" @@ -54,20 +61,29 @@ def run_save_load(ann_module, dtype): queries = generate_data((n_queries, n_cols), dtype) queries_device = device_ndarray(queries) - search_params = ann_module.SearchParams() k = 10 - - distance_dev, neighbors_dev = ann_module.search( - search_params, index, queries_device, k - ) + if ann_module == brute_force: + distance_dev, neighbors_dev = ann_module.search( + index, queries_device, k + ) + else: + search_params = ann_module.SearchParams() + distance_dev, neighbors_dev = ann_module.search( + search_params, index, queries_device, k + ) neighbors = neighbors_dev.copy_to_host() dist = distance_dev.copy_to_host() del index - distance_dev, neighbors_dev = ann_module.search( - search_params, loaded_index, queries_device, k - ) + if ann_module == brute_force: + distance_dev, neighbors_dev = ann_module.search( + loaded_index, queries_device, k + ) + else: + distance_dev, neighbors_dev = ann_module.search( + search_params, loaded_index, queries_device, k + ) neighbors2 = neighbors_dev.copy_to_host() dist2 = distance_dev.copy_to_host()