Skip to content

Commit

Permalink
test_util.cpp: various fixes
Browse files Browse the repository at this point in the history
+ add fmt::formatter specialization

Update copyright headings
  • Loading branch information
ManifoldFR committed Oct 23, 2024
1 parent 17fedc7 commit fb3cb09
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 28 deletions.
10 changes: 5 additions & 5 deletions tests/gar/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ BOOST_AUTO_TEST_CASE(parallel_manual) {
bool ret = solver_full_horz.forward(xs, us, vs, lbdas);
BOOST_CHECK(ret);
KktError err_full = computeKktError(problem, xs, us, vs, lbdas);
printKktError(err_full, "KKT error (full horz.)");
fmt::println("KKT error (full horz.): {}", err_full);
BOOST_CHECK_LE(err_full.max, EPS);
}

Expand Down Expand Up @@ -170,7 +170,7 @@ BOOST_AUTO_TEST_CASE(parallel_manual) {

KktError err_merged =
computeKktError(problem, xs_merged, us_merged, vs_merged, lbdas_merged);
printKktError(err_merged, "KKT error (merged)");
fmt::println("KKT error (merged) {}", err_merged);
}

/// Randomize some of the parameters of the problem. This simulates something
Expand Down Expand Up @@ -214,7 +214,7 @@ BOOST_AUTO_TEST_CASE(parallel_solver_class) {
refSolver.forward(xs_ref, us_ref, vs_ref, lbdas_ref);
KktError err_ref =
computeKktError(problemRef, xs_ref, us_ref, vs_ref, lbdas_ref, mu, mu);
printKktError(err_ref);
fmt::println("{}", err_ref);
BOOST_CHECK_LE(err_ref.max, tol);
}

Expand All @@ -224,7 +224,7 @@ BOOST_AUTO_TEST_CASE(parallel_solver_class) {
parSolver.backward(mu, mu);
parSolver.forward(xs, us, vs, lbdas);
KktError err = computeKktError(problem, xs, us, vs, lbdas, mu, mu);
printKktError(err);
fmt::println("{}", err);
BOOST_CHECK_LE(err.max, tol);

// TODO: properly test feedback/feedforward gains
Expand Down Expand Up @@ -268,7 +268,7 @@ BOOST_AUTO_TEST_CASE(parallel_solver_class) {
parSolver.backward(mu, mu);
parSolver.forward(xs, us, vs, lbdas);
KktError e = computeKktError(problem, xs, us, vs, lbdas, mu, mu);
printKktError(e);
fmt::println("{}", e);
BOOST_CHECK_LE(e.max, tol);
}
}
8 changes: 4 additions & 4 deletions tests/gar/riccati.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ BOOST_AUTO_TEST_CASE(short_horz_pb) {
// check error
KktError err = computeKktError(prob, xs, us, vs, lbdas);

printKktError(err);
fmt::println("{}", err);

BOOST_CHECK_LE(err.max, 1e-9);

Expand Down Expand Up @@ -134,7 +134,7 @@ BOOST_AUTO_TEST_CASE(random_long_problem) {
fmt::print("Elapsed time (fwd): {:d}\n", t_fwd.count());

KktError err = computeKktError(prob, xs, us, vs, lbdas);
printKktError(err);
fmt::println("{}", err);

BOOST_CHECK_LE(err.max, 1e-9);

Expand All @@ -150,7 +150,7 @@ BOOST_AUTO_TEST_CASE(random_long_problem) {
auto [xsd, usd, vsd, lbdasd] = lqrInitializeSolution(prob);
denseSolver.forward(xsd, usd, vsd, lbdasd);
KktError err = computeKktError(prob, xsd, usd, vsd, lbdasd);
printKktError(err);
fmt::println("{}", err);
BOOST_CHECK_LE(err.max, 1e-9);
}
}
Expand All @@ -176,7 +176,7 @@ BOOST_AUTO_TEST_CASE(parametric) {
fmt::print("e = {}\n", e.transpose());

KktError err = computeKktError(problem, xs, us, vs, lbdas, theta);
printKktError(err);
fmt::println("{}", err);
BOOST_CHECK_LE(err.max, 1e-10);
};

Expand Down
33 changes: 21 additions & 12 deletions tests/gar/test_util.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/// @copyright Copyright (C) 2023 LAAS-CNRS, INRIA
/// @copyright Copyright (C) 2023-2024 LAAS-CNRS, INRIA
#include "./test_util.hpp"
#include "aligator/gar/utils.hpp"

Expand Down Expand Up @@ -26,18 +26,18 @@ knot_t generate_knot(uint nx, uint nu, uint nth, bool singular) {
out.q = VectorXs::NullaryExpr(nx, normal_unary_op{});
out.r = VectorXs::NullaryExpr(nu, normal_unary_op{});

// out.A = MatrixXs::NullaryExpr(nx, nx, normal_unary_op{});
// out.B.setRandom();
// out.E = out.E.NullaryExpr(nx, nx, normal_unary_op{});
// out.E *= 1000;
// out.f = VectorXs::NullaryExpr(nx, normal_unary_op{});
out.A = MatrixXs::NullaryExpr(nx, nx, normal_unary_op{});
out.B.setRandom();
out.E = out.E.NullaryExpr(nx, nx, normal_unary_op{});
out.E *= 1000;
out.f = VectorXs::NullaryExpr(nx, normal_unary_op{});

// if (nth > 0) {
// out.Gx = MatrixXs::NullaryExpr(nx, nth, normal_unary_op{});
// out.Gu = MatrixXs::NullaryExpr(nu, nth, normal_unary_op{});
// out.Gth = sampleWishartDistributedMatrix(nth, nth + 2);
// out.gamma = VectorXs::NullaryExpr(nth, normal_unary_op{});
// }
if (nth > 0) {
out.Gx = MatrixXs::NullaryExpr(nx, nth, normal_unary_op{});
out.Gu = MatrixXs::NullaryExpr(nu, nth, normal_unary_op{});
out.Gth = sampleWishartDistributedMatrix(nth, nth + 2);
out.gamma = VectorXs::NullaryExpr(nth, normal_unary_op{});
}

return out;
}
Expand All @@ -61,6 +61,15 @@ problem_t generate_problem(const ConstVectorRef &x0, uint horz, uint nx,
return prob;
}

auto fmt::formatter<KktError>::format(const KktError &err,
format_context &ctx) const
-> format_context::iterator {
std::string s = fmt::format(
"{{ max: {:.3e}, dual: {:.3e}, cstr: {:.3e}, dyn: {:.3e} }}\n", err.max,
err.dual, err.cstr, err.dyn);
return formatter<std::string>::format(s, ctx);
}

KktError computeKktError(const problem_t &problem, const VectorOfVectors &xs,
const VectorOfVectors &us, const VectorOfVectors &vs,
const VectorOfVectors &lbdas,
Expand Down
12 changes: 5 additions & 7 deletions tests/gar/test_util.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/// @copyright Copyright (C) 2023 LAAS-CNRS, INRIA
/// @copyright Copyright (C) 2023-2024 LAAS-CNRS, INRIA
#pragma once

#include "aligator/gar/lqr-problem.hpp"
Expand All @@ -18,12 +18,10 @@ struct KktError {
double max = std::max({dyn, cstr, dual});
};

inline void printKktError(const KktError &err,
const std::string &msg = "Max KKT error") {
fmt::print("{}: {:.3e}\n", msg, err.max);
fmt::print("> dual: {:.3e}, cstr: {:.3e}, dyn: {:.3e}\n", err.dual, err.cstr,
err.dyn);
}
template <> struct fmt::formatter<KktError> : formatter<std::string> {
auto format(const KktError &err, format_context &ctx) const
-> format_context::iterator;
};

KktError
computeKktError(const problem_t &problem, const VectorOfVectors &xs,
Expand Down

0 comments on commit fb3cb09

Please sign in to comment.