From 4fce7589aa50e450f35240b44c1f5f50ee62f5d2 Mon Sep 17 00:00:00 2001 From: Shaohui Liu Date: Wed, 26 Jun 2024 09:54:51 +0200 Subject: [PATCH] Evaluation of residuals and Jacobian from pyceres.Problem (#49) --- _pyceres/core/bindings.h | 2 ++ _pyceres/core/crs_matrix.h | 49 ++++++++++++++++++++++++++++++++++++++ _pyceres/core/problem.h | 32 +++++++++++++++++++++---- 3 files changed, 79 insertions(+), 4 deletions(-) create mode 100644 _pyceres/core/crs_matrix.h diff --git a/_pyceres/core/bindings.h b/_pyceres/core/bindings.h index 872bd44..c2e67a9 100644 --- a/_pyceres/core/bindings.h +++ b/_pyceres/core/bindings.h @@ -3,6 +3,7 @@ #include "_pyceres/core/callbacks.h" #include "_pyceres/core/cost_functions.h" #include "_pyceres/core/covariance.h" +#include "_pyceres/core/crs_matrix.h" #include "_pyceres/core/loss_functions.h" #include "_pyceres/core/manifold.h" #include "_pyceres/core/problem.h" @@ -17,6 +18,7 @@ void BindCore(py::module& m) { BindTypes(m); BindCallbacks(m); BindCovariance(m); + BindCRSMatrix(m); BindSolver(m); BindLossFunctions(m); BindCostFunctions(m); diff --git a/_pyceres/core/crs_matrix.h b/_pyceres/core/crs_matrix.h new file mode 100644 index 0000000..c21577b --- /dev/null +++ b/_pyceres/core/crs_matrix.h @@ -0,0 +1,49 @@ +#pragma once + +#include "_pyceres/core/wrappers.h" +#include "_pyceres/helpers.h" +#include "_pyceres/logging.h" + +#include +#include +#include +#include + +namespace py = pybind11; + +namespace { +py::tuple ConvertCRSToPyTuple(const ceres::CRSMatrix& crsMatrix) { + const size_t n_values = crsMatrix.values.size(); + py::array_t rows(n_values), cols(n_values); + py::array_t values(n_values); + + int* const rows_data = static_cast(rows.request().ptr); + int* const cols_data = static_cast(cols.request().ptr); + double* const values_data = static_cast(values.request().ptr); + + int counter = 0; + for (int row = 0; row < crsMatrix.num_rows; ++row) { + for (int k = crsMatrix.rows[row]; k < crsMatrix.rows[row + 1]; ++k) { + rows_data[counter] = row; + cols_data[counter] = crsMatrix.cols[k]; + values_data[counter] = crsMatrix.values[k]; + counter++; + } + } + + // return as a tuple + return py::make_tuple(rows, cols, values); +} +} // namespace + +void BindCRSMatrix(py::module& m) { + using CRSMatrix = ceres::CRSMatrix; + py::class_ PyCRSMatrix(m, "CRSMatrix"); + PyCRSMatrix.def(py::init<>()) + .def_readonly("num_rows", &CRSMatrix::num_rows) + .def_readonly("num_cols", &CRSMatrix::num_cols) + .def_readonly("rows", &CRSMatrix::rows) + .def_readonly("cols", &CRSMatrix::cols) + .def_readonly("values", &CRSMatrix::values) + .def("to_tuple", &ConvertCRSToPyTuple); +} diff --git a/_pyceres/core/problem.h b/_pyceres/core/problem.h index 562fd76..10d7627 100644 --- a/_pyceres/core/problem.h +++ b/_pyceres/core/problem.h @@ -40,11 +40,17 @@ void BindProblem(py::module& m) { .def_readwrite("disable_all_safety_checks", &options::disable_all_safety_checks); - // TODO: bind Problem::Evaluate py::class_(m, "EvaluateOptions") .def(py::init<>()) - // Doesn't make sense to wrap this as you can't see the pointers in python - //.def_readwrite("parameter_blocks",&ceres::Problem::EvaluateOptions) + .def("set_parameter_blocks", + [](ceres::Problem::EvaluateOptions& self, + std::vector>& blocks) { + self.parameter_blocks.clear(); + for (auto it = blocks.begin(); it != blocks.end(); ++it) { + py::buffer_info info = it->request(); + self.parameter_blocks.push_back(static_cast(info.ptr)); + } + }) .def_readwrite("apply_loss_function", &ceres::Problem::EvaluateOptions::apply_loss_function) .def_readwrite("num_threads", @@ -233,5 +239,23 @@ void BindProblem(py::module& m) { .def("remove_residual_block", [](ceres::Problem& self, ResidualBlockIDWrapper& residual_block_id) { self.RemoveResidualBlock(residual_block_id.id); - }); + }) + .def( + "evaluate_residuals", + [](ceres::Problem& self, + const ceres::Problem::EvaluateOptions& options) { + std::vector residuals; + self.Evaluate(options, nullptr, &residuals, nullptr, nullptr); + return residuals; + }, + py::arg("options") = ceres::Problem::EvaluateOptions()) + .def( + "evaluate_jacobian", + [](ceres::Problem& self, + const ceres::Problem::EvaluateOptions& options) { + ceres::CRSMatrix jacobian; + self.Evaluate(options, nullptr, nullptr, nullptr, &jacobian); + return jacobian; + }, + py::arg("options") = ceres::Problem::EvaluateOptions()); }