Skip to content

Commit

Permalink
Fix bisection bounds, initial testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Oct 14, 2023
1 parent 96b0e61 commit 1873b3f
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 9 deletions.
39 changes: 30 additions & 9 deletions src/detail/llvm_helpers_celmec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,9 @@ llvm::Function *llvm_add_inv_kep_F(llvm_state &s, llvm::Type *fp_t, std::uint32_
// Clamp the initial guess to the [0, 2pi) range.
// NOTE: in case ig ends up being NaN (because the arguments are nan or for whatever
// other reason), then ig will remain NaN.
const auto twopi_num = inv_kep_E_dl_twopi_like(s, fp_t).first;
auto *lb = llvm_constantfp(s, tp, 0.);
auto *ub = llvm_codegen(s, tp, nextafter(inv_kep_E_dl_twopi_like(s, fp_t).first, number_like(s, fp_t, 0.)));
auto *ub = llvm_codegen(s, tp, nextafter(twopi_num, number_like(s, fp_t, 0.)));
ig = llvm_clamp(s, ig, lb, ub);

// Store the initial guess in the storage for the return value. This will hold the
Expand Down Expand Up @@ -832,6 +833,11 @@ llvm::Function *llvm_add_inv_kep_F(llvm_state &s, llvm::Type *fp_t, std::uint32_
return builder.CreateSelect(c_cond, tol_cond, llvm::ConstantInt::get(tol_cond->getType(), 0u));
};

// Compute the bisection bounds - i.e., the bounds which are guaranteed
// to contain the root. These are [-1, 2pi + 1).
auto *lb_bisec = llvm_constantfp(s, tp, -1.);
auto *ub_bisec = llvm_codegen(s, tp, nextafter(twopi_num + number_like(s, fp_t, 1.), number_like(s, fp_t, 0.)));

// Run the loop.
llvm_while_loop(s, loop_cond, [&, one_c = llvm_constantfp(s, tp, 1.)]() {
// Compute the new value via the Newton-Raphson formula.
Expand All @@ -841,17 +847,16 @@ llvm::Function *llvm_add_inv_kep_F(llvm_state &s, llvm::Type *fp_t, std::uint32_
auto *fdiv = llvm_fdiv(s, builder.CreateLoad(tp, fF), diff);
auto *new_val = llvm_fsub(s, old_val, fdiv);

// Bisect if new_val > ub.
// NOTE: '>' is fine here, ub is the maximum allowed value.
auto *bcheck = llvm_fcmp_ogt(s, new_val, ub);
// Bisect if new_val > ub_bisec.
auto *bcheck = llvm_fcmp_ogt(s, new_val, ub_bisec);
new_val = builder.CreateSelect(
bcheck, llvm_fmul(s, llvm_codegen(s, tp, number_like(s, fp_t, 1. / 2)), llvm_fadd(s, old_val, ub)),
bcheck, llvm_fmul(s, llvm_codegen(s, tp, number_like(s, fp_t, 1. / 2)), llvm_fadd(s, old_val, ub_bisec)),
new_val);

// Bisect if new_val < lb.
bcheck = llvm_fcmp_olt(s, new_val, lb);
// Bisect if new_val < lb_bisec.
bcheck = llvm_fcmp_olt(s, new_val, lb_bisec);
new_val = builder.CreateSelect(
bcheck, llvm_fmul(s, llvm_codegen(s, tp, number_like(s, fp_t, 1. / 2)), llvm_fadd(s, old_val, lb)),
bcheck, llvm_fmul(s, llvm_codegen(s, tp, number_like(s, fp_t, 1. / 2)), llvm_fadd(s, old_val, lb_bisec)),
new_val);

// Store the new value.
Expand Down Expand Up @@ -889,8 +894,24 @@ llvm::Function *llvm_add_inv_kep_F(llvm_state &s, llvm::Type *fp_t, std::uint32_
},
[]() {});

// Load the result.
llvm::Value *ret = builder.CreateLoad(tp, retval);

// Codegen 2pi, used below.
auto *twopi_const = llvm_codegen(s, tp, twopi_num);

// Reduce the result to the standard trigonometric range [0, 2pi).
// NOTE: this reduction will not change ret if it is NaN.
// Is ret < 0?
auto *ret_lt_0 = llvm_fcmp_olt(s, ret, llvm_constantfp(s, tp, 0.));
ret = builder.CreateSelect(ret_lt_0, llvm_fadd(s, twopi_const, ret), ret);

// Is ret >= 2pi?
auto *ret_ge_2pi = llvm_fcmp_oge(s, ret, twopi_const);
ret = builder.CreateSelect(ret_ge_2pi, llvm_fsub(s, ret, twopi_const), ret);

// Return the result.
builder.CreateRet(builder.CreateLoad(tp, retval));
builder.CreateRet(ret);

// Verify.
s.verify_function(f);
Expand Down
174 changes: 174 additions & 0 deletions test/kepF.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Copyright 2020, 2021, 2022, 2023 Francesco Biscani ([email protected]), Dario Izzo ([email protected])
//
// This file is part of the heyoka library.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include <heyoka/config.hpp>

#include <cmath>
#include <initializer_list>
#include <limits>
#include <random>
#include <tuple>
#include <type_traits>
#include <vector>

#include <boost/algorithm/string/predicate.hpp>

#include <llvm/Config/llvm-config.h>

#if defined(HEYOKA_HAVE_REAL128)

#include <mp++/real128.hpp>

#endif

#if defined(HEYOKA_HAVE_REAL)

#include <mp++/real.hpp>

#endif

#include <heyoka/expression.hpp>
#include <heyoka/llvm_state.hpp>
#include <heyoka/math/kepF.hpp>

#include "catch.hpp"
#include "test_utils.hpp"

static std::mt19937 rng;

using namespace heyoka;
using namespace heyoka_test;

const auto fp_types = std::tuple<double
#if !defined(HEYOKA_ARCH_PPC)
,
long double
#endif
#if defined(HEYOKA_HAVE_REAL128)
,
mppp::real128
#endif
>{};

constexpr bool skip_batch_ld =
#if LLVM_VERSION_MAJOR >= 13 && LLVM_VERSION_MAJOR <= 17
std::numeric_limits<long double>::digits == 64
#else
false
#endif
;

TEST_CASE("cfunc")
{
auto tester = [](auto fp_x, unsigned opt_level, bool high_accuracy, bool compact_mode) {
using fp_t = decltype(fp_x);

auto [h, k, lam] = make_vars("h", "k", "lam");

std::uniform_real_distribution<double> lam_dist(-1e5, 1e5), h_dist(std::nextafter(-1., 0.), 1.);

auto generate_hk = [&h_dist]() {
// Generate h.
auto h_val = h_dist(rng);

// Generate a k such that h**2+k**2<1.
const auto max_abs_k = std::sqrt(1. - h_val * h_val);
std::uniform_real_distribution<double> k_dist(std::nextafter(-max_abs_k, 0.), max_abs_k);
auto k_val = static_cast<fp_t>(k_dist(rng));

return std::make_pair(static_cast<fp_t>(h_val), std::move(k_val));
};

std::vector<fp_t> outs, ins, pars;

for (auto batch_size : {1u, 2u, 4u, 5u}) {
if (batch_size != 1u && std::is_same_v<fp_t, long double> && skip_batch_ld) {
continue;
}

outs.resize(batch_size * 3u);
ins.resize(batch_size * 3u);
pars.resize(batch_size * 2u);

for (auto i = 0u; i < batch_size; ++i) {
// Generate the hs and ks.
auto [hval, kval] = generate_hk();
// Generate the lam.
auto lamval = lam_dist(rng);

ins[i] = hval;
ins[i + batch_size] = kval;
ins[i + 2u * batch_size] = lamval;

// Generate another pair of hs and ks for the pars.
std::tie(hval, kval) = generate_hk();
pars[i] = hval;
pars[i + batch_size] = kval;
}

llvm_state s{kw::opt_level = opt_level};

add_cfunc<fp_t>(s, "cfunc", {kepF(h, k, lam), kepF(par[0], par[1], lam), kepF(.5_dbl, .3_dbl, lam)},
kw::batch_size = batch_size, kw::high_accuracy = high_accuracy,
kw::compact_mode = compact_mode);

if (opt_level == 0u && compact_mode) {
REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.kepF."));
}

s.compile();

auto *cf_ptr
= reinterpret_cast<void (*)(fp_t *, const fp_t *, const fp_t *, const fp_t *)>(s.jit_lookup("cfunc"));

cf_ptr(outs.data(), ins.data(), pars.data(), nullptr);

for (auto i = 0u; i < batch_size; ++i) {
using std::isnan;
using std::cos;
using std::sin;

// First output.
REQUIRE(!isnan(outs[i]));
auto Fval = outs[i];
auto hval = ins[i];
auto kval = ins[i + batch_size];
auto lamval = ins[i + 2u * batch_size];
REQUIRE(cos(lamval) == approximately(cos(Fval + hval * cos(Fval) - kval * sin(Fval)), fp_t(10000)));
REQUIRE(sin(lamval) == approximately(sin(Fval + hval * cos(Fval) - kval * sin(Fval)), fp_t(10000)));

// Second output.
REQUIRE(!isnan(outs[i + batch_size]));
Fval = outs[i + batch_size];
hval = pars[i];
kval = pars[i + batch_size];
lamval = ins[i + 2u * batch_size];
REQUIRE(cos(lamval) == approximately(cos(Fval + hval * cos(Fval) - kval * sin(Fval)), fp_t(10000)));
REQUIRE(sin(lamval) == approximately(sin(Fval + hval * cos(Fval) - kval * sin(Fval)), fp_t(10000)));

// Third output.
REQUIRE(!isnan(outs[i + batch_size * 2u]));
Fval = outs[i + batch_size * 2u];
hval = .5;
kval = .3;
lamval = ins[i + 2u * batch_size];
REQUIRE(cos(lamval) == approximately(cos(Fval + hval * cos(Fval) - kval * sin(Fval)), fp_t(10000)));
REQUIRE(sin(lamval) == approximately(sin(Fval + hval * cos(Fval) - kval * sin(Fval)), fp_t(10000)));
}
}
};

for (auto cm : {false, true}) {
for (auto f : {false, true}) {
tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 0, f, cm); });
tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 1, f, cm); });
tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 2, f, cm); });
tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 3, f, cm); });
}
}
}

0 comments on commit 1873b3f

Please sign in to comment.