From 2b0202ae2cd75a68471d00787009723683799c5b Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Thu, 31 Oct 2024 15:19:32 -0400 Subject: [PATCH] add in review comments --- .../raft/sparse/neighbors/brute_force.cuh | 4 +- cpp/include/raft/sparse/neighbors/knn.cuh | 6 +- cpp/test/CMakeLists.txt | 12 ++- cpp/test/preprocess_utils.cu | 80 ++++++++++--------- cpp/test/sparse/preprocess_coo.cu | 2 +- cpp/test/sparse/preprocess_csr.cu | 2 +- 6 files changed, 59 insertions(+), 47 deletions(-) diff --git a/cpp/include/raft/sparse/neighbors/brute_force.cuh b/cpp/include/raft/sparse/neighbors/brute_force.cuh index 47e00a012f..8e8f36c2c3 100644 --- a/cpp/include/raft/sparse/neighbors/brute_force.cuh +++ b/cpp/include/raft/sparse/neighbors/brute_force.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ namespace raft::sparse::neighbors::brute_force { /** * Search the sparse kNN for the k-nearest neighbors of a set of sparse query vectors * using some distance implementation + * template parameter value_idx is the type of the Indptr and Indices arrays. + * template parameter value_t is the type of the Data array. * @param[in] idxIndptr csr indptr of the index matrix (size n_idx_rows + 1) * @param[in] idxIndices csr column indices array of the index matrix (size n_idx_nnz) * @param[in] idxData csr data array of the index matrix (size idxNNZ) diff --git a/cpp/include/raft/sparse/neighbors/knn.cuh b/cpp/include/raft/sparse/neighbors/knn.cuh index bffbf6c943..7b93ea4d0d 100644 --- a/cpp/include/raft/sparse/neighbors/knn.cuh +++ b/cpp/include/raft/sparse/neighbors/knn.cuh @@ -62,7 +62,7 @@ namespace raft::sparse::neighbors { * @param[in] metric distance metric/measure to use * @param[in] metricArg potential argument for metric (currently unused) */ -template +template void brute_force_knn(const value_idx* idxIndptr, const value_idx* idxIndices, const value_t* idxData, @@ -120,7 +120,7 @@ void brute_force_knn(const value_idx* idxIndptr, * @param[in] metric distance metric/measure to use * @param[in] metricArg potential argument for metric (currently unused) */ -template +template void brute_force_knn(raft::device_csr_matrix +template void brute_force_knn(raft::device_coo_matrix -void preproc_kernel(raft::resources& handle, - raft::host_vector_view h_rows, - raft::host_vector_view h_cols, - raft::host_vector_view h_elems, - raft::device_vector_view results, - int num_rows, - int num_cols, - bool tf_idf) +void preproc_coo(raft::resources& handle, + raft::host_vector_view h_rows, + raft::host_vector_view h_cols, + raft::host_vector_view h_elems, + raft::device_vector_view results, + int num_rows, + int num_cols, + bool tf_idf) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); int rows_size = h_rows.size(); int cols_size = h_cols.size(); int elements_size = h_elems.size(); auto device_matrix = raft::make_device_matrix(handle, num_rows, num_cols); - raft::matrix::fill(handle, device_matrix.view(), 0.0f); + raft::matrix::fill(handle, device_matrix.view(), 0.0f); auto host_matrix = raft::make_host_matrix(handle, num_rows, num_cols); raft::copy(host_matrix.data_handle(), device_matrix.data_handle(), device_matrix.size(), stream); + raft::resource::sync_stream(handle, stream); + for (int i = 0; i < elements_size; i++) { int row = h_rows(i); int col = h_cols(i); @@ -81,7 +83,7 @@ void preproc_kernel(raft::resources& handle, output_cols_lengths.size(), stream); - auto output_cols_length_sum = raft::make_device_scalar(handle, 0); + auto output_cols_length_sum = raft::make_device_scalar(handle, 0); raft::linalg::mapReduce(output_cols_length_sum.data_handle(), num_cols, 0, @@ -89,12 +91,12 @@ void preproc_kernel(raft::resources& handle, raft::add_op(), stream, output_cols_lengths.data_handle()); - auto h_output_cols_length_sum = raft::make_host_scalar(handle, 0); + auto h_output_cols_length_sum = raft::make_host_scalar(handle, 0); raft::copy(h_output_cols_length_sum.data_handle(), output_cols_length_sum.data_handle(), output_cols_length_sum.size(), stream); - float avg_col_length = float(h_output_cols_length_sum(0)) / num_cols; + T2 avg_col_length = T2(h_output_cols_length_sum(0)) / num_cols; auto output_rows_freq = raft::make_device_matrix(handle, 1, num_rows); raft::linalg::reduce(output_rows_freq.data_handle(), @@ -116,13 +118,13 @@ void preproc_kernel(raft::resources& handle, false, stream, false, - check_zeroes()); + check_zeroes()); auto h_output_rows_cnt = raft::make_host_matrix(handle, 1, num_rows); raft::copy( h_output_rows_cnt.data_handle(), output_rows_cnt.data_handle(), output_rows_cnt.size(), stream); auto out_device_matrix = raft::make_device_matrix(handle, num_rows, num_cols); - raft::matrix::fill(handle, out_device_matrix.view(), 0.0f); + raft::matrix::fill(handle, out_device_matrix.view(), 0.0f); auto out_host_matrix = raft::make_host_matrix(handle, num_rows, num_cols); auto out_host_vector = raft::make_host_vector(handle, results.size()); @@ -137,7 +139,7 @@ void preproc_kernel(raft::resources& handle, out_host_matrix(row, col) = 0.0f; } else { float tf = float(val / h_output_cols_lengths(0, col)); - float idf = raft::log(num_cols / h_output_rows_cnt(0, row)); + float idf = raft::log(num_cols / h_output_rows_cnt(0, row)); if (tf_idf) { result = tf * idf; } else { @@ -171,7 +173,7 @@ int get_dupe_mask_count(raft::resources& handle, values.data_handle(), stream); - raft::sparse::op::compute_duplicates_mask( + raft::sparse::op::compute_duplicates_mask( mask.data_handle(), rows.data_handle(), columns.data_handle(), rows.size(), stream); int col_nnz_count = thrust::reduce(raft::resource::get_thrust_policy(handle), @@ -193,15 +195,15 @@ void remove_dupes(raft::resources& handle, { cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto col_counts = raft::make_device_vector(handle, columns.size()); + auto col_counts = raft::make_device_vector(handle, columns.size()); thrust::fill(raft::resource::get_thrust_policy(handle), col_counts.data_handle(), col_counts.data_handle() + col_counts.size(), 1.0f); - auto keys_out = raft::make_device_vector(handle, num_rows); - auto counts_out = raft::make_device_vector(handle, num_rows); + auto keys_out = raft::make_device_vector(handle, num_rows); + auto counts_out = raft::make_device_vector(handle, num_rows); thrust::reduce_by_key(raft::resource::get_thrust_policy(handle), rows.data_handle(), @@ -210,19 +212,19 @@ void remove_dupes(raft::resources& handle, keys_out.data_handle(), counts_out.data_handle()); - auto mask_out = raft::make_device_vector(handle, rows.size()); + auto mask_out = raft::make_device_vector(handle, rows.size()); - raft::linalg::map(handle, mask_out.view(), raft::cast_op{}, raft::make_const_mdspan(mask)); + raft::linalg::map(handle, mask_out.view(), raft::cast_op{}, raft::make_const_mdspan(mask)); - auto values_c = raft::make_device_vector(handle, values.size()); + auto values_c = raft::make_device_vector(handle, values.size()); raft::linalg::map(handle, values_c.view(), raft::mul_op{}, raft::make_const_mdspan(values), raft::make_const_mdspan(mask_out.view())); - auto keys_nnz_out = raft::make_device_vector(handle, num_rows); - auto counts_nnz_out = raft::make_device_vector(handle, num_rows); + auto keys_nnz_out = raft::make_device_vector(handle, num_rows); + auto counts_nnz_out = raft::make_device_vector(handle, num_rows); thrust::reduce_by_key(raft::resource::get_thrust_policy(handle), rows.data_handle(), @@ -231,18 +233,18 @@ void remove_dupes(raft::resources& handle, keys_nnz_out.data_handle(), counts_nnz_out.data_handle()); - raft::sparse::op::coo_remove_scalar(rows.data_handle(), - columns.data_handle(), - values_c.data_handle(), - values_c.size(), - out_rows.data_handle(), - out_cols.data_handle(), - out_vals.data_handle(), - counts_nnz_out.data_handle(), - counts_out.data_handle(), - 0, - num_rows, - stream); + raft::sparse::op::coo_remove_scalar(rows.data_handle(), + columns.data_handle(), + values_c.data_handle(), + values_c.size(), + out_rows.data_handle(), + out_cols.data_handle(), + out_vals.data_handle(), + counts_nnz_out.data_handle(), + counts_out.data_handle(), + 0, + num_rows, + stream); } template @@ -261,7 +263,7 @@ void create_dataset(raft::resources& handle, auto d_out = raft::make_device_vector(handle, rows.size() * 2); int theta_guide = max(num_rows_unique, num_cols_unique); - auto theta = raft::make_device_vector(handle, theta_guide * 4); + auto theta = raft::make_device_vector(handle, theta_guide * 4); raft::random::uniform(handle, rng, theta.view(), 0.0f, 1.0f); @@ -275,9 +277,9 @@ void create_dataset(raft::resources& handle, stream, rng); - auto vals = raft::make_device_vector(handle, rows.size()); + auto vals = raft::make_device_vector(handle, rows.size()); raft::random::uniformInt(handle, rng, vals.view(), 1, max_term_occurence_doc); - raft::linalg::map(handle, values, raft::cast_op{}, raft::make_const_mdspan(vals.view())); + raft::linalg::map(handle, values, raft::cast_op{}, raft::make_const_mdspan(vals.view())); } }; // namespace raft::util \ No newline at end of file diff --git a/cpp/test/sparse/preprocess_coo.cu b/cpp/test/sparse/preprocess_coo.cu index b26e5122d7..44dac88cdb 100644 --- a/cpp/test/sparse/preprocess_coo.cu +++ b/cpp/test/sparse/preprocess_coo.cu @@ -60,7 +60,7 @@ void calc_tfidf_bm25(raft::resources& handle, stream); raft::copy( h_elems.data_handle(), coo_in.get_elements().data(), coo_in.get_elements().size(), stream); - raft::util::preproc_kernel( + raft::util::preproc_coo( handle, h_rows.view(), h_cols.view(), h_elems.view(), results, num_rows, num_cols, tf_idf); } diff --git a/cpp/test/sparse/preprocess_csr.cu b/cpp/test/sparse/preprocess_csr.cu index eab270ce79..e48aabcaa4 100644 --- a/cpp/test/sparse/preprocess_csr.cu +++ b/cpp/test/sparse/preprocess_csr.cu @@ -63,7 +63,7 @@ void calc_tfidf_bm25(raft::resources& handle, raft::copy(h_rows.data_handle(), rows.data_handle(), rows.size(), stream); raft::copy(h_cols.data_handle(), indices.data_handle(), cols_size, stream); raft::copy(h_elems.data_handle(), values.data_handle(), values.size(), stream); - raft::util::preproc_kernel( + raft::util::preproc_coo( handle, h_rows.view(), h_cols.view(), h_elems.view(), results, num_rows, num_cols, tf_idf); }