Skip to content

Commit

Permalink
Merge pull request #259 from rapidsai/branch-24.08
Browse files Browse the repository at this point in the history
Forward-merge branch-24.08 into branch-24.10
  • Loading branch information
GPUtester authored Jul 29, 2024
2 parents a1a7291 + 33698a5 commit 4ef1611
Showing 1 changed file with 8 additions and 31 deletions.
39 changes: 8 additions & 31 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
#include <raft/sparse/convert/coo.cuh>
#include <raft/sparse/convert/csr.cuh>
#include <raft/sparse/distance/detail/utils.cuh>
#include <raft/sparse/linalg/sddmm.hpp>
#include <raft/sparse/linalg/masked_matmul.hpp>
#include <raft/sparse/matrix/select_k.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
Expand Down Expand Up @@ -636,36 +636,13 @@ void brute_force_search_filtered(
rows.data(),
compressed_csr_view.get_nnz(),
stream);
if (n_queries > 10) {
auto csr_view = raft::make_device_csr_matrix_view<T, IdxT, IdxT, IdxT>(
csr.get_elements().data(), compressed_csr_view);

// create dataset view
auto dataset_view = raft::make_device_matrix_view<const T, IdxT, raft::col_major>(
idx.dataset().data_handle(), dim, n_dataset);

// calc dot
T alpha = static_cast<T>(1.0f);
T beta = static_cast<T>(0.0f);
raft::sparse::linalg::sddmm(res,
queries,
dataset_view,
csr_view,
raft::linalg::Operation::NON_TRANSPOSE,
raft::linalg::Operation::NON_TRANSPOSE,
raft::make_host_scalar_view<T>(&alpha),
raft::make_host_scalar_view<T>(&beta));
} else {
raft::sparse::distance::detail::faster_dot_on_csr(res,
csr.get_elements().data(),
compressed_csr_view.get_nnz(),
compressed_csr_view.get_indptr().data(),
compressed_csr_view.get_indices().data(),
queries.data_handle(),
idx.dataset().data_handle(),
compressed_csr_view.get_n_rows(),
dim);
}
auto dataset_view = raft::make_device_matrix_view<const T, IdxT, raft::row_major>(
idx.dataset().data_handle(), n_dataset, dim);

auto csr_view = raft::make_device_csr_matrix_view<T, IdxT, IdxT, IdxT>(
csr.get_elements().data(), compressed_csr_view);

raft::sparse::linalg::masked_matmul(res, queries, dataset_view, filter, csr_view);

// post process
std::optional<raft::device_vector<T, IdxT>> query_norms_;
Expand Down

0 comments on commit 4ef1611

Please sign in to comment.