diff --git a/include/heyoka/math/kepF.hpp b/include/heyoka/math/kepF.hpp index 84a362872..a76dc7067 100644 --- a/include/heyoka/math/kepF.hpp +++ b/include/heyoka/math/kepF.hpp @@ -42,15 +42,18 @@ class HEYOKA_DLL_PUBLIC kepF_impl : public func_base { friend class boost::serialization::access; template - void serialize(Archive &ar, unsigned) - { - ar &boost::serialization::base_object(*this); - } + HEYOKA_DLL_LOCAL void serialize(Archive &, unsigned); + + template + HEYOKA_DLL_LOCAL expression diff_impl(funcptr_map &, const T &) const; public: kepF_impl(); explicit kepF_impl(expression, expression, expression); + expression diff(funcptr_map &, const std::string &) const; + expression diff(funcptr_map &, const param &) const; + [[nodiscard]] llvm::Value *llvm_eval(llvm_state &, llvm::Type *, const std::vector &, llvm::Value *, llvm::Value *, llvm::Value *, std::uint32_t, bool) const; diff --git a/src/math/kepE.cpp b/src/math/kepE.cpp index 02bcd2cc7..d1539f00c 100644 --- a/src/math/kepE.cpp +++ b/src/math/kepE.cpp @@ -55,6 +55,7 @@ #include #include #include +#include #include #include #include diff --git a/src/math/kepF.cpp b/src/math/kepF.cpp index 4d665ab0a..b26a4f49c 100644 --- a/src/math/kepF.cpp +++ b/src/math/kepF.cpp @@ -51,6 +51,7 @@ #include #include #include +#include #include #include #include @@ -67,6 +68,38 @@ kepF_impl::kepF_impl(expression h, expression k, expression lam) { } +template +void kepF_impl::serialize(Archive &ar, unsigned) +{ + ar &boost::serialization::base_object(*this); +} + +template +expression kepF_impl::diff_impl(funcptr_map &func_map, const T &s) const +{ + assert(args().size() == 3u); + + const auto &h = args()[0]; + const auto &k = args()[1]; + const auto &lam = args()[2]; + + const expression F{func{*this}}; + + return (detail::diff(func_map, k, s) * sin(F) - detail::diff(func_map, h, s) * cos(F) + + detail::diff(func_map, lam, s)) + / (1_dbl - h * sin(F) - k * cos(F)); +} + +expression kepF_impl::diff(funcptr_map &func_map, const std::string &s) const +{ + return diff_impl(func_map, s); +} + +expression kepF_impl::diff(funcptr_map &func_map, const param &p) const +{ + return diff_impl(func_map, p); +} + namespace { diff --git a/test/kepF.cpp b/test/kepF.cpp index 5cdb1439e..59999d32d 100644 --- a/test/kepF.cpp +++ b/test/kepF.cpp @@ -33,7 +33,9 @@ #include #include +#include #include +#include #include "catch.hpp" #include "test_utils.hpp" @@ -62,6 +64,46 @@ constexpr bool skip_batch_ld = #endif ; +TEST_CASE("kepF def ctor") +{ + detail::kepF_impl k; + + REQUIRE(k.args().size() == 3u); + REQUIRE(k.args()[0] == 0_dbl); + REQUIRE(k.args()[1] == 0_dbl); + REQUIRE(k.args()[2] == 0_dbl); +} + +TEST_CASE("kepF diff") +{ + auto [x, y, z] = make_vars("x", "y", "z"); + + { + REQUIRE(diff(kepF(x, y, z), x) + == -cos(kepF(x, y, z)) / (1_dbl - x * sin(kepF(x, y, z)) - y * cos(kepF(x, y, z)))); + REQUIRE(diff(kepF(x, y, z), y) + == sin(kepF(x, y, z)) / (1_dbl - x * sin(kepF(x, y, z)) - y * cos(kepF(x, y, z)))); + REQUIRE(diff(kepF(x, y, z), z) == 1_dbl / (1_dbl - x * sin(kepF(x, y, z)) - y * cos(kepF(x, y, z)))); + auto F = kepF(x * x, x * y, x * z); + REQUIRE(diff(F, x) == (y * sin(F) - 2_dbl * x * cos(F) + z) / (1_dbl - x * x * sin(F) - x * y * cos(F))); + REQUIRE(diff(F, y) == (x * sin(F)) / (1_dbl - x * x * sin(F) - x * y * cos(F))); + } + + { + REQUIRE(diff(kepF(par[0], y, z), par[0]) + == -cos(kepF(par[0], y, z)) / (1_dbl - par[0] * sin(kepF(par[0], y, z)) - y * cos(kepF(par[0], y, z)))); + REQUIRE(diff(kepF(x, par[1], z), par[1]) + == sin(kepF(x, par[1], z)) / (1_dbl - x * sin(kepF(x, par[1], z)) - par[1] * cos(kepF(x, par[1], z)))); + REQUIRE(diff(kepF(x, y, par[2]), par[2]) + == 1_dbl / (1_dbl - x * sin(kepF(x, y, par[2])) - y * cos(kepF(x, y, par[2])))); + auto F = kepF(par[0] * par[0], par[0] * par[1], par[0] * par[2]); + REQUIRE(diff(F, par[0]) + == (par[1] * sin(F) - 2_dbl * par[0] * cos(F) + par[2]) + / (1_dbl - par[0] * par[0] * sin(F) - par[0] * par[1] * cos(F))); + REQUIRE(diff(F, par[1]) == (par[0] * sin(F)) / (1_dbl - par[0] * par[0] * sin(F) - par[0] * par[1] * cos(F))); + } +} + TEST_CASE("cfunc") { using std::isnan;