diff --git a/cpp/include/raft/sparse/op/detail/sort.h b/cpp/include/raft/sparse/op/detail/sort.h index 85ae825035..02287c2367 100644 --- a/cpp/include/raft/sparse/op/detail/sort.h +++ b/cpp/include/raft/sparse/op/detail/sort.h @@ -68,8 +68,8 @@ struct TupleComp { * @param vals vals array from coo matrix * @param stream: cuda stream to use */ -template -void coo_sort(int m, int n, int nnz, int* rows, int* cols, T* vals, cudaStream_t stream) +template +void coo_sort(IdxT m, IdxT n, IdxT nnz, IdxT* rows, IdxT* cols, T* vals, cudaStream_t stream) { auto coo_indices = thrust::make_zip_iterator(thrust::make_tuple(rows, cols)); @@ -83,10 +83,10 @@ void coo_sort(int m, int n, int nnz, int* rows, int* cols, T* vals, cudaStream_t * @param in: COO to sort by row * @param stream: the cuda stream to use */ -template -void coo_sort(COO* const in, cudaStream_t stream) +template +void coo_sort(COO* const in, cudaStream_t stream) { - coo_sort(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), in->vals(), stream); + coo_sort(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), in->vals(), stream); } /** diff --git a/cpp/include/raft/sparse/op/sort.cuh b/cpp/include/raft/sparse/op/sort.cuh index c6c3c2e220..5b8a792429 100644 --- a/cpp/include/raft/sparse/op/sort.cuh +++ b/cpp/include/raft/sparse/op/sort.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,8 +37,8 @@ namespace op { * @param vals vals array from coo matrix * @param stream: cuda stream to use */ -template -void coo_sort(int m, int n, int nnz, int* rows, int* cols, T* vals, cudaStream_t stream) +template +void coo_sort(IdxT m, IdxT n, IdxT nnz, IdxT* rows, IdxT* cols, T* vals, cudaStream_t stream) { detail::coo_sort(m, n, nnz, rows, cols, vals, stream); } @@ -49,10 +49,10 @@ void coo_sort(int m, int n, int nnz, int* rows, int* cols, T* vals, cudaStream_t * @param in: COO to sort by row * @param stream: the cuda stream to use */ -template -void coo_sort(COO* const in, cudaStream_t stream) +template +void coo_sort(COO* const in, cudaStream_t stream) { - coo_sort(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), in->vals(), stream); + coo_sort(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), in->vals(), stream); } /** @@ -75,4 +75,4 @@ void coo_sort_by_weight( }; // end NAMESPACE sparse }; // end NAMESPACE raft -#endif \ No newline at end of file +#endif