Skip to content

Commit

Permalink
add test for sample rows
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Mar 13, 2024
1 parent eb73ef5 commit cc2cf24
Show file tree
Hide file tree
Showing 3 changed files with 446 additions and 448 deletions.
82 changes: 0 additions & 82 deletions cpp/include/raft/spatial/knn/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -577,86 +577,4 @@ struct batch_load_iterator {
size_type cur_pos_;
};

template <typename IdxT>
auto get_subsample_indices(raft::resources const& res, IdxT n_samples, IdxT n_subsamples, int seed)
-> raft::device_vector<IdxT, IdxT>
{
RAFT_EXPECTS(n_subsamples <= n_samples, "Cannot have more training samples than dataset vectors");
// size_t free, total;
// float GiB = 1073741824.0f;
// cudaMemGetInfo(&free, &total);
// RAFT_LOG_INFO(
// "get_subsample_indices::data free mem %6.1f, used mem %6.1f", free / GiB, (total - free) /
// GiB);

auto data_indices = raft::make_device_vector<IdxT, IdxT>(res, n_samples);
// cudaMemGetInfo(&free, &total);
// RAFT_LOG_INFO("get_subsample_indices::train free mem %6.1f, used mem %6.1f",
// free / GiB,
// (total - free) / GiB);

auto train_indices = raft::make_device_vector<IdxT, IdxT>(res, n_subsamples);
raft::linalg::map_offset(res, data_indices.view(), identity_op());
raft::random::RngState rng(seed);
raft::random::sample_without_replacement(res,
rng,
raft::make_const_mdspan(data_indices.view()),
std::nullopt,
train_indices.view(),
std::nullopt);
return train_indices;
}

/** Subsample the dataset to create a training set*/
template <typename T, typename IdxT = int64_t>
void subsample(raft::resources const& res,
const T* input,
IdxT n_samples,
raft::device_matrix_view<T, IdxT> output,
int seed)
{
IdxT n_dim = output.extent(1);
IdxT n_train = output.extent(0);

raft::device_vector<IdxT, IdxT> train_indices =
get_subsample_indices<IdxT>(res, n_samples, n_train, seed);

cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, input));
T* ptr = reinterpret_cast<T*>(attr.devicePointer);
if (ptr != nullptr) {
raft::matrix::gather(res,
raft::make_device_matrix_view<const T, IdxT>(ptr, n_samples, n_dim),
raft::make_const_mdspan(train_indices.view()),
output);
} else {
auto dataset = raft::make_host_matrix_view<const T, IdxT>(input, n_samples, n_dim);
raft::matrix::detail::gather(res, dataset, make_const_mdspan(train_indices.view()), output);
}
}

/** Subsample the dataset to create a training set*/
template <typename T, typename IdxT = int64_t>
raft::device_matrix<T, IdxT> subsample(
raft::resources const& res, const T* input, IdxT n_samples, IdxT n_train, IdxT n_dim, int seed)
{
raft::device_vector<IdxT, IdxT> train_indices =
get_subsample_indices<IdxT>(res, n_samples, n_train, seed);

auto output = raft::make_device_matrix<T, IdxT>(res, n_train, n_dim);
cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, input));
T* ptr = reinterpret_cast<T*>(attr.devicePointer);
if (ptr != nullptr) {
raft::matrix::gather(res,
raft::make_device_matrix_view<const T, IdxT>(ptr, n_samples, n_dim),
raft::make_const_mdspan(train_indices.view()),
output.view());
} else {
auto dataset = raft::make_host_matrix_view<const T, IdxT>(input, n_samples, n_dim);
raft::matrix::detail::gather(
res, dataset, make_const_mdspan(train_indices.view()), output.view());
}
return output;
}
} // namespace raft::spatial::knn::detail::utils
Loading

0 comments on commit cc2cf24

Please sign in to comment.