Skip to content

Commit

Permalink
Testing tweaks/additions.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Oct 14, 2023
1 parent 1873b3f commit cae1daa
Showing 1 changed file with 55 additions and 7 deletions.
62 changes: 55 additions & 7 deletions test/kepF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,16 @@ constexpr bool skip_batch_ld =

TEST_CASE("cfunc")
{
using std::isnan;

auto tester = [](auto fp_x, unsigned opt_level, bool high_accuracy, bool compact_mode) {
using fp_t = decltype(fp_x);

auto eps_close = [](const fp_t &a, const fp_t &b) {
using std::abs;
return abs(a - b) <= std::numeric_limits<fp_t>::epsilon() * 10;
};

auto [h, k, lam] = make_vars("h", "k", "lam");

std::uniform_real_distribution<double> lam_dist(-1e5, 1e5), h_dist(std::nextafter(-1., 0.), 1.);
Expand Down Expand Up @@ -129,7 +136,6 @@ TEST_CASE("cfunc")
cf_ptr(outs.data(), ins.data(), pars.data(), nullptr);

for (auto i = 0u; i < batch_size; ++i) {
using std::isnan;
using std::cos;
using std::sin;

Expand All @@ -139,26 +145,26 @@ TEST_CASE("cfunc")
auto hval = ins[i];
auto kval = ins[i + batch_size];
auto lamval = ins[i + 2u * batch_size];
REQUIRE(cos(lamval) == approximately(cos(Fval + hval * cos(Fval) - kval * sin(Fval)), fp_t(10000)));
REQUIRE(sin(lamval) == approximately(sin(Fval + hval * cos(Fval) - kval * sin(Fval)), fp_t(10000)));
REQUIRE(eps_close(cos(lamval), cos(Fval + hval * cos(Fval) - kval * sin(Fval))));
REQUIRE(eps_close(sin(lamval), sin(Fval + hval * cos(Fval) - kval * sin(Fval))));

// Second output.
REQUIRE(!isnan(outs[i + batch_size]));
Fval = outs[i + batch_size];
hval = pars[i];
kval = pars[i + batch_size];
lamval = ins[i + 2u * batch_size];
REQUIRE(cos(lamval) == approximately(cos(Fval + hval * cos(Fval) - kval * sin(Fval)), fp_t(10000)));
REQUIRE(sin(lamval) == approximately(sin(Fval + hval * cos(Fval) - kval * sin(Fval)), fp_t(10000)));
REQUIRE(eps_close(cos(lamval), cos(Fval + hval * cos(Fval) - kval * sin(Fval))));
REQUIRE(eps_close(sin(lamval), sin(Fval + hval * cos(Fval) - kval * sin(Fval))));

// Third output.
REQUIRE(!isnan(outs[i + batch_size * 2u]));
Fval = outs[i + batch_size * 2u];
hval = .5;
kval = .3;
lamval = ins[i + 2u * batch_size];
REQUIRE(cos(lamval) == approximately(cos(Fval + hval * cos(Fval) - kval * sin(Fval)), fp_t(10000)));
REQUIRE(sin(lamval) == approximately(sin(Fval + hval * cos(Fval) - kval * sin(Fval)), fp_t(10000)));
REQUIRE(eps_close(cos(lamval), cos(Fval + hval * cos(Fval) - kval * sin(Fval))));
REQUIRE(eps_close(sin(lamval), sin(Fval + hval * cos(Fval) - kval * sin(Fval))));
}
}
};
Expand All @@ -171,4 +177,46 @@ TEST_CASE("cfunc")
tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 3, f, cm); });
}
}

// Check nan/invalid values handling.
auto [h, k, lam] = make_vars("h", "k", "lam");

llvm_state s;

add_cfunc<double>(s, "cfunc", {kepF(h, k, lam)});

s.compile();

auto *cf_ptr
= reinterpret_cast<void (*)(double *, const double *, const double *, const double *)>(s.jit_lookup("cfunc"));

double out = 0;
double ins[3] = {.1, .2, std::numeric_limits<double>::quiet_NaN()};
cf_ptr(&out, ins, nullptr, nullptr);

REQUIRE(isnan(out));

ins[0] = std::numeric_limits<double>::quiet_NaN();
ins[1] = .2;
ins[2] = 1.;

cf_ptr(&out, ins, nullptr, nullptr);

REQUIRE(isnan(out));

ins[0] = .2;
ins[1] = std::numeric_limits<double>::quiet_NaN();
ins[2] = 1.;

cf_ptr(&out, ins, nullptr, nullptr);

REQUIRE(isnan(out));

ins[0] = .2;
ins[1] = 1.;
ins[2] = 1.;

cf_ptr(&out, ins, nullptr, nullptr);

REQUIRE(isnan(out));
}

0 comments on commit cae1daa

Please sign in to comment.