Skip to content

Commit

Permalink
Fix, test additions.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Sep 8, 2024
1 parent f78b662 commit ee0b28e
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 1 deletion.
15 changes: 15 additions & 0 deletions include/heyoka/taylor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ class HEYOKA_DLL_PUBLIC_INLINE_CLASS taylor_adaptive : public detail::taylor_ada
}
};

// Deduction guides to enable CTAD when the initial state is passed via std::initializer_list.
template <typename T, typename... KwArgs>
requires(!igor::has_unnamed_arguments<KwArgs...>())
explicit taylor_adaptive(std::vector<std::pair<expression, expression>>, std::initializer_list<T>,
Expand Down Expand Up @@ -952,6 +953,13 @@ class HEYOKA_DLL_PUBLIC_INLINE_CLASS taylor_adaptive_batch
{
finalise_ctor(std::move(sys), std::move(state), batch_size, kw_args...);
}
template <typename... KwArgs>
requires(!igor::has_unnamed_arguments<KwArgs...>())
explicit taylor_adaptive_batch(std::vector<std::pair<expression, expression>> sys, std::uint32_t batch_size,
const KwArgs &...kw_args)
: taylor_adaptive_batch(std::move(sys), std::vector<T>{}, batch_size, kw_args...)
{
}
template <typename... KwArgs>
requires(!igor::has_unnamed_arguments<KwArgs...>())
explicit taylor_adaptive_batch(var_ode_sys sys, std::vector<T> state, std::uint32_t batch_size,
Expand All @@ -960,6 +968,12 @@ class HEYOKA_DLL_PUBLIC_INLINE_CLASS taylor_adaptive_batch
{
finalise_ctor(std::move(sys), std::move(state), batch_size, kw_args...);
}
template <typename... KwArgs>
requires(!igor::has_unnamed_arguments<KwArgs...>())
explicit taylor_adaptive_batch(var_ode_sys sys, std::uint32_t batch_size, const KwArgs &...kw_args)
: taylor_adaptive_batch(std::move(sys), std::vector<T>{}, batch_size, kw_args...)
{
}

taylor_adaptive_batch(const taylor_adaptive_batch &);
taylor_adaptive_batch(taylor_adaptive_batch &&) noexcept;
Expand Down Expand Up @@ -1123,6 +1137,7 @@ class HEYOKA_DLL_PUBLIC_INLINE_CLASS taylor_adaptive_batch
[[nodiscard]] const std::vector<std::tuple<taylor_outcome, T, T, std::size_t>> &get_propagate_res() const;
};

// Deduction guides to enable CTAD when the initial state is passed via std::initializer_list.
template <typename T, typename... KwArgs>
requires(!igor::has_unnamed_arguments<KwArgs...>())
explicit taylor_adaptive_batch(std::vector<std::pair<expression, expression>>, std::initializer_list<T>, std::uint32_t,
Expand Down
2 changes: 1 addition & 1 deletion src/taylor_adaptive_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ void taylor_adaptive_batch<T>::finalise_ctor_impl(sys_t vsys, std::vector<T> sta
if (state.empty()) {
// NOTE: we will perform further initialisation for the variational quantities
// at a later stage, if needed.
state.resize(boost::numeric_cast<decltype(state.size())>(n_orig_sv), static_cast<T>(0));
state.resize(boost::safe_numerics::safe<decltype(state.size())>(n_orig_sv) * m_batch_size, static_cast<T>(0));
}

// Assign the state.
Expand Down
16 changes: 16 additions & 0 deletions test/taylor_adaptive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#include <heyoka/math/sum.hpp>
#include <heyoka/math/time.hpp>
#include <heyoka/model/nbody.hpp>
#include <heyoka/model/pendulum.hpp>
#include <heyoka/number.hpp>
#include <heyoka/s11n.hpp>
#include <heyoka/step_callback.hpp>
Expand Down Expand Up @@ -2525,3 +2526,18 @@ TEST_CASE("invalid initial state")
Message("Inconsistent sizes detected in the initialization of an adaptive Taylor "
"integrator: the state vector has a dimension of 1, while the number of equations is 2"));
}

TEST_CASE("empty init state")
{
const auto dyn = model::pendulum();

{
auto ta = taylor_adaptive<double>{dyn};
REQUIRE(ta.get_state() == std::vector{0., 0.});
}

{
auto ta = taylor_adaptive<double>{var_ode_sys(dyn, var_args::vars)};
REQUIRE(ta.get_state() == std::vector{0.0, 0.0, 1.0, 0.0, 0.0, 1.0});
}
}
16 changes: 16 additions & 0 deletions test/taylor_adaptive_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <heyoka/math/sin.hpp>
#include <heyoka/math/time.hpp>
#include <heyoka/model/nbody.hpp>
#include <heyoka/model/pendulum.hpp>
#include <heyoka/s11n.hpp>
#include <heyoka/step_callback.hpp>
#include <heyoka/taylor.hpp>
Expand Down Expand Up @@ -2217,3 +2218,18 @@ TEST_CASE("invalid initial state")
"integrator: the state vector has a dimension of 1 and a batch size of 2, "
"while the number of equations is 2"));
}

TEST_CASE("empty init state")
{
const auto dyn = model::pendulum();

{
auto ta = taylor_adaptive_batch<double>{dyn, 2u};
REQUIRE(ta.get_state() == std::vector{0., 0., 0., 0.});
}

{
auto ta = taylor_adaptive_batch<double>{var_ode_sys(dyn, var_args::vars), 2u};
REQUIRE(ta.get_state() == std::vector{0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0});
}
}
20 changes: 20 additions & 0 deletions test/taylor_adaptive_mp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <heyoka/llvm_state.hpp>
#include <heyoka/math/sin.hpp>
#include <heyoka/math/time.hpp>
#include <heyoka/model/pendulum.hpp>
#include <heyoka/s11n.hpp>
#include <heyoka/taylor.hpp>

Expand Down Expand Up @@ -1158,3 +1159,22 @@ TEST_CASE("s11n")
}
}
}

TEST_CASE("empty init state")
{
const auto dyn = model::pendulum();

const auto prec = 23;

{
auto ta = taylor_adaptive<mppp::real>{dyn, kw::prec = prec};
REQUIRE(ta.get_state() == std::vector<mppp::real>{0., 0.});
REQUIRE(std::ranges::all_of(ta.get_state(), [&](const auto &val) { return val.get_prec() == prec; }));
}

{
auto ta = taylor_adaptive<mppp::real>{var_ode_sys(dyn, var_args::vars), kw::prec = prec};
REQUIRE(ta.get_state() == std::vector<mppp::real>{0.0, 0.0, 1.0, 0.0, 0.0, 1.0});
REQUIRE(std::ranges::all_of(ta.get_state(), [&](const auto &val) { return val.get_prec() == prec; }));
}
}

0 comments on commit ee0b28e

Please sign in to comment.