diff --git a/include/heyoka/math/relational.hpp b/include/heyoka/math/relational.hpp index eafb628bd..088d1388d 100644 --- a/include/heyoka/math/relational.hpp +++ b/include/heyoka/math/relational.hpp @@ -9,11 +9,24 @@ #ifndef HEYOKA_MATH_RELATIONAL_HPP #define HEYOKA_MATH_RELATIONAL_HPP +#include + #include #include #include -#include +#if defined(HEYOKA_HAVE_REAL128) + +#include + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +#include + +#endif + #include #include #include @@ -71,6 +84,38 @@ HEYOKA_DLL_PUBLIC expression gt(expression, expression); HEYOKA_DLL_PUBLIC expression lte(expression, expression); HEYOKA_DLL_PUBLIC expression gte(expression, expression); +#define HEYOKA_DECLARE_REL_OVERLOADS(type) \ + HEYOKA_DLL_PUBLIC expression eq(expression, type); \ + HEYOKA_DLL_PUBLIC expression eq(type, expression); \ + HEYOKA_DLL_PUBLIC expression neq(expression, type); \ + HEYOKA_DLL_PUBLIC expression neq(type, expression); \ + HEYOKA_DLL_PUBLIC expression lt(expression, type); \ + HEYOKA_DLL_PUBLIC expression lt(type, expression); \ + HEYOKA_DLL_PUBLIC expression gt(expression, type); \ + HEYOKA_DLL_PUBLIC expression gt(type, expression); \ + HEYOKA_DLL_PUBLIC expression lte(expression, type); \ + HEYOKA_DLL_PUBLIC expression lte(type, expression); \ + HEYOKA_DLL_PUBLIC expression gte(expression, type); \ + HEYOKA_DLL_PUBLIC expression gte(type, expression); + +HEYOKA_DECLARE_REL_OVERLOADS(float); +HEYOKA_DECLARE_REL_OVERLOADS(double); +HEYOKA_DECLARE_REL_OVERLOADS(long double); + +#if defined(HEYOKA_HAVE_REAL128) + +HEYOKA_DECLARE_REL_OVERLOADS(mppp::real128); + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +HEYOKA_DECLARE_REL_OVERLOADS(mppp::real); + +#endif + +#undef HEYOKA_DECLARE_REL_OVERLOADS + HEYOKA_END_NAMESPACE HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::rel_impl) diff --git a/include/heyoka/math/select.hpp b/include/heyoka/math/select.hpp index 7f6b10eea..079378865 100644 --- a/include/heyoka/math/select.hpp +++ b/include/heyoka/math/select.hpp @@ -9,10 +9,23 @@ #ifndef HEYOKA_MATH_SELECT_HPP #define HEYOKA_MATH_SELECT_HPP +#include + #include #include -#include +#if defined(HEYOKA_HAVE_REAL128) + +#include + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +#include + +#endif + #include #include #include @@ -56,6 +69,32 @@ class HEYOKA_DLL_PUBLIC select_impl : public func_base HEYOKA_DLL_PUBLIC expression select(expression, expression, expression); +#define HEYOKA_DECLARE_SELECT_OVERLOADS(type) \ + HEYOKA_DLL_PUBLIC expression select(expression, type, type); \ + HEYOKA_DLL_PUBLIC expression select(type, expression, type); \ + HEYOKA_DLL_PUBLIC expression select(type, type, expression); \ + HEYOKA_DLL_PUBLIC expression select(expression, expression, type); \ + HEYOKA_DLL_PUBLIC expression select(expression, type, expression); \ + HEYOKA_DLL_PUBLIC expression select(type, expression, expression) + +HEYOKA_DECLARE_SELECT_OVERLOADS(float); +HEYOKA_DECLARE_SELECT_OVERLOADS(double); +HEYOKA_DECLARE_SELECT_OVERLOADS(long double); + +#if defined(HEYOKA_HAVE_REAL128) + +HEYOKA_DECLARE_SELECT_OVERLOADS(mppp::real128); + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +HEYOKA_DECLARE_SELECT_OVERLOADS(mppp::real); + +#endif + +#undef HEYOKA_DECLARE_SELECT_OVERLOADS + HEYOKA_END_NAMESPACE HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::select_impl) diff --git a/src/math/relational.cpp b/src/math/relational.cpp index a1c2486e9..ed50d5bf8 100644 --- a/src/math/relational.cpp +++ b/src/math/relational.cpp @@ -6,6 +6,8 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +#include + #include #include #include @@ -26,7 +28,18 @@ #include #include -#include +#if defined(HEYOKA_HAVE_REAL128) + +#include + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +#include + +#endif + #include #include #include @@ -355,6 +368,76 @@ HEYOKA_MATH_REL_IMPL(gte) #undef HEYOKA_MATH_REL_IMPL +// NOTE: this macro was copy-pasted from the kepE() overloads, hence +// the weird (but inconsequential) naming of the function arguments. +#define HEYOKA_DEFINE_REL_OVERLOADS(type) \ + expression eq(expression e, type M) \ + { \ + return eq(std::move(e), expression{std::move(M)}); \ + } \ + expression eq(type e, expression M) \ + { \ + return eq(expression{std::move(e)}, std::move(M)); \ + } \ + expression neq(expression e, type M) \ + { \ + return neq(std::move(e), expression{std::move(M)}); \ + } \ + expression neq(type e, expression M) \ + { \ + return neq(expression{std::move(e)}, std::move(M)); \ + } \ + expression lt(expression e, type M) \ + { \ + return lt(std::move(e), expression{std::move(M)}); \ + } \ + expression lt(type e, expression M) \ + { \ + return lt(expression{std::move(e)}, std::move(M)); \ + } \ + expression gt(expression e, type M) \ + { \ + return gt(std::move(e), expression{std::move(M)}); \ + } \ + expression gt(type e, expression M) \ + { \ + return gt(expression{std::move(e)}, std::move(M)); \ + } \ + expression lte(expression e, type M) \ + { \ + return lte(std::move(e), expression{std::move(M)}); \ + } \ + expression lte(type e, expression M) \ + { \ + return lte(expression{std::move(e)}, std::move(M)); \ + } \ + expression gte(expression e, type M) \ + { \ + return gte(std::move(e), expression{std::move(M)}); \ + } \ + expression gte(type e, expression M) \ + { \ + return gte(expression{std::move(e)}, std::move(M)); \ + } + +HEYOKA_DEFINE_REL_OVERLOADS(float) +HEYOKA_DEFINE_REL_OVERLOADS(double) +HEYOKA_DEFINE_REL_OVERLOADS(long double) + +#if defined(HEYOKA_HAVE_REAL128) + +HEYOKA_DEFINE_REL_OVERLOADS(mppp::real128); + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +HEYOKA_DEFINE_REL_OVERLOADS(mppp::real); + +#endif + +#undef HEYOKA_DEFINE_REL_OVERLOADS + HEYOKA_END_NAMESPACE // NOLINTNEXTLINE(cert-err58-cpp) diff --git a/src/math/select.cpp b/src/math/select.cpp index 338a2d580..10c487a01 100644 --- a/src/math/select.cpp +++ b/src/math/select.cpp @@ -6,6 +6,8 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +#include + #include #include #include @@ -23,7 +25,18 @@ #include #include -#include +#if defined(HEYOKA_HAVE_REAL128) + +#include + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +#include + +#endif + #include #include #include @@ -268,6 +281,52 @@ expression select(expression cond, expression t, expression f) return expression{func{detail::select_impl{std::move(cond), std::move(t), std::move(f)}}}; } +// NOTE: this macro was copy-pasted from the kepDE() overloads, hence +// the weird (but inconsequential) naming of the function arguments. +#define HEYOKA_DEFINE_SELECT_OVERLOADS(type) \ + expression select(expression s0, type c0, type DM) \ + { \ + return select(std::move(s0), expression{std::move(c0)}, expression{std::move(DM)}); \ + } \ + expression select(type s0, expression c0, type DM) \ + { \ + return select(expression{std::move(s0)}, std::move(c0), expression{std::move(DM)}); \ + } \ + expression select(type s0, type c0, expression DM) \ + { \ + return select(expression{std::move(s0)}, expression{std::move(c0)}, std::move(DM)); \ + } \ + expression select(expression s0, expression c0, type DM) \ + { \ + return select(std::move(s0), std::move(c0), expression{std::move(DM)}); \ + } \ + expression select(expression s0, type c0, expression DM) \ + { \ + return select(std::move(s0), expression{std::move(c0)}, std::move(DM)); \ + } \ + expression select(type s0, expression c0, expression DM) \ + { \ + return select(expression{std::move(s0)}, std::move(c0), std::move(DM)); \ + } + +HEYOKA_DEFINE_SELECT_OVERLOADS(float) +HEYOKA_DEFINE_SELECT_OVERLOADS(double) +HEYOKA_DEFINE_SELECT_OVERLOADS(long double) + +#if defined(HEYOKA_HAVE_REAL128) + +HEYOKA_DEFINE_SELECT_OVERLOADS(mppp::real128); + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +HEYOKA_DEFINE_SELECT_OVERLOADS(mppp::real); + +#endif + +#undef HEYOKA_DEFINE_SELECT_OVERLOADS + HEYOKA_END_NAMESPACE // NOLINTNEXTLINE(cert-err58-cpp) diff --git a/test/rel.cpp b/test/rel.cpp index b6b43ab17..3b9123113 100644 --- a/test/rel.cpp +++ b/test/rel.cpp @@ -80,6 +80,22 @@ TEST_CASE("basic") REQUIRE(eq(x, y) != neq(x, y)); REQUIRE(lte(x, y) != gte(x, y)); REQUIRE(lte(x, y) == lte(x, y)); + + // Test a couple of numerical overloads too. + REQUIRE(eq(x, 1.) == eq(x, 1_dbl)); + REQUIRE(lte(1.l, par[0]) == lte(1_ldbl, par[0])); + +#if defined(HEYOKA_HAVE_REAL128) + + REQUIRE(lte(mppp::real128{"1.1"}, par[0]) == lte(1.1_f128, par[0])); + +#endif + +#if defined(HEYOKA_HAVE_REAL) + + REQUIRE(lte(mppp::real{"1.1", 14}, par[0]) == lte(expression{mppp::real{"1.1", 14}}, par[0])); + +#endif } TEST_CASE("stream") diff --git a/test/select.cpp b/test/select.cpp index 9390bd6b7..df2afa337 100644 --- a/test/select.cpp +++ b/test/select.cpp @@ -76,6 +76,23 @@ TEST_CASE("basic") auto [x, y, z] = make_vars("x", "y", "z"); REQUIRE(expression{func{detail::select_impl{}}} == expression{func{detail::select_impl{0_dbl, 0_dbl, 0_dbl}}}); + + // A couple of tests for the numeric overloads. + REQUIRE(select(1., x, 2.) == select(1_dbl, x, 2_dbl)); + REQUIRE(select(x, par[0], 2.f) == select(x, par[0], 2_flt)); + +#if defined(HEYOKA_HAVE_REAL128) + + REQUIRE(select(mppp::real128{"3.1"}, x, mppp::real128{"2.1"}) == select(3.1_f128, x, 2.1_f128)); + +#endif + +#if defined(HEYOKA_HAVE_REAL) + + REQUIRE(select(mppp::real{"3.1", 14}, x, mppp::real{"2.1", 14}) + == select(expression{mppp::real{"3.1", 14}}, x, expression{mppp::real{"2.1", 14}})); + +#endif } TEST_CASE("stream")