From 430424bc4813737d62438128b6045cf9daefabdf Mon Sep 17 00:00:00 2001 From: Artem Chirkin <9253178+achirkin@users.noreply.github.com> Date: Wed, 11 Dec 2024 04:49:29 -0800 Subject: [PATCH] Keep the host dataset alive in case if CAGRA index is not-owning --- cpp/test/neighbors/ann_cagra.cuh | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 8d5701439..c1cd3ca09 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -389,12 +389,13 @@ class AnnCagraTest : public ::testing::TestWithParam { (const DataT*)database.data(), ps.n_rows, ps.dim); { + std::optional> database_host{std::nullopt}; cagra::index index(handle_, index_params.metric); if (ps.host_dataset) { - auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); - raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host->data_handle(), database.data(), database.size(), stream_); auto database_host_view = raft::make_host_matrix_view( - (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + (const DataT*)database_host->data_handle(), ps.n_rows, ps.dim); index = cagra::build(handle_, index_params, database_host_view); } else { @@ -567,13 +568,16 @@ class AnnCagraAddNodesTest : public ::testing::TestWithParam { auto initial_database_view = raft::make_device_matrix_view( (const DataT*)database.data(), initial_database_size, ps.dim); + std::optional> database_host{std::nullopt}; cagra::index index(handle_); if (ps.host_dataset) { - auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + database_host = raft::make_host_matrix(ps.n_rows, ps.dim); raft::copy( - database_host.data_handle(), database.data(), initial_database_view.size(), stream_); + database_host->data_handle(), database.data(), initial_database_view.size(), stream_); auto database_host_view = raft::make_host_matrix_view( - (const DataT*)database_host.data_handle(), initial_database_size, ps.dim); + (const DataT*)database_host->data_handle(), initial_database_size, ps.dim); + // NB: database_host must live no less than the index, because the index _may_be_ + // non-onwning index = cagra::build(handle_, index_params, database_host_view); } else { index = cagra::build(handle_, index_params, initial_database_view); @@ -763,12 +767,13 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.n_rows, ps.dim); + std::optional> database_host{std::nullopt}; cagra::index index(handle_); if (ps.host_dataset) { - auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); - raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host->data_handle(), database.data(), database.size(), stream_); auto database_host_view = raft::make_host_matrix_view( - (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + (const DataT*)database_host->data_handle(), ps.n_rows, ps.dim); index = cagra::build(handle_, index_params, database_host_view); } else { index = cagra::build(handle_, index_params, database_view);