From 0fd4d44a290e3bc243c870d16fb7a123c60a9b98 Mon Sep 17 00:00:00 2001 From: B1ueber2y Date: Thu, 20 Jun 2024 23:21:59 +0200 Subject: [PATCH] add support for partial evaluation --- _pyceres/core/problem.h | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/_pyceres/core/problem.h b/_pyceres/core/problem.h index dd86d74..54131f5 100644 --- a/_pyceres/core/problem.h +++ b/_pyceres/core/problem.h @@ -40,11 +40,15 @@ void BindProblem(py::module& m) { .def_readwrite("disable_all_safety_checks", &options::disable_all_safety_checks); - // TODO: bind Problem::Evaluate py::class_(m, "EvaluateOptions") .def(py::init<>()) - // Doesn't make sense to wrap this as you can't see the pointers in python - //.def_readwrite("parameter_blocks",&ceres::Problem::EvaluateOptions) + .def("set_parameter_blocks", [](ceres::Problem::EvaluateOptions& self, std::vector>& blocks) { + self.parameter_blocks.clear(); + for (auto it = blocks.begin(); it != blocks.end(); ++it) { + py::buffer_info info = it->request(); + self.parameter_blocks.push_back(static_cast(info.ptr)); + } + }) .def_readwrite("apply_loss_function", &ceres::Problem::EvaluateOptions::apply_loss_function) .def_readwrite("num_threads", @@ -234,11 +238,20 @@ void BindProblem(py::module& m) { [](ceres::Problem& self, ResidualBlockIDWrapper& residual_block_id) { self.RemoveResidualBlock(residual_block_id.id); }) + .def("evaluate_residuals", + [](ceres::Problem& self, + const ceres::Problem::EvaluateOptions& options) { + std::vector residuals; + self.Evaluate(options, nullptr, &residuals, nullptr, nullptr); + return residuals; + }, + py::arg("options") = ceres::Problem::EvaluateOptions()) .def("evaluate_jacobian", [](ceres::Problem& self, const ceres::Problem::EvaluateOptions& options) { ceres::CRSMatrix jacobian; self.Evaluate(options, nullptr, nullptr, nullptr, &jacobian); return jacobian; - }); + }, + py::arg("options") = ceres::Problem::EvaluateOptions()); }