Skip to content

Commit

Permalink
Merge branch 'branch-23.10' into remove-block_size-from-CAGRA-2
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 authored Sep 20, 2023
2 parents f8253b8 + 6bbcf1f commit a383916
Show file tree
Hide file tree
Showing 107 changed files with 586 additions and 489 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ repos:
additional_dependencies: [toml]
args: ["--config=pyproject.toml"]
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v16.0.1
rev: v16.0.6
hooks:
- id: clang-format
types_or: [c, c++, cuda]
Expand Down
4 changes: 2 additions & 2 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ channels:
dependencies:
- breathe
- c-compiler
- clang-tools=16.0.1
- clang=16.0.1
- clang-tools=16.0.6
- clang=16.0.6
- cmake>=3.26.4
- cuda-profiler-api=11.8.86
- cuda-python>=11.7.1,<12.0a0
Expand Down
4 changes: 2 additions & 2 deletions conda/environments/all_cuda-120_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ channels:
dependencies:
- breathe
- c-compiler
- clang-tools=16.0.1
- clang=16.0.1
- clang-tools=16.0.6
- clang=16.0.6
- cmake>=3.26.4
- cuda-cudart-dev
- cuda-profiler-api
Expand Down
4 changes: 2 additions & 2 deletions conda/environments/bench_ann_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ channels:
dependencies:
- benchmark>=1.8.2
- c-compiler
- clang-tools=16.0.1
- clang=16.0.1
- clang-tools=16.0.6
- clang=16.0.6
- cmake>=3.26.4
- cuda-profiler-api=11.8.86
- cuda-version=11.8
Expand Down
2 changes: 2 additions & 0 deletions conda/recipes/raft-ann-bench/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ requirements:
- h5py {{ h5py_version }}
- benchmark
- matplotlib
# rmm is needed to determine if package is gpu-enabled
- rmm ={{ minor_version }}
- python
- pandas
- pyyaml
Expand Down
5 changes: 3 additions & 2 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,10 @@ void bench_search(::benchmark::State& state,
try {
algo->set_search_dataset(dataset->base_set(algo_property.dataset_memory_type),
dataset->base_set_size());
} catch (const std::exception&) {
} catch (const std::exception& ex) {
state.SkipWithError("The algorithm '" + index.name +
"' requires the base set, but it's not available.");
"' requires the base set, but it's not available. " +
"Exception: " + std::string(ex.what()));
return;
}
}
Expand Down
3 changes: 1 addition & 2 deletions cpp/bench/ann/src/ggnn/ggnn_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ template <typename T>
void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::Ggnn<T>::BuildParam& param)
{
param.dataset_size = conf.at("dataset_size");
param.k = conf.at("k");
param.k = conf.at("k");

if (conf.contains("k_build")) { param.k_build = conf.at("k_build"); }
if (conf.contains("segment_size")) { param.segment_size = conf.at("segment_size"); }
Expand Down
22 changes: 4 additions & 18 deletions cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ class Ggnn : public ANN<T> {
int num_layers{4}; // L
float tau{0.5};
int refine_iterations{2};

size_t dataset_size;
int k; // GGNN requires to know k during building
};

Expand Down Expand Up @@ -182,24 +180,17 @@ GgnnImpl<T, measure, D, KBuild, KQuery, S>::GgnnImpl(Metric metric,
}

if (dim != D) { throw std::runtime_error("mis-matched dim"); }

int device;
RAFT_CUDA_TRY(cudaGetDevice(&device));

ggnn_ = std::make_unique<GGNNGPUInstance>(
device, build_param_.dataset_size, build_param_.num_layers, true, build_param_.tau);
}

template <typename T, DistanceMeasure measure, int D, int KBuild, int KQuery, int S>
void GgnnImpl<T, measure, D, KBuild, KQuery, S>::build(const T* dataset,
size_t nrow,
cudaStream_t stream)
{
if (nrow != build_param_.dataset_size) {
throw std::runtime_error(
"build_param_.dataset_size = " + std::to_string(build_param_.dataset_size) +
" , but nrow = " + std::to_string(nrow));
}
int device;
RAFT_CUDA_TRY(cudaGetDevice(&device));
ggnn_ = std::make_unique<GGNNGPUInstance>(
device, nrow, build_param_.num_layers, true, build_param_.tau);

ggnn_->set_base_data(dataset);
ggnn_->set_stream(stream);
Expand All @@ -212,11 +203,6 @@ void GgnnImpl<T, measure, D, KBuild, KQuery, S>::build(const T* dataset,
template <typename T, DistanceMeasure measure, int D, int KBuild, int KQuery, int S>
void GgnnImpl<T, measure, D, KBuild, KQuery, S>::set_search_dataset(const T* dataset, size_t nrow)
{
if (nrow != build_param_.dataset_size) {
throw std::runtime_error(
"build_param_.dataset_size = " + std::to_string(build_param_.dataset_size) +
" , but nrow = " + std::to_string(nrow));
}
ggnn_->set_base_data(dataset);
}

Expand Down
6 changes: 4 additions & 2 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include <utility>
#include <vector>

#include <omp.h>

#include "../common/ann_types.hpp"
#include <hnswlib.h>

Expand Down Expand Up @@ -164,13 +166,13 @@ class HnswLib : public ANN<T> {
struct BuildParam {
int M;
int ef_construction;
int num_threads{1};
int num_threads = omp_get_num_procs();
};

using typename ANN<T>::AnnSearchParam;
struct SearchParam : public AnnSearchParam {
int ef;
int num_threads{1};
int num_threads = omp_get_num_procs();
};

HnswLib(Metric metric, int dim, const BuildParam& param);
Expand Down
96 changes: 79 additions & 17 deletions cpp/bench/prims/matrix/select_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <common/benchmark.hpp>

#include <raft/core/device_resources.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/random/rng.cuh>
#include <raft/sparse/detail/utils.h>
#include <raft/util/cudart_utils.hpp>
Expand All @@ -38,6 +39,19 @@
namespace raft::matrix {
using namespace raft::bench; // NOLINT

template <typename KeyT>
struct replace_with_mask {
KeyT replacement;
int64_t line_length;
int64_t spared_inputs;
constexpr auto inline operator()(int64_t offset, KeyT x, uint8_t mask) -> KeyT
{
auto i = offset % line_length;
// don't replace all the inputs, spare a few elements at the beginning of the input
return (mask && i >= spared_inputs) ? replacement : x;
}
};

template <typename KeyT, typename IdxT, select::Algo Algo>
struct selection : public fixture {
explicit selection(const select::params& p)
Expand Down Expand Up @@ -67,6 +81,21 @@ struct selection : public fixture {
}
}
raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), min_value, max_value);
if (p.frac_infinities > 0.0) {
rmm::device_uvector<uint8_t> mask_buf(p.batch_size * p.len, stream);
auto mask = make_device_vector_view<uint8_t, size_t>(mask_buf.data(), mask_buf.size());
raft::random::bernoulli(handle, state, mask, p.frac_infinities);
KeyT bound = p.select_min ? raft::upper_bound<KeyT>() : raft::lower_bound<KeyT>();
auto mask_in =
make_device_vector_view<const uint8_t, size_t>(mask_buf.data(), mask_buf.size());
auto dists_in = make_device_vector_view<const KeyT>(in_dists_.data(), in_dists_.size());
auto dists_out = make_device_vector_view<KeyT>(in_dists_.data(), in_dists_.size());
raft::linalg::map_offset(handle,
dists_out,
replace_with_mask<KeyT>{bound, int64_t(p.len), int64_t(p.k / 2)},
dists_in,
mask_in);
}
}

void run_benchmark(::benchmark::State& state) override // NOLINT
Expand All @@ -75,8 +104,12 @@ struct selection : public fixture {
std::ostringstream label_stream;
label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k;
if (params_.use_same_leading_bits) { label_stream << "#same-leading-bits"; }
if (params_.frac_infinities > 0) { label_stream << "#infs-" << params_.frac_infinities; }
state.SetLabel(label_stream.str());
loop_on_state(state, [this]() {
common::nvtx::range case_scope("%s - %s", state.name().c_str(), label_stream.str().c_str());
int iter = 0;
loop_on_state(state, [&iter, this]() {
common::nvtx::range lap_scope("lap-", iter++);
select::select_k_impl<KeyT, IdxT>(handle,
Algo,
in_dists_.data(),
Expand Down Expand Up @@ -149,6 +182,35 @@ const std::vector<select::params> kInputs{
{10, 1000000, 64, true, false, true},
{10, 1000000, 128, true, false, true},
{10, 1000000, 256, true, false, true},

{10, 1000000, 1, true, false, false, true, 0.1},
{10, 1000000, 16, true, false, false, true, 0.1},
{10, 1000000, 64, true, false, false, true, 0.1},
{10, 1000000, 128, true, false, false, true, 0.1},
{10, 1000000, 256, true, false, false, true, 0.1},

{10, 1000000, 1, true, false, false, true, 0.9},
{10, 1000000, 16, true, false, false, true, 0.9},
{10, 1000000, 64, true, false, false, true, 0.9},
{10, 1000000, 128, true, false, false, true, 0.9},
{10, 1000000, 256, true, false, false, true, 0.9},
{1000, 10000, 1, true, false, false, true, 0.9},
{1000, 10000, 16, true, false, false, true, 0.9},
{1000, 10000, 64, true, false, false, true, 0.9},
{1000, 10000, 128, true, false, false, true, 0.9},
{1000, 10000, 256, true, false, false, true, 0.9},

{10, 1000000, 1, true, false, false, true, 1.0},
{10, 1000000, 16, true, false, false, true, 1.0},
{10, 1000000, 64, true, false, false, true, 1.0},
{10, 1000000, 128, true, false, false, true, 1.0},
{10, 1000000, 256, true, false, false, true, 1.0},
{1000, 10000, 1, true, false, false, true, 1.0},
{1000, 10000, 16, true, false, false, true, 1.0},
{1000, 10000, 64, true, false, false, true, 1.0},
{1000, 10000, 128, true, false, false, true, 1.0},
{1000, 10000, 256, true, false, false, true, 1.0},
{1000, 10000, 256, true, false, false, true, 0.999},
};

#define SELECTION_REGISTER(KeyT, IdxT, A) \
Expand All @@ -157,28 +219,28 @@ const std::vector<select::params> kInputs{
RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \
}

SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix11bitsExtraPass); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT
SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix11bitsExtraPass); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT

SELECTION_REGISTER(double, uint32_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, uint32_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, uint32_t, kRadix11bitsExtraPass); // NOLINT
SELECTION_REGISTER(double, uint32_t, kWarpAuto); // NOLINT

SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, int64_t, kRadix11bitsExtraPass); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT
SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, int64_t, kRadix11bitsExtraPass); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT

// For learning a heuristic of which selection algorithm to use, we
// have a couple of additional constraints when generating the dataset:
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ __global__ void __launch_bounds__((WarpSize * BlockDimY))
adjust_centers_kernel(MathT* centers, // [n_clusters, dim]
IdxT n_clusters,
IdxT dim,
const T* dataset, // [n_rows, dim]
const T* dataset, // [n_rows, dim]
IdxT n_rows,
const LabelT* labels, // [n_rows]
const CounterT* cluster_sizes, // [n_clusters]
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ namespace numpy_serializer {

#if RAFT_SYSTEM_LITTLE_ENDIAN == 1
#define RAFT_NUMPY_HOST_ENDIAN_CHAR RAFT_NUMPY_LITTLE_ENDIAN_CHAR
#else // RAFT_SYSTEM_LITTLE_ENDIAN == 1
#else // RAFT_SYSTEM_LITTLE_ENDIAN == 1
#define RAFT_NUMPY_HOST_ENDIAN_CHAR RAFT_NUMPY_BIG_ENDIAN_CHAR
#endif // RAFT_SYSTEM_LITTLE_ENDIAN == 1

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/core/detail/nvtx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ inline void pop_range()

} // namespace raft::common::nvtx::detail

#else // NVTX_ENABLED
#else // NVTX_ENABLED

namespace raft::common::nvtx::detail {

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/core/kvp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ struct KeyValuePair {
typedef _Key Key; ///< Key data type
typedef _Value Value; ///< Value data type

Key key; ///< Item key
Value value; ///< Item value
Key key; ///< Item key
Value value; ///< Item value

/// Constructor
RAFT_INLINE_FUNCTION KeyValuePair() {}
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/core/resource/resource_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ enum resource_type {
THRUST_POLICY, // thrust execution policy
WORKSPACE_RESOURCE, // rmm device memory resource

LAST_KEY // reserved for the last key
LAST_KEY // reserved for the last key
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ class EpilogueWithBroadcastCustom : public EpilogueBase<Shape_,
TensorTileIterator
tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
MatrixCoord const&
problem_size = ///< Problem size needed to guard against out-of-bounds accesses
problem_size = ///< Problem size needed to guard against out-of-bounds accesses
MatrixCoord(Shape::kM, Shape::kN),
MatrixCoord const&
threadblock_offset = ///< Threadblock's initial offset within the problem size space
Expand All @@ -418,7 +418,7 @@ class EpilogueWithBroadcastCustom : public EpilogueBase<Shape_,
broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns
ElementVector const* broadcast_ptr, ///< Broadcast vector
MatrixCoord const&
problem_size, ///< Problem size needed to guard against out-of-bounds accesses
problem_size, ///< Problem size needed to guard against out-of-bounds accesses
MatrixCoord const&
threadblock_offset ///< Threadblock's initial offset within the problem size space
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ namespace threadblock {
///
/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator
///
template <typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_, ///< Element data type
template <typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_, ///< Element data type
typename Layout_,
bool ScatterD = false, ///< Scatter D operand or not
bool UseCUDAStore = false>
Expand Down
10 changes: 5 additions & 5 deletions cpp/include/raft/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

#pragma once

#include <cstddef> // size_t
#include <limits> // std::numeric_limits
#include <raft/core/kvp.hpp> // raft::KeyValuePair
#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/distance/detail/distance_ops/l2_exp.cuh> // ops::l2_exp_distance_op
#include <cstddef> // size_t
#include <limits> // std::numeric_limits
#include <raft/core/kvp.hpp> // raft::KeyValuePair
#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/distance/detail/distance_ops/l2_exp.cuh> // ops::l2_exp_distance_op
#include <raft/distance/detail/fused_distance_nn/cutlass_base.cuh>
#include <raft/distance/detail/pairwise_distance_base.cuh> // PairwiseDistances
#include <raft/linalg/contractions.cuh> // Policy
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/detail/masked_distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ struct MaskedDistances : public BaseClass {
} // tile_idx_n
} // idx_g
rowEpilog_op(tile_idx_m);
} // tile_idx_m
} // tile_idx_m
}

private:
Expand Down
Loading

0 comments on commit a383916

Please sign in to comment.