Skip to content

Commit

Permalink
Merge branch 'branch-23.12' into CAGRA-remove-max_dim-template-param
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 authored Nov 19, 2023
2 parents 2ff87ef + 3a7f33f commit 4ea6a39
Show file tree
Hide file tree
Showing 20 changed files with 1,365 additions and 146 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ on:
default: nightly

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name }}
cancel-in-progress: true

jobs:
Expand Down
2 changes: 1 addition & 1 deletion conda/recipes/pylibraft/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ requirements:
- cython >=3.0.0
- libraft {{ version }}
- libraft-headers {{ version }}
- numpy >=1.21
- python x.x
- rmm ={{ minor_version }}
- scikit-build >=0.13.1
Expand All @@ -60,6 +59,7 @@ requirements:
{% endif %}
- libraft {{ version }}
- libraft-headers {{ version }}
- numpy >=1.21
- python x.x
- rmm ={{ minor_version }}

Expand Down
5 changes: 3 additions & 2 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ
OR RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB
)
set(RAFT_ANN_BENCH_USE_RAFT ON)
endif()
Expand Down Expand Up @@ -263,7 +263,8 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib
LINKS
raft::compiled
CXXFLAGS "${HNSW_CXX_FLAGS}"
CXXFLAGS
"${HNSW_CXX_FLAGS}"
)
endif()

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct l2_exp_cutlass_op {

__device__ l2_exp_cutlass_op() noexcept : sqrt(false) {}
__device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {}
__device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept
inline __device__ AccT operator()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept
{
AccT outVal = aNorm + bNorm - DataT(2.0) * accVal;

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/brute_force-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void search(raft::resources const& res,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances) RAFT_EXPLICIT;
raft::device_matrix_view<T, int64_t, row_major> distances) RAFT_EXPLICIT;

template <typename idx_t,
typename value_t,
Expand Down
31 changes: 2 additions & 29 deletions cpp/include/raft/neighbors/brute_force-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -346,36 +346,9 @@ void search(raft::resources const& res,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances)
raft::device_matrix_view<T, int64_t, row_major> distances)
{
RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs");
RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1),
"Number of columns in queries must match brute force index");

auto k = neighbors.extent(1);
auto d = idx.dataset().extent(1);

std::vector<T*> dataset = {const_cast<T*>(idx.dataset().data_handle())};
std::vector<int64_t> sizes = {idx.dataset().extent(0)};
std::vector<T*> norms;
if (idx.has_norms()) { norms.push_back(const_cast<T*>(idx.norms().data_handle())); }

detail::brute_force_knn_impl<int64_t, IdxT, T>(res,
dataset,
sizes,
d,
const_cast<T*>(queries.data_handle()),
queries.extent(0),
neighbors.data_handle(),
distances.data_handle(),
k,
true,
true,
nullptr,
idx.metric(),
idx.metric_arg(),
raft::identity_op(),
norms.size() ? &norms : nullptr);
raft::neighbors::detail::brute_force_search<T, IdxT>(res, idx, queries, neighbors, distances);
}
/** @} */ // end group brute_force_knn
} // namespace raft::neighbors::brute_force
68 changes: 68 additions & 0 deletions cpp/include/raft/neighbors/brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
#pragma once
#include <memory>

#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "brute_force-inl.cuh"
Expand All @@ -22,3 +23,70 @@
#ifdef RAFT_COMPILED
#include "brute_force-ext.cuh"
#endif

#include <raft/neighbors/detail/knn_brute_force_batch_k_query.cuh>

namespace raft::neighbors::brute_force {
/**
* @brief Make a brute force query over batches of k
*
* This lets you query for batches of k. For example, you can get
* the first 100 neighbors, then the next 100 neighbors etc.
*
* Example usage:
* @code{.cpp}
* #include <raft/neighbors/brute_force.cuh>
* #include <raft/core/device_mdarray.hpp>
* #include <raft/random/make_blobs.cuh>
* // create a random dataset
* int n_rows = 10000;
* int n_cols = 10000;
* raft::device_resources res;
* auto dataset = raft::make_device_matrix<float, int>(res, n_rows, n_cols);
* auto labels = raft::make_device_vector<float, int>(res, n_rows);
* raft::make_blobs(res, dataset.view(), labels.view());
*
* // create a brute_force knn index from the dataset
* auto index = raft::neighbors::brute_force::build(res,
* raft::make_const_mdspan(dataset.view()));
*
* // search the index in batches of 128 nearest neighbors
* auto search = raft::make_const_mdspan(dataset.view());
* auto query = make_batch_k_query<float, int>(res, index, search, 128);
* for (auto & batch: *query) {
* // batch.indices() and batch.distances() contain the information on the current batch
* }
*
* // we can also support variable sized batches - loaded up a different number
* // of neighbors at each iteration through the ::advance method
* int64_t batch_size = 128;
* query = make_batch_k_query<float, int>(res, index, search, batch_size);
* for (auto it = query->begin(); it != query->end(); it.advance(batch_size)) {
* // batch.indices() and batch.distances() contain the information on the current batch
*
* batch_size += 16; // load up an extra 16 items in the next batch
* }
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
* @param[in] res
* @param[in] index The index to query
* @param[in] query A device matrix view to query for [n_queries, index->dim()]
* @param[in] batch_size The size of each batch
*/

template <typename T, typename IdxT>
std::shared_ptr<batch_k_query<T, IdxT>> make_batch_k_query(
const raft::resources& res,
const raft::neighbors::brute_force::index<T>& index,
raft::device_matrix_view<const T, int64_t, row_major> query,
int64_t batch_size)
{
return std::shared_ptr<batch_k_query<T, IdxT>>(
new detail::gpu_batch_k_query<T, IdxT>(res, index, query, batch_size));
}
} // namespace raft::neighbors::brute_force
119 changes: 118 additions & 1 deletion cpp/include/raft/neighbors/brute_force_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/neighbors_types.hpp>

#include <raft/core/logger.hpp>

Expand Down Expand Up @@ -69,7 +70,7 @@ struct index : ann::index {
return norms_view_.value();
}

/** Whether ot not this index has dataset norms */
/** Whether or not this index has dataset norms */
[[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); }

[[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; }
Expand Down Expand Up @@ -160,6 +161,122 @@ struct index : ann::index {
T metric_arg_;
};

/**
* @brief Interface for performing queries over values of k
*
* This interface lets you iterate over batches of k from a brute_force::index.
* This lets you do things like retrieve the first 100 neighbors for a query,
* apply post processing to remove any unwanted items and then if needed get the
* next 100 closest neighbors for the query.
*
* This query interface exposes C++ iterators through the ::begin and ::end, and
* is compatible with range based for loops.
*
* Note that this class is an abstract class without any cuda dependencies, meaning
* that it doesn't require a cuda compiler to use - but also means it can't be directly
* instantiated. See the raft::neighbors::brute_force::make_batch_k_query
* function for usage examples.
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*/
template <typename T, typename IdxT = int64_t>
class batch_k_query {
public:
batch_k_query(const raft::resources& res,
int64_t index_size,
int64_t query_size,
int64_t batch_size)
: res(res), index_size(index_size), query_size(query_size), batch_size(batch_size)
{
}
virtual ~batch_k_query() {}

using value_type = raft::neighbors::batch<T, IdxT>;

class iterator {
public:
using value_type = raft::neighbors::batch<T, IdxT>;
using reference = const value_type&;
using pointer = const value_type*;

iterator(const batch_k_query<T, IdxT>* query, int64_t offset = 0)
: current(query->res, 0, 0), batches(query->res, 0, 0), query(query), offset(offset)
{
query->load_batch(offset, query->batch_size, &batches);
query->slice_batch(batches, offset, query->batch_size, &current);
}

reference operator*() const { return current; }

pointer operator->() const { return &current; }

iterator& operator++()
{
advance(query->batch_size);
return *this;
}

iterator operator++(int)
{
iterator previous(*this);
operator++();
return previous;
}

/**
* @brief Advance the iterator, using a custom size for the next batch
*
* Using operator++ means that we will load up the same batch_size for each
* batch. This method allows us to get around this restriction, and load up
* arbitrary batch sizes on each iteration.
* See raft::neighbors::brute_force::make_batch_k_query for a usage example.
*
* @param[in] next_batch_size: size of the next batch to load up
*/
void advance(int64_t next_batch_size)
{
offset = std::min(offset + current.batch_size(), query->index_size);
if (offset + next_batch_size > batches.batch_size()) {
query->load_batch(offset, next_batch_size, &batches);
}
query->slice_batch(batches, offset, next_batch_size, &current);
}

friend bool operator==(const iterator& lhs, const iterator& rhs)
{
return (lhs.query == rhs.query) && (lhs.offset == rhs.offset);
};
friend bool operator!=(const iterator& lhs, const iterator& rhs) { return !(lhs == rhs); };

protected:
// the current batch of data
value_type current;

// the currently loaded group of data (containing multiple batches of data that we can iterate
// through)
value_type batches;

const batch_k_query<T, IdxT>* query;
int64_t offset, current_batch_size;
};

iterator begin() const { return iterator(this); }
iterator end() const { return iterator(this, index_size); }

protected:
// these two methods need cuda code, and are implemented in the subclass
virtual void load_batch(int64_t offset,
int64_t next_batch_size,
batch<T, IdxT>* output) const = 0;
virtual void slice_batch(const value_type& input,
int64_t offset,
int64_t batch_size,
value_type* output) const = 0;

const raft::resources& res;
int64_t index_size, query_size, batch_size;
};
/** @} */

} // namespace raft::neighbors::brute_force
66 changes: 66 additions & 0 deletions cpp/include/raft/neighbors/detail/div_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef _RAFT_HAS_CUDA
#include <raft/util/pow2_utils.cuh>
#else
#include <raft/util/integer_utils.hpp>
#endif

/**
* @brief A simple wrapper for raft::Pow2 which uses Pow2 utils only when available and regular
* integer division otherwise. This is done to allow a common interface for division arithmetic for
* non CUDA headers.
*
* @tparam Value_ a compile-time value representable as a power-of-two.
*/
namespace raft::neighbors::detail {
template <auto Value_>
struct div_utils {
typedef decltype(Value_) Type;
static constexpr Type Value = Value_;

template <typename T>
static constexpr _RAFT_HOST_DEVICE inline auto roundDown(T x)
{
#if defined(_RAFT_HAS_CUDA)
return Pow2<Value_>::roundDown(x);
#else
return raft::round_down_safe(x, Value_);
#endif
}

template <typename T>
static constexpr _RAFT_HOST_DEVICE inline auto mod(T x)
{
#if defined(_RAFT_HAS_CUDA)
return Pow2<Value_>::mod(x);
#else
return x % Value_;
#endif
}

template <typename T>
static constexpr _RAFT_HOST_DEVICE inline auto div(T x)
{
#if defined(_RAFT_HAS_CUDA)
return Pow2<Value_>::div(x);
#else
return x / Value_;
#endif
}
};
} // namespace raft::neighbors::detail
Loading

0 comments on commit 4ea6a39

Please sign in to comment.