Skip to content

Commit

Permalink
add c++ test and fix dimension check
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Nov 29, 2023
1 parent f08da39 commit 89d1f10
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 3 deletions.
1 change: 0 additions & 1 deletion cpp/include/raft/neighbors/ball_cover-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ template <typename idx_t, typename value_t, typename int_t, typename matrix_idx_
void build_index(raft::resources const& handle,
BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index)
{
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
if (index.metric == raft::distance::DistanceType::Haversine) {
raft::spatial::knn::detail::rbc_build_index(
handle, index, spatial::knn::detail::HaversineFunc<value_t, int_t>());
Expand Down
1 change: 0 additions & 1 deletion cpp/include/raft/spatial/knn/detail/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ void rbc_build_index(raft::resources const& handle,
BallCoverIndex<value_idx, value_t, value_int>& index,
distance_func dfunc)
{
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
ASSERT(!index.is_index_trained(), "index cannot be previously trained");

rmm::device_uvector<value_idx> R_knn_inds(index.m, resource::get_cuda_stream(handle));
Expand Down
300 changes: 299 additions & 1 deletion cpp/test/neighbors/epsilon_neighborhood.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
#include <memory>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/neighbors/ball_cover.cuh>
#include <raft/random/make_blobs.cuh>
#include <raft/sparse/convert/csr.cuh>
#include <raft/spatial/knn/epsilon_neighborhood.cuh>
#include <raft/util/cudart_utils.hpp>
#include <rmm/device_uvector.hpp>
Expand Down Expand Up @@ -83,19 +85,25 @@ class EpsNeighTest : public ::testing::TestWithParam<EpsInputs<T, IdxT>> {
}; // class EpsNeighTest

const std::vector<EpsInputs<float, int>> inputsfi = {
{100, 16, 5, 2, 2.f},
{1500, 16, 5, 3, 2.f},
{15000, 16, 5, 1, 2.f},
{15000, 3, 5, 1, 2.f},
{14000, 16, 5, 1, 2.f},
{15000, 17, 5, 1, 2.f},
{14000, 17, 5, 1, 2.f},
{15000, 18, 5, 1, 2.f},
{14000, 18, 5, 1, 2.f},
{15000, 32, 5, 1, 2.f},
{14000, 32, 5, 1, 2.f},
{14000, 32, 5, 10, 2.f},
{20000, 10000, 10, 1, 2.f},
{20000, 10000, 10, 2, 2.f},
};

typedef EpsNeighTest<float, int> EpsNeighTestFI;
TEST_P(EpsNeighTestFI, Result)

TEST_P(EpsNeighTestFI, ResultBruteForce)
{
for (int i = 0; i < param.n_batches; ++i) {
RAFT_CUDA_TRY(cudaMemsetAsync(adj.data(), 0, sizeof(bool) * param.n_row * batchSize, stream));
Expand All @@ -114,8 +122,298 @@ TEST_P(EpsNeighTestFI, Result)
param.n_row / param.n_centers, vd.data(), batchSize, raft::Compare<int>(), stream));
}
}

INSTANTIATE_TEST_CASE_P(EpsNeighTests, EpsNeighTestFI, ::testing::ValuesIn(inputsfi));

// rbc examples take fewer points as correctness checks are very costly
const std::vector<EpsInputs<float, int>> inputsfi_rbc = {
{100, 16, 5, 2, 2.f},
{1500, 16, 5, 3, 2.f},
{1500, 16, 5, 1, 2.f},
{1500, 3, 5, 1, 2.f},
{1400, 16, 5, 1, 2.f},
{1500, 17, 5, 1, 2.f},
{1400, 17, 5, 1, 2.f},
{1500, 18, 5, 1, 2.f},
{1400, 18, 5, 1, 2.f},
{1500, 32, 5, 1, 2.f},
{1400, 32, 5, 1, 2.f},
{1400, 32, 5, 10, 2.f},
{2000, 1000, 10, 1, 2.f},
{2000, 1000, 10, 2, 2.f},
};

typedef EpsNeighTest<float, int> EpsNeighRbcTestFI;

TEST_P(EpsNeighRbcTestFI, DenseRbc)
{
rmm::device_uvector<bool> adj_baseline(param.n_row * batchSize,
resource::get_cuda_stream(handle));

raft::neighbors::ball_cover::BallCoverIndex<int, float, int> rbc_index(
handle, data.data(), param.n_row, param.n_col, raft::distance::DistanceType::L2SqrtUnexpanded);
raft::neighbors::ball_cover::build_index(handle, rbc_index);

for (int i = 0; i < param.n_batches; ++i) {
// invalidate
RAFT_CUDA_TRY(cudaMemsetAsync(adj.data(), 1, sizeof(bool) * param.n_row * batchSize, stream));
RAFT_CUDA_TRY(cudaMemsetAsync(vd.data(), 1, sizeof(int) * (batchSize + 1), stream));
RAFT_CUDA_TRY(
cudaMemsetAsync(adj_baseline.data(), 1, sizeof(bool) * param.n_row * batchSize, stream));

float* query = data.data() + (i * batchSize * param.n_col);

raft::neighbors::ball_cover::epsUnexpL2NeighborhoodRbc<int, float, int>(handle,
rbc_index,
adj.data(),
vd.data(),
query,
batchSize,
param.n_col,
param.eps * param.eps);

ASSERT_TRUE(raft::devArrMatch(
param.n_row / param.n_centers, vd.data(), batchSize, raft::Compare<int>(), stream));

// compute baseline via brute force + compare
epsUnexpL2SqNeighborhood<float, int>(adj_baseline.data(),
nullptr,
query,
data.data(),
batchSize,
param.n_row,
param.n_col,
param.eps * param.eps,
stream);

ASSERT_TRUE(raft::devArrMatch(
adj_baseline.data(), adj.data(), batchSize, param.n_row, raft::Compare<bool>(), stream));

// re-compute without vd
raft::neighbors::ball_cover::epsUnexpL2NeighborhoodRbc<int, float, int>(
handle, rbc_index, adj.data(), nullptr, query, batchSize, param.n_col, param.eps * param.eps);
ASSERT_TRUE(raft::devArrMatch(
adj_baseline.data(), adj.data(), batchSize, param.n_row, raft::Compare<bool>(), stream));
}
}

template <typename T>
testing::AssertionResult assertCsrEqualUnordered(
T* ia_exp, T* ja_exp, T* ia_act, T* ja_act, size_t rows, size_t cols, cudaStream_t stream)
{
std::unique_ptr<T[]> ia_exp_h(new T[rows + 1]);
std::unique_ptr<T[]> ia_act_h(new T[rows + 1]);
raft::update_host<T>(ia_exp_h.get(), ia_exp, rows + 1, stream);
raft::update_host<T>(ia_act_h.get(), ia_act, rows + 1, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));

size_t nnz = ia_exp_h.get()[rows];
std::unique_ptr<T[]> ja_exp_h(new T[nnz]);
std::unique_ptr<T[]> ja_act_h(new T[nnz]);
raft::update_host<T>(ja_exp_h.get(), ja_exp, nnz, stream);
raft::update_host<T>(ja_act_h.get(), ja_act, nnz, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));

for (size_t i(0); i < rows; ++i) {
auto row_start = ia_exp_h.get()[i];
auto row_end = ia_exp_h.get()[i + 1];

// sort ja's
std::sort(ja_exp_h.get() + row_start, ja_exp_h.get() + row_end);
std::sort(ja_act_h.get() + row_start, ja_act_h.get() + row_end);

for (size_t idx(row_start); idx < (size_t)row_end; ++idx) {
auto exp = ja_exp_h.get()[idx];
auto act = ja_act_h.get()[idx];
if (exp != act) {
return testing::AssertionFailure()
<< "actual=" << act << " != expected=" << exp << " @" << i << "," << idx;
}
}
}
return testing::AssertionSuccess();
}

TEST_P(EpsNeighRbcTestFI, SparseRbc)
{
rmm::device_uvector<int> adj_ia(batchSize + 1, resource::get_cuda_stream(handle));
rmm::device_uvector<int> adj_ja(param.n_row * batchSize, resource::get_cuda_stream(handle));

rmm::device_uvector<int> vd_baseline(batchSize + 1, resource::get_cuda_stream(handle));
rmm::device_uvector<int> adj_ia_baseline(batchSize + 1, resource::get_cuda_stream(handle));
rmm::device_uvector<int> adj_ja_baseline(param.n_row * batchSize,
resource::get_cuda_stream(handle));

raft::neighbors::ball_cover::BallCoverIndex<int, float, int> rbc_index(
handle, data.data(), param.n_row, param.n_col, raft::distance::DistanceType::L2SqrtUnexpanded);
raft::neighbors::ball_cover::build_index(handle, rbc_index);

for (int i = 0; i < param.n_batches; ++i) {
// reset full array -- that way we can compare the full size
RAFT_CUDA_TRY(cudaMemsetAsync(adj_ja.data(), 0, sizeof(int) * param.n_row * batchSize, stream));
RAFT_CUDA_TRY(
cudaMemsetAsync(adj_ja_baseline.data(), 0, sizeof(int) * param.n_row * batchSize, stream));

float* query = data.data() + (i * batchSize * param.n_col);

// compute dense baseline and convert adj to csr
{
raft::neighbors::ball_cover::epsUnexpL2NeighborhoodRbc<int, float, int>(
handle,
rbc_index,
adj.data(),
vd_baseline.data(),
query,
batchSize,
param.n_col,
param.eps * param.eps);
thrust::exclusive_scan(resource::get_thrust_policy(handle),
vd_baseline.data(),
vd_baseline.data() + batchSize + 1,
adj_ia_baseline.data());
raft::sparse::convert::adj_to_csr(handle,
adj.data(),
adj_ia_baseline.data(),
batchSize,
param.n_row,
labels.data(),
adj_ja_baseline.data());
}

// exact computation with 2 passes
{
raft::neighbors::ball_cover::epsUnexpL2NeighborhoodRbc<int, float, int>(
handle,
rbc_index,
adj_ia.data(),
nullptr,
vd.data(),
query,
batchSize,
param.n_col,
param.eps * param.eps);
raft::neighbors::ball_cover::epsUnexpL2NeighborhoodRbc<int, float, int>(
handle,
rbc_index,
adj_ia.data(),
adj_ja.data(),
nullptr,
query,
batchSize,
param.n_col,
param.eps * param.eps);
ASSERT_TRUE(raft::devArrMatch(
adj_ia_baseline.data(), adj_ia.data(), batchSize + 1, raft::Compare<int>(), stream));
ASSERT_TRUE(assertCsrEqualUnordered(adj_ia_baseline.data(),
adj_ja_baseline.data(),
adj_ia.data(),
adj_ja.data(),
batchSize,
param.n_row,
stream));
}
}
}

TEST_P(EpsNeighRbcTestFI, SparseRbcMaxK)
{
rmm::device_uvector<int> adj_ia(batchSize + 1, resource::get_cuda_stream(handle));
rmm::device_uvector<int> adj_ja(param.n_row * batchSize, resource::get_cuda_stream(handle));

rmm::device_uvector<int> vd_baseline(batchSize + 1, resource::get_cuda_stream(handle));
rmm::device_uvector<int> adj_ia_baseline(batchSize + 1, resource::get_cuda_stream(handle));
rmm::device_uvector<int> adj_ja_baseline(param.n_row * batchSize,
resource::get_cuda_stream(handle));

raft::neighbors::ball_cover::BallCoverIndex<int, float, int> rbc_index(
handle, data.data(), param.n_row, param.n_col, raft::distance::DistanceType::L2SqrtUnexpanded);
raft::neighbors::ball_cover::build_index(handle, rbc_index);

int expected_max_k = param.n_row / param.n_centers;

for (int i = 0; i < param.n_batches; ++i) {
// reset full array -- that way we can compare the full size
RAFT_CUDA_TRY(cudaMemsetAsync(adj_ja.data(), 0, sizeof(int) * param.n_row * batchSize, stream));
RAFT_CUDA_TRY(
cudaMemsetAsync(adj_ja_baseline.data(), 0, sizeof(int) * param.n_row * batchSize, stream));

float* query = data.data() + (i * batchSize * param.n_col);

// compute dense baseline and convert adj to csr
{
raft::neighbors::ball_cover::epsUnexpL2NeighborhoodRbc<int, float, int>(
handle,
rbc_index,
adj.data(),
vd_baseline.data(),
query,
batchSize,
param.n_col,
param.eps * param.eps);
thrust::exclusive_scan(resource::get_thrust_policy(handle),
vd_baseline.data(),
vd_baseline.data() + batchSize + 1,
adj_ia_baseline.data());
raft::sparse::convert::adj_to_csr(handle,
adj.data(),
adj_ia_baseline.data(),
batchSize,
param.n_row,
labels.data(),
adj_ja_baseline.data());
}

// exact computation with 1 pass
{
int max_k = expected_max_k;
raft::neighbors::ball_cover::epsUnexpL2NeighborhoodRbc<int, float, int>(handle,
rbc_index,
adj_ia.data(),
adj_ja.data(),
vd.data(),
query,
batchSize,
param.n_col,
param.eps * param.eps,
&max_k);
ASSERT_TRUE(raft::devArrMatch(
adj_ia_baseline.data(), adj_ia.data(), batchSize + 1, raft::Compare<int>(), stream));
ASSERT_TRUE(assertCsrEqualUnordered(adj_ia_baseline.data(),
adj_ja_baseline.data(),
adj_ia.data(),
adj_ja.data(),
batchSize,
param.n_row,
stream));
ASSERT_TRUE(raft::devArrMatch(
vd_baseline.data(), vd.data(), batchSize + 1, raft::Compare<int>(), stream));
ASSERT_TRUE(max_k == expected_max_k);
}

// k-limited computation with 1 pass
{
int max_k = expected_max_k / 2;
raft::neighbors::ball_cover::epsUnexpL2NeighborhoodRbc<int, float, int>(handle,
rbc_index,
adj_ia.data(),
adj_ja.data(),
vd.data(),
query,
batchSize,
param.n_col,
param.eps * param.eps,
&max_k);
ASSERT_TRUE(max_k == expected_max_k);
ASSERT_TRUE(
raft::devArrMatch(expected_max_k / 2, vd.data(), batchSize, raft::Compare<int>(), stream));
ASSERT_TRUE(raft::devArrMatch(
expected_max_k / 2 * batchSize, vd.data() + batchSize, 1, raft::Compare<int>(), stream));
}
}
}

INSTANTIATE_TEST_CASE_P(EpsNeighTests, EpsNeighRbcTestFI, ::testing::ValuesIn(inputsfi_rbc));

}; // namespace knn
}; // namespace spatial
}; // namespace raft

0 comments on commit 89d1f10

Please sign in to comment.