Skip to content

Commit

Permalink
Implement and test the overloads.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Oct 14, 2023
1 parent d6dccdd commit 3bfa679
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 0 deletions.
25 changes: 25 additions & 0 deletions include/heyoka/math/kepF.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,31 @@ class HEYOKA_DLL_PUBLIC kepF_impl : public func_base

HEYOKA_DLL_PUBLIC expression kepF(expression, expression, expression);

#define HEYOKA_DECLARE_KEPF_OVERLOADS(type) \
HEYOKA_DLL_PUBLIC expression kepF(expression, type, type); \
HEYOKA_DLL_PUBLIC expression kepF(type, expression, type); \
HEYOKA_DLL_PUBLIC expression kepF(type, type, expression); \
HEYOKA_DLL_PUBLIC expression kepF(expression, expression, type); \
HEYOKA_DLL_PUBLIC expression kepF(expression, type, expression); \
HEYOKA_DLL_PUBLIC expression kepF(type, expression, expression)

HEYOKA_DECLARE_KEPF_OVERLOADS(double);
HEYOKA_DECLARE_KEPF_OVERLOADS(long double);

#if defined(HEYOKA_HAVE_REAL128)

HEYOKA_DECLARE_KEPF_OVERLOADS(mppp::real128);

#endif

#if defined(HEYOKA_HAVE_REAL)

HEYOKA_DECLARE_KEPF_OVERLOADS(mppp::real);

#endif

#undef HEYOKA_DECLARE_KEPF_OVERLOADS

HEYOKA_END_NAMESPACE

HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::kepF_impl)
Expand Down
43 changes: 43 additions & 0 deletions src/math/kepF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,49 @@ expression kepF(expression h, expression k, expression lam)
return expression{func{detail::kepF_impl{std::move(h), std::move(k), std::move(lam)}}};
}

#define HEYOKA_DEFINE_KEPF_OVERLOADS(type) \
expression kepF(expression h, type k, type lam) \
{ \
return kepF(std::move(h), expression{std::move(k)}, expression{std::move(lam)}); \
} \
expression kepF(type h, expression k, type lam) \
{ \
return kepF(expression{std::move(h)}, std::move(k), expression{std::move(lam)}); \
} \
expression kepF(type h, type k, expression lam) \
{ \
return kepF(expression{std::move(h)}, expression{std::move(k)}, std::move(lam)); \
} \
expression kepF(expression h, expression k, type lam) \
{ \
return kepF(std::move(h), std::move(k), expression{std::move(lam)}); \
} \
expression kepF(expression h, type k, expression lam) \
{ \
return kepF(std::move(h), expression{std::move(k)}, std::move(lam)); \
} \
expression kepF(type h, expression k, expression lam) \
{ \
return kepF(expression{std::move(h)}, std::move(k), std::move(lam)); \
}

HEYOKA_DEFINE_KEPF_OVERLOADS(double)
HEYOKA_DEFINE_KEPF_OVERLOADS(long double)

#if defined(HEYOKA_HAVE_REAL128)

HEYOKA_DEFINE_KEPF_OVERLOADS(mppp::real128);

#endif

#if defined(HEYOKA_HAVE_REAL)

HEYOKA_DEFINE_KEPF_OVERLOADS(mppp::real);

#endif

#undef HEYOKA_DEFINE_KEPF_OVERLOADS

HEYOKA_END_NAMESPACE

HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::kepF_impl)
58 changes: 58 additions & 0 deletions test/kepF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,64 @@ TEST_CASE("kepF diff")
}
}

#define HEYOKA_TEST_KEPF_OVERLOAD(type) \
{ \
auto k = kepF("x"_var, static_cast<type>(1.1), static_cast<type>(1.3)); \
REQUIRE(std::get<func>(k.value()).args()[0] == "x"_var); \
REQUIRE(std::get<number>(std::get<func>(k.value()).args()[1].value()) == number{static_cast<type>(1.1)}); \
REQUIRE(std::get<number>(std::get<func>(k.value()).args()[2].value()) == number{static_cast<type>(1.3)}); \
} \
{ \
auto k = kepF(static_cast<type>(1.1), "y"_var, static_cast<type>(1.3)); \
REQUIRE(std::get<number>(std::get<func>(k.value()).args()[0].value()) == number{static_cast<type>(1.1)}); \
REQUIRE(std::get<func>(k.value()).args()[1] == "y"_var); \
REQUIRE(std::get<number>(std::get<func>(k.value()).args()[2].value()) == number{static_cast<type>(1.3)}); \
} \
{ \
auto k = kepF(static_cast<type>(1.1), static_cast<type>(1.3), "z"_var); \
REQUIRE(std::get<number>(std::get<func>(k.value()).args()[0].value()) == number{static_cast<type>(1.1)}); \
REQUIRE(std::get<number>(std::get<func>(k.value()).args()[1].value()) == number{static_cast<type>(1.3)}); \
REQUIRE(std::get<func>(k.value()).args()[2] == "z"_var); \
} \
{ \
auto k = kepF("x"_var, "y"_var, static_cast<type>(1.3)); \
REQUIRE(std::get<func>(k.value()).args()[0] == "x"_var); \
REQUIRE(std::get<func>(k.value()).args()[1] == "y"_var); \
REQUIRE(std::get<number>(std::get<func>(k.value()).args()[2].value()) == number{static_cast<type>(1.3)}); \
} \
{ \
auto k = kepF("x"_var, static_cast<type>(1.3), "z"_var); \
REQUIRE(std::get<func>(k.value()).args()[0] == "x"_var); \
REQUIRE(std::get<number>(std::get<func>(k.value()).args()[1].value()) == number{static_cast<type>(1.3)}); \
REQUIRE(std::get<func>(k.value()).args()[2] == "z"_var); \
} \
{ \
auto k = kepF(static_cast<type>(1.3), "y"_var, "z"_var); \
REQUIRE(std::get<number>(std::get<func>(k.value()).args()[0].value()) == number{static_cast<type>(1.3)}); \
REQUIRE(std::get<func>(k.value()).args()[1] == "y"_var); \
REQUIRE(std::get<func>(k.value()).args()[2] == "z"_var); \
}

TEST_CASE("kepF overloads")
{
HEYOKA_TEST_KEPF_OVERLOAD(double);
HEYOKA_TEST_KEPF_OVERLOAD(long double);

#if defined(HEYOKA_HAVE_REAL128)

HEYOKA_TEST_KEPF_OVERLOAD(mppp::real128);

#endif

#if defined(HEYOKA_HAVE_REAL128)

HEYOKA_TEST_KEPF_OVERLOAD(mppp::real);

#endif
}

#undef HEYOKA_TEST_KEPF_OVERLOAD

TEST_CASE("kepF s11n")
{
std::stringstream ss;
Expand Down

0 comments on commit 3bfa679

Please sign in to comment.