Skip to content

Commit

Permalink
Add dataset serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Mar 8, 2024
1 parent 833b50f commit 53a5c14
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 35 deletions.
3 changes: 2 additions & 1 deletion cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ inline dtype_t get_numpy_dtype()
}

#if defined(_RAFT_HAS_CUDA)
template <typename T, typename std::enable_if_t<std::is_same_v<T, half>, bool> = true>
template <typename T,
typename std::enable_if_t<std::is_same_v<std::remove_cv_t<T>, half>, bool> = true>
inline dtype_t get_numpy_dtype()
{
return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'e', sizeof(T)};
Expand Down
14 changes: 13 additions & 1 deletion cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ struct index : ann::index {
}

/** Dataset [size, dim] */
[[nodiscard]] inline auto dataset() const noexcept
[[nodiscard]] inline auto dataset_view() const noexcept
-> device_matrix_view<const T, int64_t, layout_stride>
{
auto p = dynamic_cast<strided_dataset<T, int64_t>*>(dataset_.get());
Expand All @@ -171,6 +171,11 @@ struct index : ann::index {
return make_device_strided_matrix_view<const T, int64_t>(nullptr, 0, d, d);
}

[[nodiscard]] inline auto dataset() const noexcept -> const neighbors::dataset<int64_t>&
{
return *dataset_;
}

/** neighborhood graph [size, graph-degree] */
[[nodiscard]] inline auto graph() const noexcept
-> device_matrix_view<const IdxT, int64_t, row_major>
Expand Down Expand Up @@ -304,6 +309,13 @@ struct index : ann::index {
upcast_dataset_ptr(std::make_unique<DatasetT>(std::move(dataset))).swap(dataset_);
}

template <typename DatasetT>
auto update_dataset(raft::resources const& res, std::unique_ptr<DatasetT>&& dataset)
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<int64_t>, DatasetT>>
{
upcast_dataset_ptr(std::move(dataset)).swap(dataset_);
}

/**
* Replace the graph with a new graph.
*
Expand Down
14 changes: 7 additions & 7 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ void search_main(raft::resources const& res,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
{
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast<size_t>(index.dataset().extent(0)),
static_cast<size_t>(index.dataset().extent(1)));
static_cast<size_t>(index.dataset_view().extent(0)),
static_cast<size_t>(index.dataset_view().extent(1)));
RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n",
static_cast<size_t>(queries.extent(0)),
static_cast<size_t>(queries.extent(1)));
Expand Down Expand Up @@ -151,11 +151,11 @@ void search_main(raft::resources const& res,
: nullptr;
uint32_t* _num_executed_iterations = nullptr;

auto dataset_internal =
make_device_strided_matrix_view<const T, int64_t, row_major>(index.dataset().data_handle(),
index.dataset().extent(0),
index.dataset().extent(1),
index.dataset().stride(0));
auto dataset_internal = make_device_strided_matrix_view<const T, int64_t, row_major>(
index.dataset_view().data_handle(),
index.dataset_view().extent(0),
index.dataset_view().extent(1),
index.dataset_view().stride(0));
auto graph_internal = raft::make_device_matrix_view<const internal_IdxT, int64_t, row_major>(
reinterpret_cast<const internal_IdxT*>(index.graph().data_handle()),
index.graph().extent(0),
Expand Down
37 changes: 11 additions & 26 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/serialize.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/detail/dataset_serialize.hpp>

#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -65,26 +66,14 @@ void serialize(raft::resources const& res,
serialize_scalar(res, os, index_.metric());
serialize_mdspan(res, os, index_.graph());

include_dataset &= (index_.dataset().extent(0) > 0);
include_dataset &= (index_.dataset().n_rows() > 0);

serialize_scalar(res, os, include_dataset);
if (include_dataset) {
RAFT_LOG_INFO("Saving CAGRA index with dataset");
auto dataset = index_.dataset();
// Remove padding before saving the dataset
auto host_dataset = make_host_matrix<T, int64_t>(dataset.extent(0), dataset.extent(1));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(),
sizeof(T) * host_dataset.extent(1),
dataset.data_handle(),
sizeof(T) * dataset.stride(0),
sizeof(T) * host_dataset.extent(1),
dataset.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
resource::sync_stream(res);
serialize_mdspan(res, os, host_dataset.view());
neighbors::detail::serialize(res, os, index_.dataset());
} else {
RAFT_LOG_INFO("Saving CAGRA index WITHOUT dataset");
RAFT_LOG_DEBUG("Saving CAGRA index WITHOUT dataset");
}
}

Expand Down Expand Up @@ -158,7 +147,7 @@ void serialize_to_hnswlib(raft::resources const& res,
std::size_t efConstruction = 500;
os.write(reinterpret_cast<char*>(&efConstruction), sizeof(std::size_t));

auto dataset = index_.dataset();
auto dataset = index_.dataset_view();
// Remove padding before saving the dataset
auto host_dataset = make_host_matrix<T, int64_t>(dataset.extent(0), dataset.extent(1));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(),
Expand Down Expand Up @@ -256,19 +245,15 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
auto graph = raft::make_host_matrix<IdxT, int64_t>(n_rows, graph_degree);
deserialize_mdspan(res, is, graph.view());

index<T, IdxT> idx(res, metric);
idx.update_graph(res, raft::make_const_mdspan(graph.view()));
bool has_dataset = deserialize_scalar<bool>(res, is);
if (has_dataset) {
auto dataset = raft::make_host_matrix<T, int64_t>(n_rows, dim);
deserialize_mdspan(res, is, dataset.view());
return index<T, IdxT>(
res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view()));
} else {
// create a new index with no dataset - the user must supply via update_dataset themselves
// later (this avoids allocating GPU memory in the meantime)
index<T, IdxT> idx(res, metric);
idx.update_graph(res, raft::make_const_mdspan(graph.view()));
return idx;
std::unique_ptr<dataset<int64_t>> dataset;
neighbors::detail::deserialize(res, is, dataset);
idx.update_dataset(res, std::move(dataset));
}
return idx;
}

template <typename T, typename IdxT>
Expand Down
238 changes: 238 additions & 0 deletions cpp/include/raft/neighbors/detail/dataset_serialize.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "../dataset.hpp"

#include <raft/core/host_mdarray.hpp>
#include <raft/core/resources.hpp>
#include <raft/core/serialize.hpp>

#include <cuda_fp16.h>

#include <fstream>
#include <memory>

namespace raft::neighbors::detail {

using dataset_instance_tag = uint32_t;
constexpr dataset_instance_tag kSerializeEmptyDataset = 1;
constexpr dataset_instance_tag kSerializeStridedDataset = 2;
constexpr dataset_instance_tag kSerializeVPQDataset = 3;

template <typename IdxT>
void serialize(const raft::resources& res, std::ostream& os, const empty_dataset<IdxT>& dataset)
{
serialize_scalar(res, os, dataset.suggested_dim);
}

template <typename DataT, typename IdxT>
void serialize(const raft::resources& res,
std::ostream& os,
const strided_dataset<DataT, IdxT>& dataset)
{
serialize_scalar(res, os, dataset.n_rows());
serialize_scalar(res, os, dataset.dim());
serialize_scalar(res, os, dataset.stride());
// Remove padding before saving the dataset
auto src = dataset.view();
auto dst = make_host_mdarray<DataT, IdxT>(src.extents());
RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(),
sizeof(DataT) * dst.extent(1),
src.data_handle(),
sizeof(DataT) * src.stride(0),
sizeof(DataT) * dst.extent(1),
src.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
resource::sync_stream(res);
serialize_mdspan(res, os, dst.view());
}

template <typename MathT, typename IdxT>
void serialize(const raft::resources& res,
std::ostream& os,
const vpq_dataset<MathT, IdxT>& dataset)
{
serialize_scalar(res, os, dataset.n_rows());
serialize_scalar(res, os, dataset.dim());
serialize_scalar(res, os, dataset.vq_n_centers());
serialize_scalar(res, os, dataset.pq_n_centers());
serialize_scalar(res, os, dataset.pq_len());
serialize_scalar(res, os, dataset.encoded_row_length());
serialize_mdspan(res, os, make_const_mdspan(dataset.vq_code_book.view()));
serialize_mdspan(res, os, make_const_mdspan(dataset.pq_code_book.view()));
serialize_mdspan(res, os, make_const_mdspan(dataset.data.view()));
}

template <typename IdxT>
void serialize(const raft::resources& res, std::ostream& os, const dataset<IdxT>& dataset)
{
if (auto x = dynamic_cast<const empty_dataset<IdxT>*>(&dataset); x != nullptr) {
serialize_scalar(res, os, kSerializeEmptyDataset);
return serialize(res, os, *x);
}
if (auto x = dynamic_cast<const strided_dataset<float, IdxT>*>(&dataset); x != nullptr) {
serialize_scalar(res, os, kSerializeStridedDataset);
serialize_scalar(res, os, CUDA_R_32F);
return serialize(res, os, *x);
}
if (auto x = dynamic_cast<const strided_dataset<half, IdxT>*>(&dataset); x != nullptr) {
serialize_scalar(res, os, kSerializeStridedDataset);
serialize_scalar(res, os, CUDA_R_16F);
return serialize(res, os, *x);
}
if (auto x = dynamic_cast<const strided_dataset<int8_t, IdxT>*>(&dataset); x != nullptr) {
serialize_scalar(res, os, kSerializeStridedDataset);
serialize_scalar(res, os, CUDA_R_8I);
return serialize(res, os, *x);
}
if (auto x = dynamic_cast<const strided_dataset<uint8_t, IdxT>*>(&dataset); x != nullptr) {
serialize_scalar(res, os, kSerializeStridedDataset);
serialize_scalar(res, os, CUDA_R_8U);
return serialize(res, os, *x);
}
if (auto x = dynamic_cast<const vpq_dataset<float, IdxT>*>(&dataset); x != nullptr) {
serialize_scalar(res, os, kSerializeVPQDataset);
serialize_scalar(res, os, CUDA_R_32F);
return serialize(res, os, *x);
}
if (auto x = dynamic_cast<const vpq_dataset<half, IdxT>*>(&dataset); x != nullptr) {
serialize_scalar(res, os, kSerializeVPQDataset);
serialize_scalar(res, os, CUDA_R_16F);
return serialize(res, os, *x);
}
RAFT_FAIL("unsupported dataset type.");
}

template <typename IdxT>
void deserialize(raft::resources const& res,
std::istream& is,
std::unique_ptr<empty_dataset<IdxT>>& out)
{
auto suggested_dim = deserialize_scalar<uint32_t>(res, is);
return std::make_unique<empty_dataset<IdxT>>(suggested_dim).swap(out);
}

template <typename DataT, typename IdxT>
void deserialize(raft::resources const& res,
std::istream& is,
std::unique_ptr<strided_dataset<DataT, IdxT>>& out)
{
using out_mdarray_type = device_mdarray<DataT, matrix_extent<IdxT>, layout_stride>;
using out_layout_type = typename out_mdarray_type::layout_type;
using out_container_policy_type = typename out_mdarray_type::container_policy_type;
using out_owning_type = owning_dataset<DataT, IdxT, out_layout_type, out_container_policy_type>;

auto n_rows = deserialize_scalar<IdxT>(res, is);
auto dim = deserialize_scalar<uint32_t>(res, is);
auto stride = deserialize_scalar<uint32_t>(res, is);
auto out_extents = make_extents<IdxT>(n_rows, dim);
auto out_layout = make_strided_layout(out_extents, std::array<IdxT, 2>{stride, 1});
auto out_array = out_mdarray_type{res, out_layout, out_container_policy_type{}};
auto host_arrray = make_host_mdarray<DataT, IdxT>(out_extents);
deserialize_mdspan(res, is, host_arrray.view());
RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(),
sizeof(DataT) * stride,
host_arrray.data_handle(),
sizeof(DataT) * dim,
sizeof(DataT) * dim,
n_rows,
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
return std::unique_ptr<strided_dataset<DataT, IdxT>>{
new out_owning_type{std::move(out_array), out_layout}}
.swap(out);
}

template <typename MathT, typename IdxT>
void deserialize(raft::resources const& res,
std::istream& is,
std::unique_ptr<vpq_dataset<MathT, IdxT>>& out)
{
auto n_rows = deserialize_scalar<IdxT>(res, is);
auto dim = deserialize_scalar<uint32_t>(res, is);
auto vq_n_centers = deserialize_scalar<uint32_t>(res, is);
auto pq_n_centers = deserialize_scalar<uint32_t>(res, is);
auto pq_len = deserialize_scalar<uint32_t>(res, is);
auto encoded_row_length = deserialize_scalar<uint32_t>(res, is);

auto vq_code_book = make_device_matrix<MathT, uint32_t, row_major>(res, vq_n_centers, dim);
auto pq_code_book = make_device_matrix<MathT, uint32_t, row_major>(res, pq_n_centers, pq_len);
auto data = make_device_matrix<uint8_t, IdxT, row_major>(res, n_rows, encoded_row_length);

deserialize_mdspan(res, is, vq_code_book.view());
deserialize_mdspan(res, is, pq_code_book.view());
deserialize_mdspan(res, is, data.view());

return std::unique_ptr<vpq_dataset<MathT, IdxT>>{
new vpq_dataset{std::move(vq_code_book), std::move(pq_code_book), std::move(data)}}
.swap(out);
}

template <typename IdxT>
void deserialize(raft::resources const& res, std::istream& is, std::unique_ptr<dataset<IdxT>>& out)
{
switch (deserialize_scalar<dataset_instance_tag>(res, is)) {
case kSerializeEmptyDataset: {
std::unique_ptr<empty_dataset<IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
}
case kSerializeStridedDataset:
switch (deserialize_scalar<cudaDataType_t>(res, is)) {
case CUDA_R_32F: {
std::unique_ptr<strided_dataset<float, IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
}
case CUDA_R_16F: {
std::unique_ptr<strided_dataset<half, IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
}
case CUDA_R_8I: {
std::unique_ptr<strided_dataset<int8_t, IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
}
case CUDA_R_8U: {
std::unique_ptr<strided_dataset<uint8_t, IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
}
default: break;
}
case kSerializeVPQDataset:
switch (deserialize_scalar<cudaDataType_t>(res, is)) {
case CUDA_R_32F: {
std::unique_ptr<vpq_dataset<float, IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
}
case CUDA_R_16F: {
std::unique_ptr<vpq_dataset<half, IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
}
default: break;
}
default: break;
}
RAFT_FAIL("Failed to deserialize dataset: unsupported combination of instance tags.");
}

} // namespace raft::neighbors::detail

0 comments on commit 53a5c14

Please sign in to comment.