From 8555c995fcc8db8e1f078f93385c6c624fed47d0 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Fri, 17 May 2024 18:07:06 +0200 Subject: [PATCH] Add NormalError cost function --- _pyceres/factors/bindings.h | 45 ++++++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/_pyceres/factors/bindings.h b/_pyceres/factors/bindings.h index e2aca1d..b9c20a5 100644 --- a/_pyceres/factors/bindings.h +++ b/_pyceres/factors/bindings.h @@ -9,16 +9,55 @@ namespace py = pybind11; +inline Eigen::MatrixXd SqrtInformation(const Eigen::MatrixXd& covariance) { + return covariance.inverse().llt().matrixL(); +} + +class NormalError { + public: + explicit NormalError(const Eigen::MatrixXd& covariance) + : sqrt_information_(SqrtInformation(covariance)), + dimension_(covariance.rows()) { + THROW_CHECK_EQ(covariance.rows(), covariance.cols()); + } + + static ceres::CostFunction* Create(const Eigen::MatrixXd& covariance) { + auto* cost_function = new ceres::DynamicAutoDiffCostFunction( + new NormalError(covariance)); + const int dimension = covariance.rows(); + cost_function->AddParameterBlock(dimension); + cost_function->AddParameterBlock(dimension); + cost_function->SetNumResiduals(dimension); + return cost_function; + } + + template + bool operator()(T const* const* parameters, T* residuals_ptr) const { + for (int i = 0; i < dimension_; ++i) { + residuals_ptr[i] = parameters[0][i] - parameters[1][i]; + } + Eigen::Map> residuals(residuals_ptr, + dimension_); + residuals.applyOnTheLeft(sqrt_information_.template cast()); + return true; + } + + private: + const Eigen::MatrixXd sqrt_information_; + const int dimension_; +}; + void BindFactors(py::module& m) { m.def( "NormalPrior", [](const Eigen::VectorXd& mean, - const Eigen::Matrix& covariance) { + const Eigen::MatrixXd& covariance) -> ceres::CostFunction* { THROW_CHECK_EQ(covariance.cols(), mean.size()); THROW_CHECK_EQ(covariance.cols(), covariance.rows()); - return new ceres::NormalPrior(covariance.inverse().llt().matrixL(), - mean); + return new ceres::NormalPrior(SqrtInformation(covariance), mean); }, py::arg("mean"), py::arg("covariance")); + + m.def("NormalError", &NormalError::Create, py::arg("covariance")); }