diff --git a/cpp/include/raft/spectral/eigen_solvers.cuh b/cpp/include/raft/spectral/eigen_solvers.cuh index d98e90532e..324f16ac7b 100644 --- a/cpp/include/raft/spectral/eigen_solvers.cuh +++ b/cpp/include/raft/spectral/eigen_solvers.cuh @@ -18,7 +18,6 @@ #pragma once -#include #include #include @@ -58,31 +57,18 @@ struct lanczos_solver_t { { RAFT_EXPECTS(eigVals != nullptr, "Null eigVals buffer."); RAFT_EXPECTS(eigVecs != nullptr, "Null eigVecs buffer."); - index_type_t iters{0}; // TODO: return total number of iter - auto lanczos_config = raft::sparse::solver::lanczos_solver_config{ - config_.n_eigVecs, config_.maxIter, config_.restartIter, config_.tol, config_.seed}; - auto csr_structure = - raft::make_device_compressed_structure_view( - const_cast(A.row_offsets_), - const_cast(A.col_indices_), - A.nrows_, - A.ncols_, - A.nnz_); - - auto csr_matrix = - raft::make_device_csr_matrix_view( - const_cast(A.values_), csr_structure); - std::optional> v0_opt; - - sparse::solver::lanczos_compute_smallest_eigenvectors( - handle, - lanczos_config, - csr_matrix, - v0_opt, - raft::make_device_vector_view(eigVals, - config_.n_eigVecs), - raft::make_device_matrix_view( - eigVecs, A.nrows_, config_.n_eigVecs)); + index_type_t iters{}; + sparse::solver::computeSmallestEigenvectors(handle, + A, + config_.n_eigVecs, + config_.maxIter, + config_.restartIter, + config_.tol, + config_.reorthogonalize, + iters, + eigVals, + eigVecs, + config_.seed); return iters; }