Skip to content

Commit

Permalink
Add NormalError cost function (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe authored May 21, 2024
1 parent f90405e commit 3f34fa4
Showing 1 changed file with 42 additions and 3 deletions.
45 changes: 42 additions & 3 deletions _pyceres/factors/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,55 @@

namespace py = pybind11;

inline Eigen::MatrixXd SqrtInformation(const Eigen::MatrixXd& covariance) {
return covariance.inverse().llt().matrixL();
}

// Mahalanobis squared distance between two parameters.
class NormalError {
public:
explicit NormalError(const Eigen::MatrixXd& covariance)
: sqrt_information_(SqrtInformation(covariance)) {
THROW_CHECK_EQ(covariance.rows(), covariance.cols());
}

static ceres::CostFunction* Create(const Eigen::MatrixXd& covariance) {
auto* cost_function = new ceres::DynamicAutoDiffCostFunction<NormalError>(
new NormalError(covariance));
const int dimension = covariance.rows();
cost_function->AddParameterBlock(dimension);
cost_function->AddParameterBlock(dimension);
cost_function->SetNumResiduals(dimension);
return cost_function;
}

template <typename T>
bool operator()(T const* const* parameters, T* residuals_ptr) const {
const int dimension = sqrt_information_.rows();
for (int i = 0; i < dimension; ++i) {
residuals_ptr[i] = parameters[0][i] - parameters[1][i];
}
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>> residuals(residuals_ptr,
dimension);
residuals.applyOnTheLeft(sqrt_information_.template cast<T>());
return true;
}

private:
const Eigen::MatrixXd sqrt_information_;
};

void BindFactors(py::module& m) {
m.def(
"NormalPrior",
[](const Eigen::VectorXd& mean,
const Eigen::Matrix<double, -1, -1>& covariance) {
const Eigen::MatrixXd& covariance) -> ceres::CostFunction* {
THROW_CHECK_EQ(covariance.cols(), mean.size());
THROW_CHECK_EQ(covariance.cols(), covariance.rows());
return new ceres::NormalPrior(covariance.inverse().llt().matrixL(),
mean);
return new ceres::NormalPrior(SqrtInformation(covariance), mean);
},
py::arg("mean"),
py::arg("covariance"));

m.def("NormalError", &NormalError::Create, py::arg("covariance"));
}

0 comments on commit 3f34fa4

Please sign in to comment.