Skip to content

Commit

Permalink
build fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Dec 3, 2024
1 parent 591ac65 commit 6e2f8d1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
9 changes: 7 additions & 2 deletions cpp/src/distance/detail/kernels/gram_matrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -465,14 +465,19 @@ void GramMatrixBase<math_t>::linear(raft::resources const& handle,
if (is_col_major_nopad) {
auto out_row_major = raft::make_device_matrix_view<math_t, int, raft::row_major>(
out.data_handle(), out.extent(1), out.extent(0));

// TODO: use PW distance from cuvs
raft::sparse::distance::pairwise_distance(
handle, x2, x1, out_row_major, cuvs::distance::DistanceType::InnerProduct, 0.0);
handle, x2, x1, out_row_major, raft::distance::DistanceType::InnerProduct, 0.0);
} else {
auto out_row_major = raft::make_device_matrix_view<math_t, int, raft::row_major>(
out.data_handle(), out.extent(0), out.extent(1));
raft::sparse::distance::pairwise_distance(
handle, x1, x2, out_row_major, cuvs::distance::DistanceType::InnerProduct, 0.0);
handle, x1, x2, out_row_major, raft::distance::DistanceType::InnerProduct, 0.0);
}
}

template class GramMatrixBase<float>;
template class GramMatrixBase<double>;

}; // namespace cuvs::distance::kernels
3 changes: 3 additions & 0 deletions cpp/src/distance/detail/kernels/kernel_factory.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,7 @@ template <typename math_t>
return res;
}

template class KernelFactory<float>;
template class KernelFactory<double>;

}; // end namespace cuvs::distance::kernels

0 comments on commit 6e2f8d1

Please sign in to comment.