From 1432ab6c5db00de79adec94a6c0cbbf1ae7b7142 Mon Sep 17 00:00:00 2001 From: B1ueber2y Date: Mon, 24 Jun 2024 14:02:36 +0200 Subject: [PATCH] update, --- _pyceres/core/crs_matrix.h | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/_pyceres/core/crs_matrix.h b/_pyceres/core/crs_matrix.h index 5bffa09..3c65f0e 100644 --- a/_pyceres/core/crs_matrix.h +++ b/_pyceres/core/crs_matrix.h @@ -13,14 +13,21 @@ namespace py = pybind11; namespace { py::tuple ConvertCRSToPyTuple(const ceres::CRSMatrix& crsMatrix) { - std::vector rows, cols; - std::vector values; + size_t n_values = crsMatrix.values.size(); + py::array_t rows(n_values), cols(n_values); + py::array_t values(n_values); + int* rows_data = static_cast(rows.request().ptr); + int* cols_data = static_cast(cols.request().ptr); + double* values_data = static_cast(values.request().ptr); + + int counter = 0; for (int row = 0; row < crsMatrix.num_rows; ++row) { for (int k = crsMatrix.rows[row]; k < crsMatrix.rows[row + 1]; ++k) { - rows.push_back(row); - cols.push_back(crsMatrix.cols[k]); - values.push_back(crsMatrix.values[k]); + rows_data[counter] = row; + cols_data[counter] = crsMatrix.cols[k]; + values_data[counter] = crsMatrix.values[k]; + counter++; } } @@ -38,6 +45,5 @@ void BindCRSMatrix(py::module& m) { .def_readonly("rows", &CRSMatrix::rows) .def_readonly("cols", &CRSMatrix::cols) .def_readonly("values", &CRSMatrix::values) - .def("to_tuple", - [](CRSMatrix& self) { return ConvertCRSToPyTuple(self); }); + .def("to_tuple", &ConvertCRSToPyTuple); }