Skip to content

Commit

Permalink
Do specific cagra graph/dataset memory allocation in the benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Oct 20, 2023
1 parent b33363e commit 32e7b31
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 98 deletions.
17 changes: 17 additions & 0 deletions cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,21 @@ void parse_build_param(const nlohmann::json& conf,
if (conf.contains("nn_descent_niter")) { param.nn_descent_niter = conf.at("nn_descent_niter"); }
}

AllocatorType parse_allocator(std::string mem_type)
{
if (mem_type == "device") {
return AllocatorType::Device;
} else if (mem_type == "host_pinned") {
return AllocatorType::HostPinned;
} else if (mem_type == "host_huge_page") {
return AllocatorType::HostHugePage;
}
THROW(
"Invalid value for memory type %s, must be one of [\"device\", \"host_pinned\", "
"\"host_huge_page\"",
mem_type.c_str());
}

template <typename T, typename IdxT>
void parse_search_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftCagra<T, IdxT>::SearchParam& param)
Expand All @@ -178,6 +193,8 @@ void parse_search_param(const nlohmann::json& conf,
THROW("Invalid value for algo: %s", tmp.c_str());
}
}
if (conf.contains("graph_mem")) { param.graph_mem = parse_allocator(conf.at("graph_mem")); }
if (conf.contains("dataset_mem")) { param.dataset_mem = parse_allocator(conf.at("dataset_mem")); }
}
#endif

Expand Down
93 changes: 90 additions & 3 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,24 @@
#include "raft_ann_bench_utils.h"
#include <raft/util/cudart_utils.hpp>

#include "../common/cuda_huge_page_resource.hpp"
#include "../common/cuda_pinned_resource.hpp"

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

namespace raft::bench::ann {

enum class AllocatorType { HostPinned, HostHugePage, Device };
template <typename T, typename IdxT>
class RaftCagra : public ANN<T> {
public:
using typename ANN<T>::AnnSearchParam;

struct SearchParam : public AnnSearchParam {
raft::neighbors::experimental::cagra::search_params p;
AllocatorType graph_mem = AllocatorType::Device;
AllocatorType dataset_mem = AllocatorType::Device;
auto needs_dataset() const -> bool override { return true; }
};

Expand All @@ -56,7 +65,11 @@ class RaftCagra : public ANN<T> {
: ANN<T>(metric, dim),
index_params_(param),
dimension_(dim),
mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull)
mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull),
need_dataset_update_(true),
dataset_(make_device_matrix<T, int64_t>(handle_, 0, 0)),
graph_(make_device_matrix<IdxT, int64_t>(handle_, 0, 0)),
graph_mem_(AllocatorType::Device)
{
rmm::mr::set_current_device_resource(&mr_);
index_params_.metric = parse_metric_type(metric);
Expand Down Expand Up @@ -92,14 +105,29 @@ class RaftCagra : public ANN<T> {
void load(const std::string&) override;

private:
inline rmm::mr::device_memory_resource* get_mr(AllocatorType mem_type)
{
switch (mem_type) {
case (AllocatorType::HostPinned): return &mr_pinned_;
case (AllocatorType::HostHugePage): return &mr_huge_page_;
default: return rmm::mr::get_current_device_resource();
}
}
// `mr_` must go first to make sure it dies last
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> mr_;
rmm ::mr::cuda_pinned_resource mr_pinned_;
rmm ::mr::cuda_huge_page_resource mr_huge_page_;
raft::device_resources handle_;
AllocatorType graph_mem_;
AllocatorType dataset_mem_;
BuildParam index_params_;
bool need_dataset_update_;
raft::neighbors::cagra::search_params search_params_;
std::optional<raft::neighbors::cagra::index<T, IdxT>> index_;
int device_;
int dimension_;
raft::device_matrix<IdxT, int64_t, row_major> graph_;
raft::device_matrix<T, int64_t, row_major> dataset_;
};

template <typename T, typename IdxT>
Expand All @@ -118,18 +146,77 @@ void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t)
}
}

inline std::string allocator_to_string(AllocatorType mem_type)
{
if (mem_type == AllocatorType::Device) {
return "device";
} else if (mem_type == AllocatorType::HostPinned) {
return "host_pinned";
} else if (mem_type == AllocatorType::HostHugePage) {
return "host_huge_page";
}
return "<invalid allocator type>";
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::set_search_param(const AnnSearchParam& param)
{
auto search_param = dynamic_cast<const SearchParam&>(param);
search_params_ = search_param.p;
if (search_param.graph_mem != graph_mem_) {
// Move graph to correct memory space
graph_mem_ = search_param.graph_mem;
std::cout << "Moving graph to new memory space " << allocator_to_string(graph_mem_)
<< std::endl;
// We create a new graph and copy to it from existing graph
auto mr = get_mr(graph_mem_);
auto new_graph = make_device_mdarray<IdxT, int64_t>(
handle_, mr, make_extents<int64_t>(index_->graph().extent(0), index_->graph_degree()));

std::cout << "new_grap " << new_graph.extent(0) << "x" << new_graph.extent(1) << std::endl;
std::cout << "graph size " << index_->graph().size() << std::endl;
raft::copy(new_graph.data_handle(),
index_->graph().data_handle(),
index_->graph().size(),
resource::get_cuda_stream(handle_));

index_->update_graph(handle_, make_const_mdspan(new_graph.view()));
// update_graph() only stores a view in the index. We need to keep the graph object alive.
graph_ = std::move(new_graph);
}

if (search_param.dataset_mem != dataset_mem_) {
need_dataset_update_ = true;
dataset_mem_ = search_param.dataset_mem;
}
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::set_search_dataset(const T* dataset, size_t nrow)
{
index_->update_dataset(handle_,
raft::make_host_matrix_view<const T, int64_t>(dataset, nrow, this->dim_));
// It can happen that we are re-using a previous algo object which already has
// the dataset set. Check if we need update.
if (index_->size() != nrow || need_dataset_update_) {
// First free up existing memory
dataset_ = make_device_matrix<T, int64_t>(handle_, 0, 0);
index_->update_dataset(handle_, make_const_mdspan(dataset_.view()));

// Allocate space using the correcct memory resource
auto mr = get_mr(dataset_mem_);

std::cout << "Moving dataset to new memory space " << allocator_to_string(dataset_mem_)
<< std::endl;
auto input_dataset_view = make_device_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
raft::neighbors::cagra::detail::copy_with_padding(handle_, dataset_, input_dataset_view, mr);

index_->update_dataset(handle_, make_const_mdspan(dataset_.view()));

// Ideally, instead of dataset_.view(), we should pass a strided matrix view to update.
// auto dataset_view = make_device_strided_matrix_view<const T, int64_t>(
// dataset_.data_handle(), dataset_.extent(0), this->dim_, dataset_.extent(1));
// index_->update_dataset(handle_, dataset_view);
need_dataset_update_ = false;
}
}

template <typename T, typename IdxT>
Expand Down
102 changes: 8 additions & 94 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/detail/cagra/utils.hpp>
#include <raft/util/integer_utils.hpp>

#include <memory>
Expand All @@ -33,12 +34,6 @@
#include <thrust/fill.h>
#include <type_traits>

#include "cuda_huge_page_resource.hpp"
#include "cuda_pinned_resource.hpp"

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

#include <raft/core/logger.hpp>
namespace raft::neighbors::cagra {
/**
Expand Down Expand Up @@ -189,13 +184,9 @@ struct index : ann::index {
index(raft::resources const& res,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
: ann::index(),
mr_(new rmm::mr::cuda_pinned_resource()),
mr_huge_(new rmm::mr::cuda_huge_page_resource()),
metric_(raft::distance::DistanceType::L2Expanded),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
dataset_pinned_(0, resource::get_cuda_stream(res), mr_.get()),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0)),
graph_pinned_(0, resource::get_cuda_stream(res), mr_.get())
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
}

Expand Down Expand Up @@ -257,64 +248,16 @@ struct index : ann::index {
index(raft::resources const& res,
raft::distance::DistanceType metric,
mdspan<const T, matrix_extent<int64_t>, row_major, data_accessor> dataset,
mdspan<const IdxT, matrix_extent<int64_t>, row_major, graph_accessor> knn_graph,
bool graph_pinned = true,
bool data_pinned = true)
mdspan<const IdxT, matrix_extent<int64_t>, row_major, graph_accessor> knn_graph)
: ann::index(),
mr_(new rmm::mr::cuda_pinned_resource()),
mr_huge_(new rmm::mr::cuda_huge_page_resource()),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
dataset_pinned_(0, resource::get_cuda_stream(res), mr_huge_.get()),
// dataset_pinned_(0, resource::get_cuda_stream(res), mr_.get()),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0)),
graph_pinned_(0, resource::get_cuda_stream(res), mr_huge_.get())
// graph_pinned_(0, resource::get_cuda_stream(res), mr_.get())
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
"Dataset and knn_graph must have equal number of rows");
if (data_pinned) {
// copy with padding
int64_t aligned_dim = round_up_safe<size_t>(dataset.extent(1) * sizeof(T), 16) / sizeof(T);
dataset_pinned_.resize(dataset.extent(0) * aligned_dim, resource::get_cuda_stream(res));
resource::sync_stream(res);

RAFT_LOG_INFO("Allocated pinned dataset");

memset(dataset_pinned_.data(), 0, dataset_pinned_.size() * sizeof(T));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(dataset_pinned_.data(),
sizeof(T) * aligned_dim,
dataset.data_handle(),
sizeof(T) * dataset.extent(1),
sizeof(T) * dataset.extent(1),
dataset.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));

dataset_view_ = make_device_strided_matrix_view<const T, int64_t>(
dataset_pinned_.data(), dataset.extent(0), dataset.extent(1), aligned_dim);
RAFT_LOG_INFO("CAGRA dataset strided matrix view %zux%zu, stride %zu",
static_cast<size_t>(dataset_view_.extent(0)),
static_cast<size_t>(dataset_view_.extent(1)),
static_cast<size_t>(dataset_view_.stride(0)));
} else {
update_dataset(res, dataset);
}
if (graph_pinned) {
graph_pinned_.resize(knn_graph.size(), resource::get_cuda_stream(res));
resource::sync_stream(res);
RAFT_LOG_INFO("Allocated pinned graph");

memset(graph_pinned_.data(), 0, sizeof(IdxT) * graph_pinned_.size());
graph_view_ = make_device_matrix_view<IdxT, int64_t, row_major>(
graph_pinned_.data(), knn_graph.extent(0), knn_graph.extent(1));
raft::copy(graph_pinned_.data(),
knn_graph.data_handle(),
knn_graph.size(),
resource::get_cuda_stream(res));
} else {
update_graph(res, knn_graph);
}
update_dataset(res, dataset);
update_graph(res, knn_graph);
resource::sync_stream(res);
}

Expand Down Expand Up @@ -388,32 +331,8 @@ struct index : ann::index {
void copy_padded(raft::resources const& res,
mdspan<const T, matrix_extent<int64_t>, row_major, data_accessor> dataset)
{
size_t padded_dim = round_up_safe<size_t>(dataset.extent(1) * sizeof(T), 16) / sizeof(T);
detail::copy_with_padding(res, dataset_, dataset);

if ((dataset_.extent(0) != dataset.extent(0)) ||
(static_cast<size_t>(dataset_.extent(1)) != padded_dim)) {
// clear existing memory before allocating to prevent OOM errors on large datasets
if (dataset_.size()) { dataset_ = make_device_matrix<T, int64_t>(res, 0, 0); }
dataset_ = make_device_matrix<T, int64_t>(res, dataset.extent(0), padded_dim);
}
if (dataset_.extent(1) == dataset.extent(1)) {
raft::copy(dataset_.data_handle(),
dataset.data_handle(),
dataset.size(),
resource::get_cuda_stream(res));
} else {
// copy with padding
RAFT_CUDA_TRY(cudaMemsetAsync(
dataset_.data_handle(), 0, dataset_.size() * sizeof(T), resource::get_cuda_stream(res)));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(dataset_.data_handle(),
sizeof(T) * dataset_.extent(1),
dataset.data_handle(),
sizeof(T) * dataset.extent(1),
sizeof(T) * dataset.extent(1),
dataset.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
}
dataset_view_ = make_device_strided_matrix_view<const T, int64_t>(
dataset_.data_handle(), dataset_.extent(0), dataset.extent(1), dataset_.extent(1));
RAFT_LOG_DEBUG("CAGRA dataset strided matrix view %zux%zu, stride %zu",
Expand All @@ -422,14 +341,9 @@ struct index : ann::index {
static_cast<size_t>(dataset_view_.stride(0)));
}

private:
std::unique_ptr<rmm ::mr::cuda_pinned_resource> mr_;
std::unique_ptr<rmm ::mr::cuda_huge_page_resource> mr_huge_;
raft::distance::DistanceType metric_;
raft::device_matrix<T, int64_t, row_major> dataset_;
rmm::device_uvector<T> dataset_pinned_;
raft::device_matrix<IdxT, int64_t, row_major> graph_;
rmm::device_uvector<IdxT> graph_pinned_;
raft::device_matrix_view<const T, int64_t, layout_stride> dataset_view_;
raft::device_matrix_view<const IdxT, int64_t, row_major> graph_view_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct check_index_layout {
"paste in the new size and consider updating the serialization logic");
};

constexpr size_t expected_size = 296;
constexpr size_t expected_size = 200;
template struct check_index_layout<sizeof(index<double, std::uint64_t>), expected_size>;

/**
Expand Down
34 changes: 34 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <raft/core/detail/macros.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/util/integer_utils.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <type_traits>

namespace raft::neighbors::cagra::detail {
Expand Down Expand Up @@ -245,4 +247,36 @@ class host_matrix_view_from_device {
device_matrix_view<T, IdxT> device_view_;
T* host_ptr;
};

// Copy matrix src to dst. pad rows with 0 if necessary to make them 16 byte aligned.
template <typename T, typename data_accessor>
void copy_with_padding(raft::resources const& res,
raft::device_matrix<T, int64_t, row_major>& dst,
mdspan<const T, matrix_extent<int64_t>, row_major, data_accessor> src,
rmm::mr::device_memory_resource* mr = nullptr)
{
if (!mr) { mr = rmm::mr::get_current_device_resource(); }
size_t padded_dim = round_up_safe<size_t>(src.extent(1) * sizeof(T), 16) / sizeof(T);

if ((dst.extent(0) != src.extent(0)) || (static_cast<size_t>(dst.extent(1)) != padded_dim)) {
// clear existing memory before allocating to prevent OOM errors on large datasets
if (dst.size()) { dst = make_device_matrix<T, int64_t>(res, 0, 0); }
dst = make_device_mdarray<T>(res, mr, make_extents<int64_t>(src.extent(0), padded_dim));
}
if (dst.extent(1) == src.extent(1)) {
raft::copy(dst.data_handle(), src.data_handle(), src.size(), resource::get_cuda_stream(res));
} else {
// copy with padding
RAFT_CUDA_TRY(cudaMemsetAsync(
dst.data_handle(), 0, dst.size() * sizeof(T), resource::get_cuda_stream(res)));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(),
sizeof(T) * dst.extent(1),
src.data_handle(),
sizeof(T) * src.extent(1),
sizeof(T) * src.extent(1),
src.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
}
}
} // namespace raft::neighbors::cagra::detail

0 comments on commit 32e7b31

Please sign in to comment.