Skip to content

Commit

Permalink
proxddp : make tryLinearStep un-static
Browse files Browse the repository at this point in the history
remove LQProblem typedef
  • Loading branch information
ManifoldFR committed Oct 18, 2024
1 parent ebfcf4c commit f93a4e6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 20 deletions.
4 changes: 1 addition & 3 deletions include/aligator/solvers/proxddp/solver-proxddp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ template <typename _Scalar> struct SolverProxDDPTpl {
using TrajOptData = TrajOptDataTpl<Scalar>;
using LinesearchOptions = typename Linesearch<Scalar>::Options;
using LinesearchType = proxsuite::nlp::ArmijoLinesearch<Scalar>;
using LQProblem = gar::LQRProblemTpl<Scalar>;
using Filter = FilterTpl<Scalar>;

struct AlmParams {
Expand Down Expand Up @@ -195,8 +194,7 @@ template <typename _Scalar> struct SolverProxDDPTpl {
/// \f$(\bfx \oplus\alpha\delta\bfx, \bfu+\alpha\delta\bfu,
/// \bmlam+\alpha\delta\bmlam)\f$
/// @returns The trajectory cost.
static Scalar tryLinearStep(const Problem &problem, Workspace &workspace,
const Results &results, const Scalar alpha);
Scalar tryLinearStep(const Problem &problem, const Scalar alpha);

/// @brief Policy rollout using the full nonlinear dynamics. The feedback
/// gains need to be computed first. This will evaluate all the terms in the
Expand Down
32 changes: 15 additions & 17 deletions include/aligator/solvers/proxddp/solver-proxddp.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -83,35 +83,34 @@ SolverProxDDPTpl<Scalar>::SolverProxDDPTpl(const Scalar tol,
// C. Forward pass
template <typename Scalar>
Scalar SolverProxDDPTpl<Scalar>::tryLinearStep(const Problem &problem,
Workspace &workspace,
const Results &results,
const Scalar alpha) {
ALIGATOR_TRACY_ZONE_SCOPED;

const std::size_t nsteps = workspace.nsteps;
const std::size_t nsteps = workspace_.nsteps;
assert(results.xs.size() == nsteps + 1);
assert(results.us.size() == nsteps);
assert(results.lams.size() == nsteps + 1);
assert(results.vs.size() == nsteps + 1);

math::vectorMultiplyAdd(results.lams, workspace.dlams, workspace.trial_lams,
math::vectorMultiplyAdd(results_.lams, workspace_.dlams,
workspace_.trial_lams, alpha);
math::vectorMultiplyAdd(results_.vs, workspace_.dvs, workspace_.trial_vs,
alpha);
math::vectorMultiplyAdd(results.vs, workspace.dvs, workspace.trial_vs, alpha);

for (std::size_t i = 0; i < nsteps; i++) {
const StageModel &stage = *problem.stages_[i];
stage.xspace_->integrate(results.xs[i], alpha * workspace.dxs[i],
workspace.trial_xs[i]);
stage.uspace_->integrate(results.us[i], alpha * workspace.dus[i],
workspace.trial_us[i]);
stage.xspace_->integrate(results_.xs[i], alpha * workspace_.dxs[i],
workspace_.trial_xs[i]);
stage.uspace_->integrate(results_.us[i], alpha * workspace_.dus[i],
workspace_.trial_us[i]);
}
const StageModel &stage = *problem.stages_[nsteps - 1];
stage.xspace_next_->integrate(results.xs[nsteps],
alpha * workspace.dxs[nsteps],
workspace.trial_xs[nsteps]);
TrajOptData &prob_data = workspace.problem_data;
stage.xspace_next_->integrate(results_.xs[nsteps],
alpha * workspace_.dxs[nsteps],
workspace_.trial_xs[nsteps]);
TrajOptData &prob_data = workspace_.problem_data;
prob_data.cost_ =
problem.evaluate(workspace.trial_xs, workspace.trial_us, prob_data);
problem.evaluate(workspace_.trial_xs, workspace_.trial_us, prob_data);
return prob_data.cost_;
}

Expand Down Expand Up @@ -163,7 +162,6 @@ void SolverProxDDPTpl<Scalar>::cycleProblem(
linearSolver_->cycleAppend(workspace_.knots[workspace_.nsteps - 1]);
}

/// TODO: REWORK FOR NEW MULTIPLIERS
template <typename Scalar>
void SolverProxDDPTpl<Scalar>::computeMultipliers(
const Problem &problem, const std::vector<VectorXs> &lams,
Expand Down Expand Up @@ -536,7 +534,7 @@ Scalar SolverProxDDPTpl<Scalar>::forwardPass(const Problem &problem,
ALIGATOR_TRACY_ZONE_SCOPED;
switch (rollout_type_) {
case RolloutType::LINEAR:
tryLinearStep(problem, workspace_, results_, alpha);
tryLinearStep(problem, alpha);
break;
case RolloutType::NONLINEAR:
tryNonlinearRollout(problem, alpha);
Expand Down Expand Up @@ -753,7 +751,7 @@ void SolverProxDDPTpl<Scalar>::registerCallback(const std::string &name,
template <typename Scalar> void SolverProxDDPTpl<Scalar>::updateLQSubproblem() {
ALIGATOR_NOMALLOC_SCOPED;
ALIGATOR_TRACY_ZONE_SCOPED;
LQProblem &prob = workspace_.lqr_problem;
gar::LQRProblemTpl<Scalar> &prob = workspace_.lqr_problem;
const TrajOptData &pd = workspace_.problem_data;

using gar::LQRKnotTpl;
Expand Down

0 comments on commit f93a4e6

Please sign in to comment.