From fb3cb0953cfeebff28c7872eece11c3d6ce48a4d Mon Sep 17 00:00:00 2001 From: ManifoldFR Date: Thu, 17 Oct 2024 11:14:59 +0200 Subject: [PATCH] test_util.cpp: various fixes + add fmt::formatter specialization Update copyright headings --- tests/gar/parallel.cpp | 10 +++++----- tests/gar/riccati.cpp | 8 ++++---- tests/gar/test_util.cpp | 33 +++++++++++++++++++++------------ tests/gar/test_util.hpp | 12 +++++------- 4 files changed, 35 insertions(+), 28 deletions(-) diff --git a/tests/gar/parallel.cpp b/tests/gar/parallel.cpp index 74d8ba88f..add388643 100644 --- a/tests/gar/parallel.cpp +++ b/tests/gar/parallel.cpp @@ -80,7 +80,7 @@ BOOST_AUTO_TEST_CASE(parallel_manual) { bool ret = solver_full_horz.forward(xs, us, vs, lbdas); BOOST_CHECK(ret); KktError err_full = computeKktError(problem, xs, us, vs, lbdas); - printKktError(err_full, "KKT error (full horz.)"); + fmt::println("KKT error (full horz.): {}", err_full); BOOST_CHECK_LE(err_full.max, EPS); } @@ -170,7 +170,7 @@ BOOST_AUTO_TEST_CASE(parallel_manual) { KktError err_merged = computeKktError(problem, xs_merged, us_merged, vs_merged, lbdas_merged); - printKktError(err_merged, "KKT error (merged)"); + fmt::println("KKT error (merged) {}", err_merged); } /// Randomize some of the parameters of the problem. This simulates something @@ -214,7 +214,7 @@ BOOST_AUTO_TEST_CASE(parallel_solver_class) { refSolver.forward(xs_ref, us_ref, vs_ref, lbdas_ref); KktError err_ref = computeKktError(problemRef, xs_ref, us_ref, vs_ref, lbdas_ref, mu, mu); - printKktError(err_ref); + fmt::println("{}", err_ref); BOOST_CHECK_LE(err_ref.max, tol); } @@ -224,7 +224,7 @@ BOOST_AUTO_TEST_CASE(parallel_solver_class) { parSolver.backward(mu, mu); parSolver.forward(xs, us, vs, lbdas); KktError err = computeKktError(problem, xs, us, vs, lbdas, mu, mu); - printKktError(err); + fmt::println("{}", err); BOOST_CHECK_LE(err.max, tol); // TODO: properly test feedback/feedforward gains @@ -268,7 +268,7 @@ BOOST_AUTO_TEST_CASE(parallel_solver_class) { parSolver.backward(mu, mu); parSolver.forward(xs, us, vs, lbdas); KktError e = computeKktError(problem, xs, us, vs, lbdas, mu, mu); - printKktError(e); + fmt::println("{}", e); BOOST_CHECK_LE(e.max, tol); } } diff --git a/tests/gar/riccati.cpp b/tests/gar/riccati.cpp index 62a89600c..8d516ba04 100644 --- a/tests/gar/riccati.cpp +++ b/tests/gar/riccati.cpp @@ -79,7 +79,7 @@ BOOST_AUTO_TEST_CASE(short_horz_pb) { // check error KktError err = computeKktError(prob, xs, us, vs, lbdas); - printKktError(err); + fmt::println("{}", err); BOOST_CHECK_LE(err.max, 1e-9); @@ -134,7 +134,7 @@ BOOST_AUTO_TEST_CASE(random_long_problem) { fmt::print("Elapsed time (fwd): {:d}\n", t_fwd.count()); KktError err = computeKktError(prob, xs, us, vs, lbdas); - printKktError(err); + fmt::println("{}", err); BOOST_CHECK_LE(err.max, 1e-9); @@ -150,7 +150,7 @@ BOOST_AUTO_TEST_CASE(random_long_problem) { auto [xsd, usd, vsd, lbdasd] = lqrInitializeSolution(prob); denseSolver.forward(xsd, usd, vsd, lbdasd); KktError err = computeKktError(prob, xsd, usd, vsd, lbdasd); - printKktError(err); + fmt::println("{}", err); BOOST_CHECK_LE(err.max, 1e-9); } } @@ -176,7 +176,7 @@ BOOST_AUTO_TEST_CASE(parametric) { fmt::print("e = {}\n", e.transpose()); KktError err = computeKktError(problem, xs, us, vs, lbdas, theta); - printKktError(err); + fmt::println("{}", err); BOOST_CHECK_LE(err.max, 1e-10); }; diff --git a/tests/gar/test_util.cpp b/tests/gar/test_util.cpp index 39c9a0113..c2610be0f 100644 --- a/tests/gar/test_util.cpp +++ b/tests/gar/test_util.cpp @@ -1,4 +1,4 @@ -/// @copyright Copyright (C) 2023 LAAS-CNRS, INRIA +/// @copyright Copyright (C) 2023-2024 LAAS-CNRS, INRIA #include "./test_util.hpp" #include "aligator/gar/utils.hpp" @@ -26,18 +26,18 @@ knot_t generate_knot(uint nx, uint nu, uint nth, bool singular) { out.q = VectorXs::NullaryExpr(nx, normal_unary_op{}); out.r = VectorXs::NullaryExpr(nu, normal_unary_op{}); - // out.A = MatrixXs::NullaryExpr(nx, nx, normal_unary_op{}); - // out.B.setRandom(); - // out.E = out.E.NullaryExpr(nx, nx, normal_unary_op{}); - // out.E *= 1000; - // out.f = VectorXs::NullaryExpr(nx, normal_unary_op{}); + out.A = MatrixXs::NullaryExpr(nx, nx, normal_unary_op{}); + out.B.setRandom(); + out.E = out.E.NullaryExpr(nx, nx, normal_unary_op{}); + out.E *= 1000; + out.f = VectorXs::NullaryExpr(nx, normal_unary_op{}); - // if (nth > 0) { - // out.Gx = MatrixXs::NullaryExpr(nx, nth, normal_unary_op{}); - // out.Gu = MatrixXs::NullaryExpr(nu, nth, normal_unary_op{}); - // out.Gth = sampleWishartDistributedMatrix(nth, nth + 2); - // out.gamma = VectorXs::NullaryExpr(nth, normal_unary_op{}); - // } + if (nth > 0) { + out.Gx = MatrixXs::NullaryExpr(nx, nth, normal_unary_op{}); + out.Gu = MatrixXs::NullaryExpr(nu, nth, normal_unary_op{}); + out.Gth = sampleWishartDistributedMatrix(nth, nth + 2); + out.gamma = VectorXs::NullaryExpr(nth, normal_unary_op{}); + } return out; } @@ -61,6 +61,15 @@ problem_t generate_problem(const ConstVectorRef &x0, uint horz, uint nx, return prob; } +auto fmt::formatter::format(const KktError &err, + format_context &ctx) const + -> format_context::iterator { + std::string s = fmt::format( + "{{ max: {:.3e}, dual: {:.3e}, cstr: {:.3e}, dyn: {:.3e} }}\n", err.max, + err.dual, err.cstr, err.dyn); + return formatter::format(s, ctx); +} + KktError computeKktError(const problem_t &problem, const VectorOfVectors &xs, const VectorOfVectors &us, const VectorOfVectors &vs, const VectorOfVectors &lbdas, diff --git a/tests/gar/test_util.hpp b/tests/gar/test_util.hpp index c24c74df0..d28395f9a 100644 --- a/tests/gar/test_util.hpp +++ b/tests/gar/test_util.hpp @@ -1,4 +1,4 @@ -/// @copyright Copyright (C) 2023 LAAS-CNRS, INRIA +/// @copyright Copyright (C) 2023-2024 LAAS-CNRS, INRIA #pragma once #include "aligator/gar/lqr-problem.hpp" @@ -18,12 +18,10 @@ struct KktError { double max = std::max({dyn, cstr, dual}); }; -inline void printKktError(const KktError &err, - const std::string &msg = "Max KKT error") { - fmt::print("{}: {:.3e}\n", msg, err.max); - fmt::print("> dual: {:.3e}, cstr: {:.3e}, dyn: {:.3e}\n", err.dual, err.cstr, - err.dyn); -} +template <> struct fmt::formatter : formatter { + auto format(const KktError &err, format_context &ctx) const + -> format_context::iterator; +}; KktError computeKktError(const problem_t &problem, const VectorOfVectors &xs,