Skip to content

Commit

Permalink
Add NormalError cost function
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe committed May 17, 2024
1 parent f90405e commit 8555c99
Showing 1 changed file with 42 additions and 3 deletions.
45 changes: 42 additions & 3 deletions _pyceres/factors/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,55 @@

namespace py = pybind11;

inline Eigen::MatrixXd SqrtInformation(const Eigen::MatrixXd& covariance) {
return covariance.inverse().llt().matrixL();
}

class NormalError {
public:
explicit NormalError(const Eigen::MatrixXd& covariance)
: sqrt_information_(SqrtInformation(covariance)),
dimension_(covariance.rows()) {
THROW_CHECK_EQ(covariance.rows(), covariance.cols());
}

static ceres::CostFunction* Create(const Eigen::MatrixXd& covariance) {
auto* cost_function = new ceres::DynamicAutoDiffCostFunction<NormalError>(
new NormalError(covariance));
const int dimension = covariance.rows();
cost_function->AddParameterBlock(dimension);
cost_function->AddParameterBlock(dimension);
cost_function->SetNumResiduals(dimension);
return cost_function;
}

template <typename T>
bool operator()(T const* const* parameters, T* residuals_ptr) const {
for (int i = 0; i < dimension_; ++i) {
residuals_ptr[i] = parameters[0][i] - parameters[1][i];
}
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>> residuals(residuals_ptr,
dimension_);
residuals.applyOnTheLeft(sqrt_information_.template cast<T>());
return true;
}

private:
const Eigen::MatrixXd sqrt_information_;
const int dimension_;
};

void BindFactors(py::module& m) {
m.def(
"NormalPrior",
[](const Eigen::VectorXd& mean,
const Eigen::Matrix<double, -1, -1>& covariance) {
const Eigen::MatrixXd& covariance) -> ceres::CostFunction* {
THROW_CHECK_EQ(covariance.cols(), mean.size());
THROW_CHECK_EQ(covariance.cols(), covariance.rows());
return new ceres::NormalPrior(covariance.inverse().llt().matrixL(),
mean);
return new ceres::NormalPrior(SqrtInformation(covariance), mean);
},
py::arg("mean"),
py::arg("covariance"));

m.def("NormalError", &NormalError::Create, py::arg("covariance"));
}

0 comments on commit 8555c99

Please sign in to comment.