Skip to content

Commit

Permalink
[python/gar] Separate out how utils are exposed
Browse files Browse the repository at this point in the history
  • Loading branch information
ManifoldFR committed Oct 19, 2024
1 parent d09bdc6 commit c7cae5b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 25 deletions.
28 changes: 3 additions & 25 deletions bindings/python/src/gar/expose-gar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "aligator/python/blk-matrix.hpp"
#include "aligator/gar/lqr-problem.hpp"
#include "aligator/gar/riccati-base.hpp"
#include "aligator/gar/utils.hpp"

#include "aligator/python/utils.hpp"
#include "aligator/python/visitors.hpp"
Expand All @@ -24,17 +23,6 @@ using context::VectorXs;

using knot_vec_t = lqr_t::KnotVector;

bp::dict lqr_sol_initialize_wrap(const lqr_t &problem) {
bp::dict out;
auto ss = lqrInitializeSolution(problem);
auto &[xs, us, vs, lbdas] = ss;
out["xs"] = xs;
out["us"] = us;
out["vs"] = vs;
out["lbdas"] = lbdas;
return out;
}

static void exposeBlockMatrices() {
BlkMatrixPythonVisitor<BlkMatrix<MatrixXs, 2, 2>>::expose("BlockMatrix22");
BlkMatrixPythonVisitor<BlkMatrix<VectorXs, 4, 1>>::expose("BlockVector4");
Expand All @@ -61,6 +49,8 @@ void exposeParallelSolver();
void exposeDenseSolver();
// fwd-declare exposeProxRiccati()
void exposeProxRiccati();
// fwd-declare exposeGarUtils()
void exposeGarUtils();

void exposeGAR() {

Expand Down Expand Up @@ -130,19 +120,7 @@ void exposeGAR() {
.def("forward", &riccati_base_t::forward,
("self"_a, "xs", "us", "vs", "lbdas", "theta"_a = std::nullopt));

bp::def(
"lqrDenseMatrix",
+[](const lqr_t &problem, Scalar mudyn, Scalar mueq) {
auto mat_rhs = lqrDenseMatrix(problem, mudyn, mueq);
return bp::make_tuple(std::get<0>(mat_rhs), std::get<1>(mat_rhs));
},
("problem"_a, "mudyn", "mueq"));

bp::def("lqrCreateSparseMatrix", lqrCreateSparseMatrix<Scalar>,
("problem"_a, "mudyn", "mueq", "mat", "rhs", "update"),
"Create or update a sparse matrix from an LQRProblem.");

bp::def("lqrInitializeSolution", lqr_sol_initialize_wrap, ("problem"_a));
exposeGarUtils();

#ifdef ALIGATOR_WITH_CHOLMOD
exposeCholmodSolver();
Expand Down
37 changes: 37 additions & 0 deletions bindings/python/src/gar/expose-utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "aligator/python/fwd.hpp"
#include "aligator/gar/utils.hpp"

namespace aligator::python {
using namespace gar;

using context::Scalar;
using lqr_t = LQRProblemTpl<context::Scalar>;

bp::dict lqr_sol_initialize_wrap(const lqr_t &problem) {
bp::dict out;
auto ss = lqrInitializeSolution(problem);
auto &[xs, us, vs, lbdas] = ss;
out["xs"] = xs;
out["us"] = us;
out["vs"] = vs;
out["lbdas"] = lbdas;
return out;
}

void exposeGarUtils() {

bp::def(
"lqrDenseMatrix",
+[](const lqr_t &problem, Scalar mudyn, Scalar mueq) {
auto mat_rhs = lqrDenseMatrix(problem, mudyn, mueq);
return bp::make_tuple(std::get<0>(mat_rhs), std::get<1>(mat_rhs));
},
("problem"_a, "mudyn", "mueq"));

bp::def("lqrCreateSparseMatrix", lqrCreateSparseMatrix<Scalar>,
("problem"_a, "mudyn", "mueq", "mat", "rhs", "update"),
"Create or update a sparse matrix from an LQRProblem.");

bp::def("lqrInitializeSolution", lqr_sol_initialize_wrap, ("problem"_a));
}
} // namespace aligator::python

0 comments on commit c7cae5b

Please sign in to comment.