Skip to content

Commit

Permalink
add in review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jperez999 committed Oct 31, 2024
1 parent a70619e commit 2b0202a
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 47 deletions.
4 changes: 3 additions & 1 deletion cpp/include/raft/sparse/neighbors/brute_force.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,8 @@ namespace raft::sparse::neighbors::brute_force {
/**
* Search the sparse kNN for the k-nearest neighbors of a set of sparse query vectors
* using some distance implementation
* template parameter value_idx is the type of the Indptr and Indices arrays.
* template parameter value_t is the type of the Data array.
* @param[in] idxIndptr csr indptr of the index matrix (size n_idx_rows + 1)
* @param[in] idxIndices csr column indices array of the index matrix (size n_idx_nnz)
* @param[in] idxData csr data array of the index matrix (size idxNNZ)
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/sparse/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ namespace raft::sparse::neighbors {
* @param[in] metric distance metric/measure to use
* @param[in] metricArg potential argument for metric (currently unused)
*/
template <typename value_idx = int, typename value_t = float, int TPB_X = 32>
template <typename value_idx = int, typename value_t = float>
void brute_force_knn(const value_idx* idxIndptr,
const value_idx* idxIndices,
const value_t* idxData,
Expand Down Expand Up @@ -120,7 +120,7 @@ void brute_force_knn(const value_idx* idxIndptr,
* @param[in] metric distance metric/measure to use
* @param[in] metricArg potential argument for metric (currently unused)
*/
template <typename value_idx = int, typename value_t = float, int TPB_X = 32>
template <typename value_idx = int, typename value_t = float>
void brute_force_knn(raft::device_csr_matrix<value_t,
value_idx,
value_idx,
Expand Down Expand Up @@ -186,7 +186,7 @@ void brute_force_knn(raft::device_csr_matrix<value_t,
* @param[in] metric distance metric/measure to use
* @param[in] metricArg potential argument for metric (currently unused)
*/
template <typename value_idx = int, typename value_t = float, int TPB_X = 32>
template <typename value_idx = int, typename value_t = float>
void brute_force_knn(raft::device_coo_matrix<value_t,
value_idx,
value_idx,
Expand Down
12 changes: 10 additions & 2 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,16 @@ if(BUILD_TESTS)
)

ConfigureTest(
NAME SPARSE_NEIGHBORS_TEST PATH sparse/neighbors/cross_component_nn.cu
sparse/neighbors/brute_force.cu sparse/neighbors/knn_graph.cu LIB EXPLICIT_INSTANTIATE_ONLY
NAME
SPARSE_NEIGHBORS_TEST
PATH
sparse/neighbors/cross_component_nn.cu
sparse/neighbors/brute_force.cu
sparse/neighbors/brute_force_coo.cu
sparse/neighbors/brute_force_csr.cu
sparse/neighbors/knn_graph.cu
LIB
EXPLICIT_INSTANTIATE_ONLY
)

ConfigureTest(
Expand Down
80 changes: 41 additions & 39 deletions cpp/test/preprocess_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,26 @@ struct check_zeroes {
};

template <typename T1, typename T2>
void preproc_kernel(raft::resources& handle,
raft::host_vector_view<T1> h_rows,
raft::host_vector_view<T1> h_cols,
raft::host_vector_view<T2> h_elems,
raft::device_vector_view<T2> results,
int num_rows,
int num_cols,
bool tf_idf)
void preproc_coo(raft::resources& handle,
raft::host_vector_view<T1> h_rows,
raft::host_vector_view<T1> h_cols,
raft::host_vector_view<T2> h_elems,
raft::device_vector_view<T2> results,
int num_rows,
int num_cols,
bool tf_idf)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);
int rows_size = h_rows.size();
int cols_size = h_cols.size();
int elements_size = h_elems.size();
auto device_matrix = raft::make_device_matrix<T2, int64_t>(handle, num_rows, num_cols);
raft::matrix::fill<float>(handle, device_matrix.view(), 0.0f);
raft::matrix::fill<T2>(handle, device_matrix.view(), 0.0f);
auto host_matrix = raft::make_host_matrix<T2, int64_t>(handle, num_rows, num_cols);
raft::copy(host_matrix.data_handle(), device_matrix.data_handle(), device_matrix.size(), stream);

raft::resource::sync_stream(handle, stream);

for (int i = 0; i < elements_size; i++) {
int row = h_rows(i);
int col = h_cols(i);
Expand All @@ -81,20 +83,20 @@ void preproc_kernel(raft::resources& handle,
output_cols_lengths.size(),
stream);

auto output_cols_length_sum = raft::make_device_scalar<int>(handle, 0);
auto output_cols_length_sum = raft::make_device_scalar<T1>(handle, 0);
raft::linalg::mapReduce(output_cols_length_sum.data_handle(),
num_cols,
0,
raft::identity_op(),
raft::add_op(),
stream,
output_cols_lengths.data_handle());
auto h_output_cols_length_sum = raft::make_host_scalar<int>(handle, 0);
auto h_output_cols_length_sum = raft::make_host_scalar<T1>(handle, 0);
raft::copy(h_output_cols_length_sum.data_handle(),
output_cols_length_sum.data_handle(),
output_cols_length_sum.size(),
stream);
float avg_col_length = float(h_output_cols_length_sum(0)) / num_cols;
T2 avg_col_length = T2(h_output_cols_length_sum(0)) / num_cols;

auto output_rows_freq = raft::make_device_matrix<T2, int64_t>(handle, 1, num_rows);
raft::linalg::reduce(output_rows_freq.data_handle(),
Expand All @@ -116,13 +118,13 @@ void preproc_kernel(raft::resources& handle,
false,
stream,
false,
check_zeroes<float, float>());
check_zeroes<T2, T2>());
auto h_output_rows_cnt = raft::make_host_matrix<T2, int64_t>(handle, 1, num_rows);
raft::copy(
h_output_rows_cnt.data_handle(), output_rows_cnt.data_handle(), output_rows_cnt.size(), stream);

auto out_device_matrix = raft::make_device_matrix<T2, int64_t>(handle, num_rows, num_cols);
raft::matrix::fill<float>(handle, out_device_matrix.view(), 0.0f);
raft::matrix::fill<T2>(handle, out_device_matrix.view(), 0.0f);
auto out_host_matrix = raft::make_host_matrix<T2, int64_t>(handle, num_rows, num_cols);
auto out_host_vector = raft::make_host_vector<T2, int64_t>(handle, results.size());

Expand All @@ -137,7 +139,7 @@ void preproc_kernel(raft::resources& handle,
out_host_matrix(row, col) = 0.0f;
} else {
float tf = float(val / h_output_cols_lengths(0, col));
float idf = raft::log<float>(num_cols / h_output_rows_cnt(0, row));
float idf = raft::log<T2>(num_cols / h_output_rows_cnt(0, row));
if (tf_idf) {
result = tf * idf;
} else {
Expand Down Expand Up @@ -171,7 +173,7 @@ int get_dupe_mask_count(raft::resources& handle,
values.data_handle(),
stream);

raft::sparse::op::compute_duplicates_mask<int>(
raft::sparse::op::compute_duplicates_mask<T1>(
mask.data_handle(), rows.data_handle(), columns.data_handle(), rows.size(), stream);

int col_nnz_count = thrust::reduce(raft::resource::get_thrust_policy(handle),
Expand All @@ -193,15 +195,15 @@ void remove_dupes(raft::resources& handle,
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);

auto col_counts = raft::make_device_vector<int, int64_t>(handle, columns.size());
auto col_counts = raft::make_device_vector<T1, int64_t>(handle, columns.size());

thrust::fill(raft::resource::get_thrust_policy(handle),
col_counts.data_handle(),
col_counts.data_handle() + col_counts.size(),
1.0f);

auto keys_out = raft::make_device_vector<int, int64_t>(handle, num_rows);
auto counts_out = raft::make_device_vector<int, int64_t>(handle, num_rows);
auto keys_out = raft::make_device_vector<T1, int64_t>(handle, num_rows);
auto counts_out = raft::make_device_vector<T1, int64_t>(handle, num_rows);

thrust::reduce_by_key(raft::resource::get_thrust_policy(handle),
rows.data_handle(),
Expand All @@ -210,19 +212,19 @@ void remove_dupes(raft::resources& handle,
keys_out.data_handle(),
counts_out.data_handle());

auto mask_out = raft::make_device_vector<float, int64_t>(handle, rows.size());
auto mask_out = raft::make_device_vector<T2, int64_t>(handle, rows.size());

raft::linalg::map(handle, mask_out.view(), raft::cast_op<float>{}, raft::make_const_mdspan(mask));
raft::linalg::map(handle, mask_out.view(), raft::cast_op<T2>{}, raft::make_const_mdspan(mask));

auto values_c = raft::make_device_vector<float, int64_t>(handle, values.size());
auto values_c = raft::make_device_vector<T2, int64_t>(handle, values.size());
raft::linalg::map(handle,
values_c.view(),
raft::mul_op{},
raft::make_const_mdspan(values),
raft::make_const_mdspan(mask_out.view()));

auto keys_nnz_out = raft::make_device_vector<int, int64_t>(handle, num_rows);
auto counts_nnz_out = raft::make_device_vector<int, int64_t>(handle, num_rows);
auto keys_nnz_out = raft::make_device_vector<T1, int64_t>(handle, num_rows);
auto counts_nnz_out = raft::make_device_vector<T1, int64_t>(handle, num_rows);

thrust::reduce_by_key(raft::resource::get_thrust_policy(handle),
rows.data_handle(),
Expand All @@ -231,18 +233,18 @@ void remove_dupes(raft::resources& handle,
keys_nnz_out.data_handle(),
counts_nnz_out.data_handle());

raft::sparse::op::coo_remove_scalar<float>(rows.data_handle(),
columns.data_handle(),
values_c.data_handle(),
values_c.size(),
out_rows.data_handle(),
out_cols.data_handle(),
out_vals.data_handle(),
counts_nnz_out.data_handle(),
counts_out.data_handle(),
0,
num_rows,
stream);
raft::sparse::op::coo_remove_scalar<T2>(rows.data_handle(),
columns.data_handle(),
values_c.data_handle(),
values_c.size(),
out_rows.data_handle(),
out_cols.data_handle(),
out_vals.data_handle(),
counts_nnz_out.data_handle(),
counts_out.data_handle(),
0,
num_rows,
stream);
}

template <typename T1, typename T2>
Expand All @@ -261,7 +263,7 @@ void create_dataset(raft::resources& handle,
auto d_out = raft::make_device_vector<T1, int64_t>(handle, rows.size() * 2);

int theta_guide = max(num_rows_unique, num_cols_unique);
auto theta = raft::make_device_vector<float, int64_t>(handle, theta_guide * 4);
auto theta = raft::make_device_vector<T2, int64_t>(handle, theta_guide * 4);

raft::random::uniform(handle, rng, theta.view(), 0.0f, 1.0f);

Expand All @@ -275,9 +277,9 @@ void create_dataset(raft::resources& handle,
stream,
rng);

auto vals = raft::make_device_vector<int, int64_t>(handle, rows.size());
auto vals = raft::make_device_vector<T1, int64_t>(handle, rows.size());
raft::random::uniformInt(handle, rng, vals.view(), 1, max_term_occurence_doc);
raft::linalg::map(handle, values, raft::cast_op<float>{}, raft::make_const_mdspan(vals.view()));
raft::linalg::map(handle, values, raft::cast_op<T2>{}, raft::make_const_mdspan(vals.view()));
}

}; // namespace raft::util
2 changes: 1 addition & 1 deletion cpp/test/sparse/preprocess_coo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void calc_tfidf_bm25(raft::resources& handle,
stream);
raft::copy(
h_elems.data_handle(), coo_in.get_elements().data(), coo_in.get_elements().size(), stream);
raft::util::preproc_kernel<T1, T2>(
raft::util::preproc_coo<T1, T2>(
handle, h_rows.view(), h_cols.view(), h_elems.view(), results, num_rows, num_cols, tf_idf);
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/test/sparse/preprocess_csr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void calc_tfidf_bm25(raft::resources& handle,
raft::copy(h_rows.data_handle(), rows.data_handle(), rows.size(), stream);
raft::copy(h_cols.data_handle(), indices.data_handle(), cols_size, stream);
raft::copy(h_elems.data_handle(), values.data_handle(), values.size(), stream);
raft::util::preproc_kernel<T1, T2>(
raft::util::preproc_coo<T1, T2>(
handle, h_rows.view(), h_cols.view(), h_elems.view(), results, num_rows, num_cols, tf_idf);
}

Expand Down

0 comments on commit 2b0202a

Please sign in to comment.