diff --git a/CHANGELOG.md b/CHANGELOG.md index 168ab9e15..2e8eae1ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - [gar] Rework `RiccatiSolverDense` to not use inner struct `FactorData` - Various changes to `gar` tests and `test_util`, add `LQRKnot::isApprox()` - Add `MemReq` struct to handle requests for single blocks of memory for multiple arrays +- Change `LQRKnotTpl` to work with a single block of allocated memory using `MemReq` ### Removed diff --git a/bindings/python/src/gar/expose-gar.cpp b/bindings/python/src/gar/expose-gar.cpp index 9cd7ab1be..1fcc5dbf7 100644 --- a/bindings/python/src/gar/expose-gar.cpp +++ b/bindings/python/src/gar/expose-gar.cpp @@ -59,8 +59,9 @@ void exposeGAR() { exposeBlockMatrices(); bp::class_("LQRKnot", bp::no_init) - .def(bp::init(("nx"_a, "nu", "nc"))) - .def(bp::init(("nx"_a, "nu", "nc", "nx2"))) + .def(bp::init(("self"_a, "nx", "nu", "nc"))) + .def(bp::init( + ("self"_a, "nx"_a, "nu", "nc", "nx2", "nth"_a = 0))) .def_readonly("nx", &knot_t::nx) .def_readonly("nu", &knot_t::nu) .def_readonly("nc", &knot_t::nc) diff --git a/gar/include/aligator/gar/lqr-problem.hpp b/gar/include/aligator/gar/lqr-problem.hpp index 23cf834bc..abd1c2202 100644 --- a/gar/include/aligator/gar/lqr-problem.hpp +++ b/gar/include/aligator/gar/lqr-problem.hpp @@ -2,6 +2,7 @@ #pragma once #include "aligator/math.hpp" +#include "mem-req.hpp" #include #include @@ -27,35 +28,81 @@ namespace gar { /// template struct LQRKnotTpl { ALIGATOR_DYNAMIC_TYPEDEFS(Scalar); + enum { Alignment = Eigen::AlignedMax }; + using VectorMap = Eigen::Map; + using MatrixMap = Eigen::Map; + + // tag type + struct no_alloc_t { + explicit constexpr no_alloc_t() {} + }; + static constexpr no_alloc_t no_alloc{}; uint nx, nu, nc, nx2, nth; - MatrixXs Q, S, R; - VectorXs q, r; - MatrixXs A, B, E; - VectorXs f; - MatrixXs C, D; - VectorXs d; - - MatrixXs Gth; - MatrixXs Gx; - MatrixXs Gu; - MatrixXs Gv; - VectorXs gamma; + + MatrixMap Q, S, R; + VectorMap q, r; + MatrixMap A, B, E; + VectorMap f; + MatrixMap C, D; + VectorMap d; + + MatrixMap Gth; + MatrixMap Gx; + MatrixMap Gu; + MatrixMap Gv; + VectorMap gamma; LQRKnotTpl(uint nx, uint nu, uint nc, uint nx2, uint nth = 0); LQRKnotTpl(uint nx, uint nu, uint nc) : LQRKnotTpl(nx, nu, nc, nx) {} + void allocate(); + // initialize the matrices. + void initialize(); // reallocates entire buffer for contigousness void addParameterization(uint nth); + + LQRKnotTpl(const LQRKnotTpl &other); + LQRKnotTpl(LQRKnotTpl &&other); + LQRKnotTpl &operator=(const LQRKnotTpl &other); + LQRKnotTpl &operator=(LQRKnotTpl &&other); + + ~LQRKnotTpl(); + + friend void swap(LQRKnotTpl &lhs, LQRKnotTpl &rhs) { + using std::swap; + swap(lhs.nx, rhs.nx); + swap(lhs.nu, rhs.nu); + swap(lhs.nc, rhs.nc); + swap(lhs.nx2, rhs.nx2); + swap(lhs.nth, rhs.nth); + // only swap the memory ptr, do not swap the Eigen::Map objects. + swap(lhs.memory, rhs.memory); + swap(lhs.req, rhs.req); + + lhs.initialize(); + rhs.initialize(); + } + bool isApprox(const LQRKnotTpl &other, Scalar prec = std::numeric_limits::epsilon()) const; friend bool operator==(const LQRKnotTpl &lhs, const LQRKnotTpl &rhs) { return lhs.isApprox(rhs); } + +private: + LQRKnotTpl(no_alloc_t, uint nx, uint nu, uint nc, uint nx2, uint nth); + Scalar *memory; + MemReq req; }; +template LQRKnotTpl::~LQRKnotTpl() { + if (memory) + std::free(memory); +} + template struct LQRProblemTpl { ALIGATOR_DYNAMIC_TYPEDEFS(Scalar); using KnotType = LQRKnotTpl; diff --git a/gar/include/aligator/gar/lqr-problem.hxx b/gar/include/aligator/gar/lqr-problem.hxx index cf269ce84..bac434a41 100644 --- a/gar/include/aligator/gar/lqr-problem.hxx +++ b/gar/include/aligator/gar/lqr-problem.hxx @@ -4,43 +4,163 @@ namespace aligator::gar { +namespace detail { +template +void emplaceMap(Eigen::Map &map, long size, + Scalar *ptr) { + using MapType = Eigen::Map; + new (&map) MapType{ptr, size}; +} + +/// \brief Placement-new a map type using the provided memory pointer. +template +void emplaceMap(Eigen::Map &map, long rows, long cols, + Scalar *ptr) { + using MapType = Eigen::Map; + new (&map) MapType{ptr, rows, cols}; +} +} // namespace detail + +template +LQRKnotTpl::LQRKnotTpl(no_alloc_t, uint nx, uint nu, uint nc, uint nx2, + uint nth) + : nx(nx), nu(nu), nc(nc), nx2(nx2), nth(nth), // + Q(NULL, 0, 0), S(NULL, 0, 0), R(NULL, 0, 0), q(NULL, 0), r(NULL, 0), // + A(NULL, 0, 0), B(NULL, 0, 0), E(NULL, 0, 0), f(NULL, 0), // + C(NULL, 0, 0), D(NULL, 0, 0), d(NULL, 0), // + Gth(NULL, 0, 0), Gx(NULL, 0, 0), Gu(NULL, 0, 0), Gv(NULL, 0, 0), + gamma(NULL, 0), // + memory(NULL), req(Alignment) {} + template LQRKnotTpl::LQRKnotTpl(uint nx, uint nu, uint nc, uint nx2, uint nth) - : nx(nx), nu(nu), nc(nc), nx2(nx2), nth(nth), // - Q(nx, nx), S(nx, nu), R(nu, nu), q(nx), r(nu), // - A(nx2, nx), B(nx2, nu), E(nx2, nx), f(nx2), // - C(nc, nx), D(nc, nu), d(nc), Gth(nth, nth), Gx(nx, nth), Gu(nu, nth), - Gv(nc, nth), gamma(nth) { - Q.setZero(); - S.setZero(); - R.setZero(); - q.setZero(); - r.setZero(); - - A.setZero(); - B.setZero(); - E.setZero(); - f.setZero(); - - C.setZero(); - D.setZero(); - d.setZero(); - - Gth.setZero(); - Gx.setZero(); - Gu.setZero(); - Gv.setZero(); - gamma.setZero(); + : LQRKnotTpl(no_alloc, nx, nu, nc, nx2, nth) { + + this->allocate(); + this->initialize(); +} + +template void LQRKnotTpl::allocate() { + req.addArray(nx, nx) // Q + .addArray(nx, nu) // S + .addArray(nu, nu) // R + .addArray(nx) // q + .addArray(nu) // r + .addArray(nx2, nx) // A + .addArray(nx2, nu) // B + .addArray(nx2, nx2) // E + .addArray(nx2) // f + .addArray(nc, nx) // C + .addArray(nc, nu) // D + .addArray(nc) // d + .addArray(nth, nth) // Gth + .addArray(nx, nth) // Gx + .addArray(nu, nth) // Gu + .addArray(nc, nth) // Gv + .addArray(nth); // gamma + + this->memory = static_cast(req.allocate()); + std::memset(memory, 0, req.totalBytes()); +} + +template void LQRKnotTpl::initialize() { + Scalar *ptr = memory; + detail::emplaceMap(Q, nx, nx, ptr); + req.advance(ptr); + detail::emplaceMap(S, nx, nu, ptr); + req.advance(ptr); + detail::emplaceMap(R, nu, nu, ptr); + req.advance(ptr); + detail::emplaceMap(q, nx, ptr); + req.advance(ptr); + detail::emplaceMap(r, nu, ptr); + req.advance(ptr); + + detail::emplaceMap(A, nx2, nx, ptr); + req.advance(ptr); + detail::emplaceMap(B, nx2, nu, ptr); + req.advance(ptr); + detail::emplaceMap(E, nx2, nx2, ptr); + req.advance(ptr); + detail::emplaceMap(f, nx2, ptr); + req.advance(ptr); + + detail::emplaceMap(C, nc, nx, ptr); + req.advance(ptr); + detail::emplaceMap(D, nc, nu, ptr); + req.advance(ptr); + detail::emplaceMap(d, nc, ptr); + req.advance(ptr); + + detail::emplaceMap(Gth, nth, nth, ptr); + req.advance(ptr); + detail::emplaceMap(Gx, nx, nth, ptr); + req.advance(ptr); + detail::emplaceMap(Gu, nu, nth, ptr); + req.advance(ptr); + detail::emplaceMap(Gv, nc, nth, ptr); + req.advance(ptr); + detail::emplaceMap(gamma, nth, ptr); + req.advance(ptr); + + req.reset(); } template void LQRKnotTpl::addParameterization(uint nth) { - this->nth = nth; - Gth.setZero(nth, nth); - Gx.setZero(nx, nth); - Gu.setZero(nu, nth); - Gv.setZero(nc, nth); - gamma.setZero(nth); + LQRKnotTpl copy(nx, nu, nc, nx2, nth); + copy.Q = Q; + copy.S = S; + copy.R = R; + copy.q = q; + copy.r = r; + + copy.A = A; + copy.B = B; + copy.E = E; + copy.f = f; + + copy.C = C; + copy.D = D; + copy.d = d; + + *this = LQRKnotTpl{copy}; +} + +template +LQRKnotTpl::LQRKnotTpl(const LQRKnotTpl &other) + : LQRKnotTpl(no_alloc, other.nx, other.nu, other.nc, other.nx2, other.nth) { + this->allocate(); + assert(req.totalBytes() == other.req.totalBytes()); + // copy memory over from other + std::memcpy(memory, other.memory, other.req.totalBytes()); + this->initialize(); +} + +template +LQRKnotTpl::LQRKnotTpl(LQRKnotTpl &&other) + : LQRKnotTpl(no_alloc, other.nx, other.nu, other.nc, other.nx2, other.nth) { + // no need to allocate, just bring in the other + // memory buffer + memory = other.memory; + other.memory = NULL; + req = other.req; + this->initialize(); +} + +template +LQRKnotTpl &LQRKnotTpl::operator=(const LQRKnotTpl &other) { + this->~LQRKnotTpl(); + new (this) LQRKnotTpl{other}; + return *this; +} + +template +LQRKnotTpl &LQRKnotTpl::operator=(LQRKnotTpl &&other) { + swap(*this, other); + return *this; } template