Skip to content

Commit

Permalink
Add serialization API to brute-force (#461)
Browse files Browse the repository at this point in the history
I noticed it was missing while switching Milvus to cuVS

Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #461
  • Loading branch information
lowener authored Nov 25, 2024
1 parent 96d98b1 commit e1359e1
Show file tree
Hide file tree
Showing 22 changed files with 767 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ compile_commands.json
.clangd/

# serialized ann indexes
brute_force_index
cagra_index
ivf_flat_index
ivf_pq_index
Expand Down
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ if(BUILD_SHARED_LIBS)
src/distance/pairwise_distance.cu
src/distance/sparse_distance.cu
src/neighbors/brute_force.cu
src/neighbors/brute_force_serialize.cu
src/neighbors/cagra_build_float.cu
src/neighbors/cagra_build_half.cu
src/neighbors/cagra_build_int8.cu
Expand Down
60 changes: 60 additions & 0 deletions cpp/include/cuvs/neighbors/brute_force.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,66 @@ cuvsError_t cuvsBruteForceSearch(cuvsResources_t res,
* @}
*/

/**
* @defgroup bruteforce_c_serialize BRUTEFORCE C-API serialize functions
* @{
*/
/**
* Save the index to file.
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.c}
* #include <cuvs/neighbors/brute_force.h>
*
* // Create cuvsResources_t
* cuvsResources_t res;
* cuvsError_t res_create_status = cuvsResourcesCreate(&res);
*
* // create an index with `cuvsBruteforceBuild`
* cuvsBruteForceSerialize(res, "/path/to/index", index);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the file name for saving the index
* @param[in] index BRUTEFORCE index
*
*/
cuvsError_t cuvsBruteForceSerialize(cuvsResources_t res,
const char* filename,
cuvsBruteForceIndex_t index);

/**
* Load index from file.
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.c}
* #include <cuvs/neighbors/brute_force.h>
*
* // Create cuvsResources_t
* cuvsResources_t res;
* cuvsError_t res_create_status = cuvsResourcesCreate(&res);
*
* // Deserialize an index previously built with `cuvsBruteforceBuild`
* cuvsBruteForceIndex_t index;
* cuvsBruteForceIndexCreate(&index);
* cuvsBruteForceDeserialize(res, "/path/to/index", index);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the name of the file that stores the index
* @param[out] index BRUTEFORCE index loaded disk
*/
cuvsError_t cuvsBruteForceDeserialize(cuvsResources_t res,
const char* filename,
cuvsBruteForceIndex_t index);

/**
* @}
*/
#ifdef __cplusplus
}
#endif
243 changes: 243 additions & 0 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ struct index : cuvs::neighbors::index {
index& operator=(index&&) = default;
~index() = default;

/**
* @brief Construct an empty index.
*
* Constructs an empty index. This index will either need to be trained with `build`
* or loaded from a saved copy with `deserialize`
*/
index(raft::resources const& handle);

/** Construct a brute force index from dataset
*
* Constructs a brute force index from a dataset. This lets us precompute norms for
Expand Down Expand Up @@ -479,4 +487,239 @@ void search(raft::resources const& handle,
/**
* @}
*/

/**
* @defgroup bruteforce_cpp_index_serialize Bruteforce index serialize functions
* @{
*/
/**
* Save the index to file.
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, filename, index);
* @endcode
*
* @tparam T data element type
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index brute force index
* @param[in] include_dataset whether to include the dataset in the serialized
* output
*/
void serialize(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::brute_force::index<half, float>& index,
bool include_dataset = true);
/**
* Save the index to file.
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, filename, index);
* @endcode
*
* @tparam T data element type
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index brute force index
* @param[in] include_dataset whether to include the dataset in the serialized
* output
*
*/
void serialize(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::brute_force::index<float, float>& index,
bool include_dataset = true);

/**
* Write the index to an output stream
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = cuvs::neighbors::brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, os, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index brute force index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*/
void serialize(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::brute_force::index<half, float>& index,
bool include_dataset = true);

/**
* Write the index to an output stream
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = cuvs::neighbors::brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, os, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index brute force index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*/
void serialize(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::brute_force::index<float, float>& index,
bool include_dataset = true);

/**
* Load index from file.
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* using T = half; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, filename, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
const std::string& filename,
cuvs::neighbors::brute_force::index<half, float>* index);
/**
* Load index from file.
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* using T = float; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, filename, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
const std::string& filename,
cuvs::neighbors::brute_force::index<float, float>* index);
/**
* Load index from input stream
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an input stream
* std::istream is(std::cin.rdbuf());
* using T = half; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, is, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] is input stream
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
std::istream& is,
cuvs::neighbors::brute_force::index<half, float>* index);
/**
* Load index from input stream
* The serialization format can be subject to changes, therefore loading
* an index saved with a previous version of cuvs is not guaranteed
* to work.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an input stream
* std::istream is(std::cin.rdbuf());
* using T = float; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, is, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] is input stream
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
std::istream& is,
cuvs::neighbors::brute_force::index<float, float>* index);
/**
* @}
*/

} // namespace cuvs::neighbors::brute_force
15 changes: 15 additions & 0 deletions cpp/src/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@
#include <raft/core/copy.hpp>

namespace cuvs::neighbors::brute_force {

template <typename T, typename DistT>
index<T, DistT>::index(raft::resources const& res)
// this constructor is just for a temporary index, for use in the deserialization
// api. all the parameters here will get replaced with loaded values - that aren't
// necessarily known ahead of time before deserialization.
// TODO: do we even need a handle here - could just construct one?
: cuvs::neighbors::index(),
metric_(cuvs::distance::DistanceType::L2Expanded),
dataset_(raft::make_device_matrix<T, int64_t>(res, 0, 0)),
norms_(std::nullopt),
metric_arg_(0)
{
}

template <typename T, typename DistT>
index<T, DistT>::index(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset,
Expand Down
Loading

0 comments on commit e1359e1

Please sign in to comment.