From f56d0fbc00c0539b8423d81f7202c02775b8c807 Mon Sep 17 00:00:00 2001 From: Shaohui Liu Date: Fri, 28 Jun 2024 18:30:54 +0200 Subject: [PATCH] Support residual evaluation from ceres::Problem (#50) --- _pyceres/core/problem.h | 52 ++++++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/_pyceres/core/problem.h b/_pyceres/core/problem.h index 10d7627..ccee5bf 100644 --- a/_pyceres/core/problem.h +++ b/_pyceres/core/problem.h @@ -9,6 +9,22 @@ namespace py = pybind11; +namespace { + +// Set residual blocks for Ceres::Problem::EvaluateOptions +void SetResidualBlocks( + ceres::Problem::EvaluateOptions& self, + std::vector& residual_block_ids) { + self.residual_blocks.clear(); + self.residual_blocks.reserve(residual_block_ids.size()); + for (auto it = residual_block_ids.begin(); it != residual_block_ids.end(); + ++it) { + self.residual_blocks.push_back(it->id); + } +} + +} // namespace + // Function to create Problem::Options with DO_NOT_TAKE_OWNERSHIP // This is cause we want Python to manage our memory not Ceres ceres::Problem::Options CreateProblemOptions() { @@ -42,15 +58,21 @@ void BindProblem(py::module& m) { py::class_(m, "EvaluateOptions") .def(py::init<>()) - .def("set_parameter_blocks", - [](ceres::Problem::EvaluateOptions& self, - std::vector>& blocks) { - self.parameter_blocks.clear(); - for (auto it = blocks.begin(); it != blocks.end(); ++it) { - py::buffer_info info = it->request(); - self.parameter_blocks.push_back(static_cast(info.ptr)); - } - }) + .def( + "set_parameter_blocks", + [](ceres::Problem::EvaluateOptions& self, + std::vector>& blocks) { + self.parameter_blocks.clear(); + self.parameter_blocks.reserve(blocks.size()); + for (auto it = blocks.begin(); it != blocks.end(); ++it) { + py::buffer_info info = it->request(); + self.parameter_blocks.push_back(static_cast(info.ptr)); + } + }, + py::arg("parameter_blocks")) + .def("set_residual_blocks", + &SetResidualBlocks, + py::arg("residual_block_ids")) .def_readwrite("apply_loss_function", &ceres::Problem::EvaluateOptions::apply_loss_function) .def_readwrite("num_threads", @@ -249,6 +271,18 @@ void BindProblem(py::module& m) { return residuals; }, py::arg("options") = ceres::Problem::EvaluateOptions()) + .def( + "evaluate_residuals", + [](ceres::Problem& self, + std::vector& residual_block_ids) { + ceres::Problem::EvaluateOptions eval_options = + ceres::Problem::EvaluateOptions(); + SetResidualBlocks(eval_options, residual_block_ids); + std::vector residuals; + self.Evaluate(eval_options, nullptr, &residuals, nullptr, nullptr); + return residuals; + }, + py::arg("residual_block_ids")) .def( "evaluate_jacobian", [](ceres::Problem& self,