From 56a9ef9a78bfd259b83521b4b06d49d5c968c450 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Wed, 19 Jun 2024 09:37:24 +0000 Subject: [PATCH 01/26] Bump to 6.0.0 for development. --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1db953289..1dcdc9ce5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ if(NOT CMAKE_BUILD_TYPE) FORCE) endif() -project(heyoka VERSION 5.0.0 LANGUAGES CXX C) +project(heyoka VERSION 6.0.0 LANGUAGES CXX C) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake" "${CMAKE_CURRENT_SOURCE_DIR}/cmake/yacma") @@ -329,7 +329,7 @@ if(HEYOKA_WITH_SLEEF) endif() # Setup the heyoka ABI version number. -set(HEYOKA_ABI_VERSION 28) +set(HEYOKA_ABI_VERSION 29) if(HEYOKA_BUILD_STATIC_LIBRARY) # Setup of the heyoka static library. From 8b4a52aa794f38b0344b9f93b9a1bc134a7632e3 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Wed, 19 Jun 2024 13:03:37 +0000 Subject: [PATCH 02/26] Add llvm helper for inequality testing. --- include/heyoka/detail/llvm_helpers.hpp | 1 + include/heyoka/detail/real_helpers.hpp | 1 + src/detail/llvm_helpers.cpp | 25 +++++++++++++++++++++++++ src/detail/real_helpers.cpp | 9 +++++++++ 4 files changed, 36 insertions(+) diff --git a/include/heyoka/detail/llvm_helpers.hpp b/include/heyoka/detail/llvm_helpers.hpp index 11230fb77..980c53bbd 100644 --- a/include/heyoka/detail/llvm_helpers.hpp +++ b/include/heyoka/detail/llvm_helpers.hpp @@ -130,6 +130,7 @@ HEYOKA_DLL_PUBLIC llvm::Value *llvm_fcmp_ole(llvm_state &, llvm::Value *, llvm:: HEYOKA_DLL_PUBLIC llvm::Value *llvm_fcmp_olt(llvm_state &, llvm::Value *, llvm::Value *); HEYOKA_DLL_PUBLIC llvm::Value *llvm_fcmp_ogt(llvm_state &, llvm::Value *, llvm::Value *); HEYOKA_DLL_PUBLIC llvm::Value *llvm_fcmp_oeq(llvm_state &, llvm::Value *, llvm::Value *); +HEYOKA_DLL_PUBLIC llvm::Value *llvm_fcmp_one(llvm_state &, llvm::Value *, llvm::Value *); HEYOKA_DLL_PUBLIC llvm::Value *llvm_min(llvm_state &, llvm::Value *, llvm::Value *); HEYOKA_DLL_PUBLIC llvm::Value *llvm_max(llvm_state &, llvm::Value *, llvm::Value *); HEYOKA_DLL_PUBLIC llvm::Value *llvm_min_nan(llvm_state &, llvm::Value *, llvm::Value *); diff --git a/include/heyoka/detail/real_helpers.hpp b/include/heyoka/detail/real_helpers.hpp index 3fffe181e..1fdbb77bf 100644 --- a/include/heyoka/detail/real_helpers.hpp +++ b/include/heyoka/detail/real_helpers.hpp @@ -40,6 +40,7 @@ llvm::Value *llvm_real_fcmp_ole(llvm_state &, llvm::Value *, llvm::Value *); llvm::Value *llvm_real_fcmp_olt(llvm_state &, llvm::Value *, llvm::Value *); llvm::Value *llvm_real_fcmp_ogt(llvm_state &, llvm::Value *, llvm::Value *); llvm::Value *llvm_real_fcmp_oeq(llvm_state &, llvm::Value *, llvm::Value *); +llvm::Value *llvm_real_fcmp_one(llvm_state &, llvm::Value *, llvm::Value *); llvm::Value *llvm_real_ui_to_fp(llvm_state &, llvm::Value *, llvm::Type *); llvm::Value *llvm_real_sgn(llvm_state &, llvm::Value *); diff --git a/src/detail/llvm_helpers.cpp b/src/detail/llvm_helpers.cpp index e506d525c..da6bb26f8 100644 --- a/src/detail/llvm_helpers.cpp +++ b/src/detail/llvm_helpers.cpp @@ -2041,6 +2041,31 @@ llvm::Value *llvm_fcmp_oeq(llvm_state &s, llvm::Value *a, llvm::Value *b) } } +llvm::Value *llvm_fcmp_one(llvm_state &s, llvm::Value *a, llvm::Value *b) +{ + // LCOV_EXCL_START + assert(a != nullptr); + assert(b != nullptr); + assert(a->getType() == b->getType()); + // LCOV_EXCL_STOP + + auto &builder = s.builder(); + + auto *fp_t = a->getType(); + + if (fp_t->getScalarType()->isFloatingPointTy()) { + return builder.CreateFCmpONE(a, b); +#if defined(HEYOKA_HAVE_REAL) + } else if (llvm_is_real(fp_t) != 0) { + return llvm_real_fcmp_one(s, a, b); +#endif + } else { + // LCOV_EXCL_START + throw std::invalid_argument(fmt::format("Unable to fcmp_one values of type '{}'", llvm_type_name(fp_t))); + // LCOV_EXCL_STOP + } +} + // Helper to compute sin and cos simultaneously. // NOTE: although there exists a SLEEF function for computing sin/cos // at the same time, we cannot use it directly because it returns a pair diff --git a/src/detail/real_helpers.cpp b/src/detail/real_helpers.cpp index 83055843d..eb05c45e0 100644 --- a/src/detail/real_helpers.cpp +++ b/src/detail/real_helpers.cpp @@ -591,6 +591,15 @@ llvm::Value *llvm_real_fcmp_oeq(llvm_state &s, llvm::Value *a, llvm::Value *b) return s.builder().CreateCall(f, {a, b}); } +llvm::Value *llvm_real_fcmp_one(llvm_state &s, llvm::Value *a, llvm::Value *b) +{ + // Compute a == b. + auto *ret = llvm_real_fcmp_oeq(s, a, b); + + // NOTE: this creates a logical NOT. + return s.builder().CreateICmpEQ(ret, llvm::ConstantInt::getNullValue(ret->getType())); +} + // Convert the input unsigned integral value n to the real type fp_t. llvm::Value *llvm_real_ui_to_fp(llvm_state &s, llvm::Value *n, llvm::Type *fp_t) { From 535d5eae7b925050854e854575d3f56d0ccd45b9 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Wed, 19 Jun 2024 15:20:44 +0000 Subject: [PATCH 03/26] Extend llvm_ui_to_fp() to work also on vectors. --- src/detail/llvm_helpers.cpp | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/detail/llvm_helpers.cpp b/src/detail/llvm_helpers.cpp index da6bb26f8..3d34e31f8 100644 --- a/src/detail/llvm_helpers.cpp +++ b/src/detail/llvm_helpers.cpp @@ -3223,16 +3223,24 @@ llvm::Type *llvm_ext_type(llvm::Type *fp_t) // LCOV_EXCL_STOP } -// Convert the input unsigned integral value n to the floating-point type fp_t. -// Vector types/values are not supported. +// Convert the input unsigned integral value(s) n to the floating-point type fp_t. +// If n is a scalar/vector, then fp_t must also be a scalar/vector type. llvm::Value *llvm_ui_to_fp(llvm_state &s, llvm::Value *n, llvm::Type *fp_t) { assert(n != nullptr); assert(fp_t != nullptr); - assert(!n->getType()->isVectorTy()); - assert(!fp_t->isVectorTy()); - if (fp_t->isFloatingPointTy()) { +#if !defined(NDEBUG) + if (n->getType()->isVectorTy()) { + assert(fp_t->isVectorTy()); + assert(llvm::cast(n->getType())->getNumElements() + == llvm::cast(fp_t)->getNumElements()); + } else { + assert(!fp_t->isVectorTy()); + } +#endif + + if (fp_t->getScalarType()->isFloatingPointTy()) { return s.builder().CreateUIToFP(n, fp_t); #if defined(HEYOKA_HAVE_REAL) } else if (llvm_is_real(fp_t) != 0) { From 43488df2e574b7bc7b5ec3ee483606bc855a2fdd Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Wed, 19 Jun 2024 15:23:07 +0000 Subject: [PATCH 04/26] Initial version of the relational operations. --- CMakeLists.txt | 1 + include/heyoka/math/relational.hpp | 69 ++++++++ src/math/relational.cpp | 198 ++++++++++++++++++++++ test/CMakeLists.txt | 1 + test/rel.cpp | 261 +++++++++++++++++++++++++++++ 5 files changed, 530 insertions(+) create mode 100644 include/heyoka/math/relational.hpp create mode 100644 src/math/relational.cpp create mode 100644 test/rel.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 1dcdc9ce5..9f962ac2e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -313,6 +313,7 @@ set(HEYOKA_SRC_FILES "${CMAKE_CURRENT_SOURCE_DIR}/src/math/prod.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/math/constants.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/math/dfun.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/src/math/relational.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/detail/string_conv.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/detail/logging_impl.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/step_callback.cpp" diff --git a/include/heyoka/math/relational.hpp b/include/heyoka/math/relational.hpp new file mode 100644 index 000000000..2fe7118a5 --- /dev/null +++ b/include/heyoka/math/relational.hpp @@ -0,0 +1,69 @@ +// Copyright 2020, 2021, 2022, 2023, 2024 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef HEYOKA_MATH_RELATIONAL_HPP +#define HEYOKA_MATH_RELATIONAL_HPP + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +HEYOKA_BEGIN_NAMESPACE + +namespace detail +{ + +enum class rel_op { eq, neq, lt, gt, lte, gte }; + +class HEYOKA_DLL_PUBLIC rel_impl : public func_base +{ + rel_op m_op = rel_op::eq; + + friend class boost::serialization::access; + template + void serialize(Archive &ar, unsigned) + { + ar &boost::serialization::base_object(*this); + ar & m_op; + } + +public: + rel_impl(); + explicit rel_impl(rel_op, expression, expression); + + void to_stream(std::ostringstream &) const; + + [[nodiscard]] std::vector gradient() const; + + [[nodiscard]] llvm::Value *llvm_eval(llvm_state &, llvm::Type *, const std::vector &, llvm::Value *, + llvm::Value *, llvm::Value *, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const; +}; + +} // namespace detail + +HEYOKA_DLL_PUBLIC expression eq(expression, expression); +HEYOKA_DLL_PUBLIC expression neq(expression, expression); +HEYOKA_DLL_PUBLIC expression lt(expression, expression); +HEYOKA_DLL_PUBLIC expression gt(expression, expression); +HEYOKA_DLL_PUBLIC expression lte(expression, expression); +HEYOKA_DLL_PUBLIC expression gte(expression, expression); + +HEYOKA_END_NAMESPACE + +HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::rel_impl) + +#endif diff --git a/src/math/relational.cpp b/src/math/relational.cpp new file mode 100644 index 000000000..42d6bb98e --- /dev/null +++ b/src/math/relational.cpp @@ -0,0 +1,198 @@ +// Copyright 2020, 2021, 2022, 2023, 2024 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +HEYOKA_BEGIN_NAMESPACE + +namespace detail +{ + +namespace +{ + +std::string name_from_op(rel_op op) +{ + assert(op >= rel_op::eq); + assert(op <= rel_op::gte); + + constexpr auto fstr = "rel_{}"; + +#define HEYOKA_MATH_REL_HANDLE_CASE(op) \ + case rel_op::op: \ + return fmt::format(fstr, #op); + + switch (op) { + HEYOKA_MATH_REL_HANDLE_CASE(eq) + HEYOKA_MATH_REL_HANDLE_CASE(neq) + HEYOKA_MATH_REL_HANDLE_CASE(lt) + HEYOKA_MATH_REL_HANDLE_CASE(gt) + HEYOKA_MATH_REL_HANDLE_CASE(lte) + HEYOKA_MATH_REL_HANDLE_CASE(gte) + } + +#undef HEYOKA_MATH_REL_HANDLE_CASE + + // LCOV_EXCL_START + assert(false); + + throw; + // LCOV_EXCL_STOP +} + +} // namespace + +rel_impl::rel_impl() : rel_impl(rel_op::eq, 1_dbl, 1_dbl) {} + +rel_impl::rel_impl(rel_op op, expression a, expression b) + : func_base(name_from_op(op), {std::move(a), std::move(b)}), m_op(op) +{ +} + +void rel_impl::to_stream(std::ostringstream &oss) const +{ + assert(args().size() == 2u); + + const auto &a = args()[0]; + const auto &b = args()[1]; + + oss << '('; + stream_expression(oss, a); + + switch (m_op) { + case rel_op::eq: + oss << " == "; + break; + case rel_op::neq: + oss << " != "; + break; + case rel_op::lt: + oss << " < "; + break; + case rel_op::gt: + oss << " > "; + break; + case rel_op::lte: + oss << " <= "; + break; + default: + assert(m_op == rel_op::gte); + oss << " >= "; + break; + } + + stream_expression(oss, b); + oss << ')'; +} + +// NOLINTNEXTLINE(readability-convert-member-functions-to-static) +std::vector rel_impl::gradient() const +{ + return {0_dbl, 0_dbl}; +} + +namespace +{ + +llvm::Value *rel_eval_impl(llvm_state &s, rel_op op, const std::vector &args) +{ + assert(args.size() == 2u); + + llvm::Value *ret = nullptr; + + switch (op) { + case rel_op::eq: + ret = llvm_fcmp_oeq(s, args[0], args[1]); + break; + case rel_op::neq: + ret = llvm_fcmp_one(s, args[0], args[1]); + break; + case rel_op::lt: + ret = llvm_fcmp_olt(s, args[0], args[1]); + break; + case rel_op::gt: + ret = llvm_fcmp_ogt(s, args[0], args[1]); + break; + case rel_op::lte: + ret = llvm_fcmp_ole(s, args[0], args[1]); + break; + case rel_op::gte: + ret = llvm_fcmp_oge(s, args[0], args[1]); + break; + } + + assert(ret != nullptr); + + // NOTE: the LLVM fp comparison primitives return booleans. Thus, + // we need to convert back to the proper (vector) fp type on exit. + // NOTE: we create a UI to FP conversion (rather than SI to FP) + // so that we get either 1 or 0 from the conversion (with SI, true + // would come out as -1). + return llvm_ui_to_fp(s, ret, args[0]->getType()); +} + +} // namespace + +llvm::Value *rel_impl::llvm_eval(llvm_state &s, llvm::Type *fp_t, const std::vector &eval_arr, + llvm::Value *par_ptr, llvm::Value *, llvm::Value *stride, std::uint32_t batch_size, + bool high_accuracy) const +{ + return llvm_eval_helper( + [&s, op = m_op](const std::vector &args, bool) { return rel_eval_impl(s, op, args); }, *this, s, + fp_t, eval_arr, par_ptr, stride, batch_size, high_accuracy); +} + +llvm::Function *rel_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t batch_size, + bool high_accuracy) const +{ + return llvm_c_eval_func_helper( + name_from_op(m_op), + [&s, op = m_op](const std::vector &args, bool) { return rel_eval_impl(s, op, args); }, *this, s, + fp_t, batch_size, high_accuracy); +} + +} // namespace detail + +#define HEYOKA_MATH_REL_IMPL(op) \ + expression op(expression a, expression b) \ + { \ + return expression{func{detail::rel_impl{detail::rel_op::op, std::move(a), std::move(b)}}}; \ + } + +HEYOKA_MATH_REL_IMPL(eq) +HEYOKA_MATH_REL_IMPL(neq) +HEYOKA_MATH_REL_IMPL(lt) +HEYOKA_MATH_REL_IMPL(gt) +HEYOKA_MATH_REL_IMPL(lte) +HEYOKA_MATH_REL_IMPL(gte) + +#undef HEYOKA_MATH_REL_IMPL + +HEYOKA_END_NAMESPACE + +// NOLINTNEXTLINE(cert-err58-cpp) +HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::rel_impl) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 16468e23b..5b0be4090 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -120,6 +120,7 @@ ADD_HEYOKA_TESTCASE(prod) ADD_HEYOKA_TESTCASE(div) ADD_HEYOKA_TESTCASE(sub) ADD_HEYOKA_TESTCASE(time) +ADD_HEYOKA_TESTCASE(rel) ADD_HEYOKA_TESTCASE(wavy_ramp) ADD_HEYOKA_TESTCASE(dfloat_time) ADD_HEYOKA_TESTCASE(timestep_check) diff --git a/test/rel.cpp b/test/rel.cpp new file mode 100644 index 000000000..eef4d16bd --- /dev/null +++ b/test/rel.cpp @@ -0,0 +1,261 @@ +// Copyright 2020, 2021, 2022, 2023, 2024 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#if defined(HEYOKA_HAVE_REAL128) + +#include + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +#include + +#endif + +#include +#include +#include +#include +#include +#include + +#include "catch.hpp" +#include "test_utils.hpp" + +static std::mt19937 rng; + +using namespace heyoka; +using namespace heyoka_test; + +const auto fp_types = std::tuple{}; + +constexpr bool skip_batch_ld = +#if LLVM_VERSION_MAJOR <= 17 + std::numeric_limits::digits == 64 +#else + false +#endif + ; + +TEST_CASE("stream") +{ + auto [x, y] = make_vars("x", "y"); + + { + std::ostringstream oss; + oss << eq(x, y); + REQUIRE(oss.str() == "(x == y)"); + } + + { + std::ostringstream oss; + oss << neq(x, y); + REQUIRE(oss.str() == "(x != y)"); + } + + { + std::ostringstream oss; + oss << lt(x, y); + REQUIRE(oss.str() == "(x < y)"); + } + + { + std::ostringstream oss; + oss << gt(x, y); + REQUIRE(oss.str() == "(x > y)"); + } + + { + std::ostringstream oss; + oss << lte(x, y); + REQUIRE(oss.str() == "(x <= y)"); + } + + { + std::ostringstream oss; + oss << gte(x, y); + REQUIRE(oss.str() == "(x >= y)"); + } +} + +TEST_CASE("diff") +{ + auto [x, y] = make_vars("x", "y"); + + REQUIRE(diff(eq(x, y), "x") == 0_dbl); + REQUIRE(diff(neq(x, y), "y") == 0_dbl); +} + +TEST_CASE("s11n") +{ + std::stringstream ss; + + auto [x, y] = make_vars("x", "y"); + + auto ex = eq(x, y); + + { + boost::archive::binary_oarchive oa(ss); + + oa << ex; + } + + ex = neq(x, y); + + { + boost::archive::binary_iarchive ia(ss); + + ia >> ex; + } + + REQUIRE(ex == eq(x, y)); +} + +TEST_CASE("cfunc") +{ + auto tester = [](auto fp_x, unsigned opt_level, bool high_accuracy, bool compact_mode) { + using fp_t = decltype(fp_x); + + auto [x, y] = make_vars("x", "y"); + + std::uniform_real_distribution rdist(-1., 1.); + + auto gen = [&rdist]() { return static_cast(rdist(rng)); }; + + std::vector outs, ins, pars, time; + + for (auto batch_size : {1u, 2u, 4u, 5u}) { + if (batch_size != 1u && std::is_same_v && skip_batch_ld) { + continue; + } + + outs.resize(batch_size * 8u); + ins.resize(batch_size * 2u); + pars.resize(batch_size); + time.resize(batch_size); + + std::generate(ins.begin(), ins.end(), gen); + std::generate(pars.begin(), pars.end(), gen); + std::generate(time.begin(), time.end(), gen); + + llvm_state s{kw::opt_level = opt_level}; + + add_cfunc(s, "cfunc", + {eq(x, y), neq(x, par[0]), lt(y, 1_dbl), gt(x + y, y - x), lte(x * x, heyoka::time), + gte(par[0], .4_dbl), lte(x, x), gte(y, y)}, + {x, y}, kw::batch_size = batch_size, kw::high_accuracy = high_accuracy, + kw::compact_mode = compact_mode); + + if (opt_level == 0u && compact_mode) { + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.rel_eq.")); + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.rel_neq.")); + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.rel_lt.")); + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.rel_gt.")); + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.rel_lte.")); + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.rel_gte.")); + } + + s.compile(); + + auto *cf_ptr + = reinterpret_cast(s.jit_lookup("cfunc")); + + cf_ptr(outs.data(), ins.data(), pars.data(), time.data()); + + for (auto i = 0u; i < batch_size; ++i) { + REQUIRE(outs[i] == (ins[i] == ins[i + batch_size])); + REQUIRE(outs[i + batch_size] == (ins[i] != pars[i])); + REQUIRE(outs[i + 2u * batch_size] == (ins[i + batch_size] < 1)); + REQUIRE(outs[i + 3u * batch_size] == ((ins[i] + ins[i + batch_size]) > (ins[i + batch_size] - ins[i]))); + REQUIRE(outs[i + 4u * batch_size] == ((ins[i] * ins[i]) <= time[i])); + REQUIRE(outs[i + 5u * batch_size] == (pars[i] >= .4)); + REQUIRE(outs[i + 6u * batch_size] == true); + REQUIRE(outs[i + 7u * batch_size] == true); + } + } + }; + + for (auto cm : {false, true}) { + for (auto f : {false, true}) { + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 0, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 1, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 2, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 3, f, cm); }); + } + } +} + +#if defined(HEYOKA_HAVE_REAL) + +TEST_CASE("cfunc_mp") +{ + auto [x, y] = make_vars("x", "y"); + + const auto prec = 237u; + + for (auto compact_mode : {false, true}) { + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + add_cfunc(s, "cfunc", + {eq(x, y), neq(x, par[0]), lt(y, 1_dbl), gt(x + y, y - x), lte(x * x, heyoka::time), + gte(par[0], .4_dbl), lte(x, x), gte(y, y)}, + {x, y}, kw::compact_mode = compact_mode, kw::prec = prec); + + s.compile(); + + auto *cf_ptr + = reinterpret_cast( + s.jit_lookup("cfunc")); + + const std::vector ins{mppp::real{".7", prec}, mppp::real{"-.1", prec}}; + const std::vector pars{mppp::real{"-.1", prec}}; + const std::vector time{mppp::real{".3", prec}}; + std::vector outs(8u, mppp::real{0, prec}); + + cf_ptr(outs.data(), ins.data(), pars.data(), time.data()); + + auto i = 0u; + auto batch_size = 1u; + REQUIRE(outs[i] == (ins[i] == ins[i + batch_size])); + REQUIRE(outs[i + batch_size] == (ins[i] != pars[i])); + REQUIRE(outs[i + 2u * batch_size] == (ins[i + batch_size] < 1)); + REQUIRE(outs[i + 3u * batch_size] == ((ins[i] + ins[i + batch_size]) > (ins[i + batch_size] - ins[i]))); + REQUIRE(outs[i + 4u * batch_size] == ((ins[i] * ins[i]) <= time[i])); + REQUIRE(outs[i + 5u * batch_size] == (pars[i] >= .4)); + REQUIRE(outs[i + 6u * batch_size] == true); + REQUIRE(outs[i + 7u * batch_size] == true); + } + } +} + +#endif From c5a7cf0cfe04556af0e363dbf5c4fb190fe16e64 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Wed, 19 Jun 2024 15:30:07 +0000 Subject: [PATCH 05/26] Small assert() addition. --- src/detail/llvm_helpers.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/detail/llvm_helpers.cpp b/src/detail/llvm_helpers.cpp index 3d34e31f8..f7a0fcf6b 100644 --- a/src/detail/llvm_helpers.cpp +++ b/src/detail/llvm_helpers.cpp @@ -3230,6 +3230,8 @@ llvm::Value *llvm_ui_to_fp(llvm_state &s, llvm::Value *n, llvm::Type *fp_t) assert(n != nullptr); assert(fp_t != nullptr); + assert(n->getType()->getScalarType()->isIntegerTy()); + #if !defined(NDEBUG) if (n->getType()->isVectorTy()) { assert(fp_t->isVectorTy()); From 1f908c4b458566203bc40f264420aa15f66d46c7 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Wed, 19 Jun 2024 15:48:43 +0000 Subject: [PATCH 06/26] Take advantage of the new logicaland/logicalor functions available from LLVM 13. --- src/detail/event_detection.cpp | 4 +--- src/detail/llvm_helpers.cpp | 12 ++++-------- src/detail/llvm_helpers_celmec.cpp | 19 +++++++------------ 3 files changed, 12 insertions(+), 23 deletions(-) diff --git a/src/detail/event_detection.cpp b/src/detail/event_detection.cpp index cf2df0a23..f5bf6a8df 100644 --- a/src/detail/event_detection.cpp +++ b/src/detail/event_detection.cpp @@ -748,9 +748,7 @@ llvm::Function *llvm_add_fex_check(llvm_state &s, llvm::Type *fp_t, std::uint32_ // Check if the signs are equal and the low sign is nonzero. auto *cmp1 = builder.CreateICmpEQ(s_lo, s_hi); auto *cmp2 = builder.CreateICmpNE(s_lo, llvm::ConstantInt::get(s_lo->getType(), 0u)); - // NOTE: this is a way of creating a logical AND between cmp1 and cmp2. LLVM 13 has a specific - // function for this. - auto *cmp = builder.CreateSelect(cmp1, cmp2, llvm::ConstantInt::get(cmp1->getType(), 0u)); + auto *cmp = builder.CreateLogicalAnd(cmp1, cmp2); // Extend cmp to int32_t. auto *retval = builder.CreateZExt(cmp, make_vector_type(builder.getInt32Ty(), batch_size)); diff --git a/src/detail/llvm_helpers.cpp b/src/detail/llvm_helpers.cpp index f7a0fcf6b..3bab78bf9 100644 --- a/src/detail/llvm_helpers.cpp +++ b/src/detail/llvm_helpers.cpp @@ -2948,10 +2948,8 @@ llvm::Value *llvm_dl_lt(llvm_state &state, llvm::Value *x_hi, llvm::Value *x_lo, auto *cond1 = llvm_fcmp_olt(state, x_hi, y_hi); auto *cond2 = llvm_fcmp_oeq(state, x_hi, y_hi); auto *cond3 = llvm_fcmp_olt(state, x_lo, y_lo); - // NOTE: this is a logical AND. - auto *cond4 = builder.CreateSelect(cond2, cond3, llvm::ConstantInt::getNullValue(cond3->getType())); - // NOTE: this is a logical OR. - auto *cond = builder.CreateSelect(cond1, llvm::ConstantInt::getAllOnesValue(cond4->getType()), cond4); + auto *cond4 = builder.CreateLogicalAnd(cond2, cond3); + auto *cond = builder.CreateLogicalOr(cond1, cond4); return cond; } @@ -2968,10 +2966,8 @@ llvm::Value *llvm_dl_gt(llvm_state &state, llvm::Value *x_hi, llvm::Value *x_lo, auto *cond1 = llvm_fcmp_ogt(state, x_hi, y_hi); auto *cond2 = llvm_fcmp_oeq(state, x_hi, y_hi); auto *cond3 = llvm_fcmp_ogt(state, x_lo, y_lo); - // NOTE: this is a logical AND. - auto *cond4 = builder.CreateSelect(cond2, cond3, llvm::ConstantInt::getNullValue(cond3->getType())); - // NOTE: this is a logical OR. - auto *cond = builder.CreateSelect(cond1, llvm::ConstantInt::getAllOnesValue(cond4->getType()), cond4); + auto *cond4 = builder.CreateLogicalAnd(cond2, cond3); + auto *cond = builder.CreateLogicalOr(cond1, cond4); return cond; } diff --git a/src/detail/llvm_helpers_celmec.cpp b/src/detail/llvm_helpers_celmec.cpp index 6cc032000..3fcb5deeb 100644 --- a/src/detail/llvm_helpers_celmec.cpp +++ b/src/detail/llvm_helpers_celmec.cpp @@ -375,9 +375,7 @@ llvm::Function *llvm_add_inv_kep_E(llvm_state &s, llvm::Type *fp_t, std::uint32_ auto *ecc_is_gte1 = llvm_fcmp_oge(s, ecc_arg, llvm_constantfp(s, tp, 1.)); // Is the eccentricity NaN or out of range? - // NOTE: this is a logical OR. - auto *ecc_invalid = builder.CreateSelect( - ecc_is_nan_or_neg, llvm::ConstantInt::getAllOnesValue(ecc_is_nan_or_neg->getType()), ecc_is_gte1); + auto *ecc_invalid = builder.CreateLogicalOr(ecc_is_nan_or_neg, ecc_is_gte1); // Replace invalid eccentricity values with quiet NaNs. auto *ecc @@ -537,8 +535,7 @@ llvm::Function *llvm_add_inv_kep_E(llvm_state &s, llvm::Type *fp_t, std::uint32_ auto *tol2_check = llvm_fcmp_ogt(s, bsize, tol); // Put them together with a logical AND. - auto *tol_check - = builder.CreateSelect(tol1_check, tol2_check, llvm::ConstantInt::get(tol2_check->getType(), 0u)); + auto *tol_check = builder.CreateLogicalAnd(tol1_check, tol2_check); // NOTE: we need OR reduction in batch mode: continue if *any* element of the batch // needs more iterations. auto *tol_cond = (batch_size == 1u) ? tol_check : builder.CreateOrReduce(tol_check); @@ -550,7 +547,7 @@ llvm::Function *llvm_add_inv_kep_E(llvm_state &s, llvm::Type *fp_t, std::uint32_ auto *c_cond = builder.CreateICmpULT(builder.CreateLoad(builder.getInt32Ty(), counter), max_iter); // Combine tolerance check and number of iterations check with a logical AND. - return builder.CreateSelect(c_cond, tol_cond, llvm::ConstantInt::get(tol_cond->getType(), 0u)); + return builder.CreateLogicalAnd(c_cond, tol_cond); }; // Run the loop. @@ -919,8 +916,7 @@ llvm::Function *llvm_add_inv_kep_F(llvm_state &s, llvm::Type *fp_t, std::uint32_ auto *tol2_check = llvm_fcmp_ogt(s, bsize, tol); // Put them together with a logical AND. - auto *tol_check - = builder.CreateSelect(tol1_check, tol2_check, llvm::ConstantInt::get(tol2_check->getType(), 0u)); + auto *tol_check = builder.CreateLogicalAnd(tol1_check, tol2_check); // NOTE: we need OR reduction in batch mode: continue if *any* element of the batch // needs more iterations. auto *tol_cond = (batch_size == 1u) ? tol_check : builder.CreateOrReduce(tol_check); @@ -932,7 +928,7 @@ llvm::Function *llvm_add_inv_kep_F(llvm_state &s, llvm::Type *fp_t, std::uint32_ auto *c_cond = builder.CreateICmpULT(builder.CreateLoad(builder.getInt32Ty(), counter), max_iter); // Combine tolerance check and number of iterations check with a logical AND. - return builder.CreateSelect(c_cond, tol_cond, llvm::ConstantInt::get(tol_cond->getType(), 0u)); + return builder.CreateLogicalAnd(c_cond, tol_cond); }; // Run the loop. @@ -1235,8 +1231,7 @@ llvm::Function *llvm_add_inv_kep_DE(llvm_state &s, llvm::Type *fp_t, std::uint32 auto *tol2_check = llvm_fcmp_ogt(s, bsize, tol); // Put them together with a logical AND. - auto *tol_check - = builder.CreateSelect(tol1_check, tol2_check, llvm::ConstantInt::get(tol2_check->getType(), 0u)); + auto *tol_check = builder.CreateLogicalAnd(tol1_check, tol2_check); // NOTE: we need OR reduction in batch mode: continue if *any* element of the batch // needs more iterations. auto *tol_cond = (batch_size == 1u) ? tol_check : builder.CreateOrReduce(tol_check); @@ -1248,7 +1243,7 @@ llvm::Function *llvm_add_inv_kep_DE(llvm_state &s, llvm::Type *fp_t, std::uint32 auto *c_cond = builder.CreateICmpULT(builder.CreateLoad(builder.getInt32Ty(), counter), max_iter); // Combine tolerance check and number of iterations check with a logical AND. - return builder.CreateSelect(c_cond, tol_cond, llvm::ConstantInt::get(tol_cond->getType(), 0u)); + return builder.CreateLogicalAnd(c_cond, tol_cond); }; // Run the loop. From 13ed3b9b2838e2ed6f670a87dce6e6a6abf86b52 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Wed, 19 Jun 2024 16:02:05 +0000 Subject: [PATCH 07/26] More tuning of testing tolerances. --- test/taylor_acosh.cpp | 4 ++-- test/taylor_sinhcosh.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/taylor_acosh.cpp b/test/taylor_acosh.cpp index 9a788ae4b..e91b5a28a 100644 --- a/test/taylor_acosh.cpp +++ b/test/taylor_acosh.cpp @@ -374,7 +374,7 @@ TEST_CASE("taylor acosh") if constexpr (!std::is_same_v || !skip_batch_ld) { // Do the batch/scalar comparison. compare_batch_scalar({prime(x) = acosh(expression{number{fp_t{1.625}}}), prime(y) = x + y}, opt_level, - high_accuracy, compact_mode, rng, 1.3f, 5.1f); + high_accuracy, compact_mode, rng, 1.2f, 5.1f, fp_t(10000)); } // Variable tests. @@ -536,7 +536,7 @@ TEST_CASE("taylor acosh") if constexpr (!std::is_same_v || !skip_batch_ld) { // Do the batch/scalar comparison. compare_batch_scalar({prime(x) = acosh(y), prime(y) = acosh(x)}, opt_level, high_accuracy, - compact_mode, rng, 1.3f, 5.1f); + compact_mode, rng, 1.2f, 5.1f, fp_t(10000)); } }; diff --git a/test/taylor_sinhcosh.cpp b/test/taylor_sinhcosh.cpp index 16263007b..bb3804f72 100644 --- a/test/taylor_sinhcosh.cpp +++ b/test/taylor_sinhcosh.cpp @@ -352,7 +352,7 @@ TEST_CASE("taylor sinhcosh") // Do the batch/scalar comparison. compare_batch_scalar( {prime(x) = sinh(expression{number{fp_t{2}}}) + cosh(expression{number{fp_t{3}}}), prime(y) = x + y}, - opt_level, high_accuracy, compact_mode, rng, -10.f, 10.f); + opt_level, high_accuracy, compact_mode, rng, -10.f, 10.f, fp_t(10000)); } // Variable tests. @@ -507,7 +507,7 @@ TEST_CASE("taylor sinhcosh") if constexpr (!std::is_same_v || !skip_batch_ld) { // Do the batch/scalar comparison. compare_batch_scalar({prime(x) = sinh(y), prime(y) = cosh(x)}, opt_level, high_accuracy, compact_mode, - rng, -10.f, 10.f); + rng, -10.f, 10.f, fp_t(10000)); } }; From be561ad541def6d3e42f336c6e281056141e3802 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Wed, 19 Jun 2024 16:04:06 +0000 Subject: [PATCH 08/26] Small test addition. --- test/rel.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/rel.cpp b/test/rel.cpp index eef4d16bd..a7a41f6ff 100644 --- a/test/rel.cpp +++ b/test/rel.cpp @@ -33,6 +33,7 @@ #endif #include +#include #include #include #include @@ -66,6 +67,13 @@ constexpr bool skip_batch_ld = #endif ; +TEST_CASE("basic") +{ + auto [x, y] = make_vars("x", "y"); + + REQUIRE(expression{func{detail::rel_impl{}}} == eq(1_dbl, 1_dbl)); +} + TEST_CASE("stream") { auto [x, y] = make_vars("x", "y"); From 2d4c0b4b885508f1ca55ccb54874ec5aedd34aa0 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Wed, 19 Jun 2024 16:12:50 +0000 Subject: [PATCH 09/26] Another small test addition. --- include/heyoka/math/relational.hpp | 2 ++ src/math/relational.cpp | 5 +++++ test/rel.cpp | 6 ++++-- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/include/heyoka/math/relational.hpp b/include/heyoka/math/relational.hpp index 2fe7118a5..6e86b3886 100644 --- a/include/heyoka/math/relational.hpp +++ b/include/heyoka/math/relational.hpp @@ -43,6 +43,8 @@ class HEYOKA_DLL_PUBLIC rel_impl : public func_base rel_impl(); explicit rel_impl(rel_op, expression, expression); + [[nodiscard]] rel_op get_op() const noexcept; + void to_stream(std::ostringstream &) const; [[nodiscard]] std::vector gradient() const; diff --git a/src/math/relational.cpp b/src/math/relational.cpp index 42d6bb98e..0490aff69 100644 --- a/src/math/relational.cpp +++ b/src/math/relational.cpp @@ -73,6 +73,11 @@ rel_impl::rel_impl(rel_op op, expression a, expression b) { } +rel_op rel_impl::get_op() const noexcept +{ + return m_op; +} + void rel_impl::to_stream(std::ostringstream &oss) const { assert(args().size() == 2u); diff --git a/test/rel.cpp b/test/rel.cpp index a7a41f6ff..e4a115867 100644 --- a/test/rel.cpp +++ b/test/rel.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -129,7 +130,7 @@ TEST_CASE("s11n") auto [x, y] = make_vars("x", "y"); - auto ex = eq(x, y); + auto ex = lt(x, y); { boost::archive::binary_oarchive oa(ss); @@ -145,7 +146,8 @@ TEST_CASE("s11n") ia >> ex; } - REQUIRE(ex == eq(x, y)); + REQUIRE(ex == lt(x, y)); + REQUIRE(std::get(ex.value()).extract()->get_op() == detail::rel_op::lt); } TEST_CASE("cfunc") From ac82da31c2e9db8a8ae5c41dd4ccb6cdfc3e970d Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Wed, 19 Jun 2024 20:48:38 +0200 Subject: [PATCH 10/26] Back to 5.1.0. --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f962ac2e..fab189deb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ if(NOT CMAKE_BUILD_TYPE) FORCE) endif() -project(heyoka VERSION 6.0.0 LANGUAGES CXX C) +project(heyoka VERSION 5.1.0 LANGUAGES CXX C) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake" "${CMAKE_CURRENT_SOURCE_DIR}/cmake/yacma") From 78703d4a4ae181618c932614297fe2e1ccf11db0 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Wed, 19 Jun 2024 20:49:36 +0200 Subject: [PATCH 11/26] Add missing header. --- include/heyoka/math.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/include/heyoka/math.hpp b/include/heyoka/math.hpp index 9cf6e153c..ba5ba8c2f 100644 --- a/include/heyoka/math.hpp +++ b/include/heyoka/math.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include From 05c7bb62db87e9698e3bad1929bb27db4e02bb62 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Wed, 19 Jun 2024 21:24:53 +0200 Subject: [PATCH 12/26] More small test additions. --- test/rel.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/rel.cpp b/test/rel.cpp index e4a115867..59e7aca0f 100644 --- a/test/rel.cpp +++ b/test/rel.cpp @@ -73,6 +73,11 @@ TEST_CASE("basic") auto [x, y] = make_vars("x", "y"); REQUIRE(expression{func{detail::rel_impl{}}} == eq(1_dbl, 1_dbl)); + + REQUIRE(eq(x, y) == eq(x, y)); + REQUIRE(eq(x, y) != neq(x, y)); + REQUIRE(lte(x, y) != gte(x, y)); + REQUIRE(lte(x, y) == lte(x, y)); } TEST_CASE("stream") From 086ebcb23823dd7c949810836780b4be9d24e719 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Thu, 20 Jun 2024 08:36:20 +0200 Subject: [PATCH 13/26] Add llvm helpers to determine if a FP number is nonzero. --- include/heyoka/detail/llvm_helpers.hpp | 1 + include/heyoka/detail/real_helpers.hpp | 1 + src/detail/llvm_helpers.cpp | 26 ++++++++++++++++++++++++++ src/detail/real_helpers.cpp | 16 ++++++++++++++++ 4 files changed, 44 insertions(+) diff --git a/include/heyoka/detail/llvm_helpers.hpp b/include/heyoka/detail/llvm_helpers.hpp index 980c53bbd..3aa72cd88 100644 --- a/include/heyoka/detail/llvm_helpers.hpp +++ b/include/heyoka/detail/llvm_helpers.hpp @@ -131,6 +131,7 @@ HEYOKA_DLL_PUBLIC llvm::Value *llvm_fcmp_olt(llvm_state &, llvm::Value *, llvm:: HEYOKA_DLL_PUBLIC llvm::Value *llvm_fcmp_ogt(llvm_state &, llvm::Value *, llvm::Value *); HEYOKA_DLL_PUBLIC llvm::Value *llvm_fcmp_oeq(llvm_state &, llvm::Value *, llvm::Value *); HEYOKA_DLL_PUBLIC llvm::Value *llvm_fcmp_one(llvm_state &, llvm::Value *, llvm::Value *); +HEYOKA_DLL_PUBLIC llvm::Value *llvm_fnz(llvm_state &, llvm::Value *); HEYOKA_DLL_PUBLIC llvm::Value *llvm_min(llvm_state &, llvm::Value *, llvm::Value *); HEYOKA_DLL_PUBLIC llvm::Value *llvm_max(llvm_state &, llvm::Value *, llvm::Value *); HEYOKA_DLL_PUBLIC llvm::Value *llvm_min_nan(llvm_state &, llvm::Value *, llvm::Value *); diff --git a/include/heyoka/detail/real_helpers.hpp b/include/heyoka/detail/real_helpers.hpp index 1fdbb77bf..c64a239de 100644 --- a/include/heyoka/detail/real_helpers.hpp +++ b/include/heyoka/detail/real_helpers.hpp @@ -41,6 +41,7 @@ llvm::Value *llvm_real_fcmp_olt(llvm_state &, llvm::Value *, llvm::Value *); llvm::Value *llvm_real_fcmp_ogt(llvm_state &, llvm::Value *, llvm::Value *); llvm::Value *llvm_real_fcmp_oeq(llvm_state &, llvm::Value *, llvm::Value *); llvm::Value *llvm_real_fcmp_one(llvm_state &, llvm::Value *, llvm::Value *); +llvm::Value *llvm_real_fnz(llvm_state &, llvm::Value *); llvm::Value *llvm_real_ui_to_fp(llvm_state &, llvm::Value *, llvm::Type *); llvm::Value *llvm_real_sgn(llvm_state &, llvm::Value *); diff --git a/src/detail/llvm_helpers.cpp b/src/detail/llvm_helpers.cpp index 3bab78bf9..33a7cecd7 100644 --- a/src/detail/llvm_helpers.cpp +++ b/src/detail/llvm_helpers.cpp @@ -2066,6 +2066,32 @@ llvm::Value *llvm_fcmp_one(llvm_state &s, llvm::Value *a, llvm::Value *b) } } +// Check if the input floating-point value(s) x is anything other +// than zero (including NaN). +llvm::Value *llvm_fnz(llvm_state &s, llvm::Value *x) +{ + // LCOV_EXCL_START + assert(x != nullptr); + // LCOV_EXCL_STOP + + auto &builder = s.builder(); + + auto *fp_t = x->getType(); + + if (fp_t->getScalarType()->isFloatingPointTy()) { + return builder.CreateFCmpUNE(x, llvm::ConstantFP::get(x->getType(), 0.)); +#if defined(HEYOKA_HAVE_REAL) + } else if (llvm_is_real(fp_t) != 0) { + return llvm_real_fnz(s, x); +#endif + } else { + // LCOV_EXCL_START + throw std::invalid_argument( + fmt::format("Unable to invoke llvm_fnz() on values of type '{}'", llvm_type_name(fp_t))); + // LCOV_EXCL_STOP + } +} + // Helper to compute sin and cos simultaneously. // NOTE: although there exists a SLEEF function for computing sin/cos // at the same time, we cannot use it directly because it returns a pair diff --git a/src/detail/real_helpers.cpp b/src/detail/real_helpers.cpp index eb05c45e0..05c2a9295 100644 --- a/src/detail/real_helpers.cpp +++ b/src/detail/real_helpers.cpp @@ -600,6 +600,22 @@ llvm::Value *llvm_real_fcmp_one(llvm_state &s, llvm::Value *a, llvm::Value *b) return s.builder().CreateICmpEQ(ret, llvm::ConstantInt::getNullValue(ret->getType())); } +llvm::Value *llvm_real_fnz(llvm_state &s, llvm::Value *x) +{ + // LCOV_EXCL_START + assert(x != nullptr); + // LCOV_EXCL_STOP + + auto &builder = s.builder(); + + // Check if x is zero. + auto *f = real_nary_cmp(s, x->getType(), "mpfr_zero_p", 1u); + auto *ret = builder.CreateCall(f, x); + + // NOTE: this creates a logical NOT. + return builder.CreateICmpEQ(ret, llvm::ConstantInt::getNullValue(ret->getType())); +} + // Convert the input unsigned integral value n to the real type fp_t. llvm::Value *llvm_real_ui_to_fp(llvm_state &s, llvm::Value *n, llvm::Type *fp_t) { From 04bdccb74acea3ccddffcc82be11fa22f12f7ceb Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Thu, 20 Jun 2024 08:42:37 +0200 Subject: [PATCH 14/26] Add initial implementation of logical AND. --- CMakeLists.txt | 1 + include/heyoka/math.hpp | 1 + include/heyoka/math/logical.hpp | 56 +++++++ src/math/logical.cpp | 100 ++++++++++++ test/CMakeLists.txt | 1 + test/logical.cpp | 261 ++++++++++++++++++++++++++++++++ 6 files changed, 420 insertions(+) create mode 100644 include/heyoka/math/logical.hpp create mode 100644 src/math/logical.cpp create mode 100644 test/logical.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index fab189deb..c7a2344d8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -314,6 +314,7 @@ set(HEYOKA_SRC_FILES "${CMAKE_CURRENT_SOURCE_DIR}/src/math/constants.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/math/dfun.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/math/relational.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/src/math/logical.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/detail/string_conv.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/detail/logging_impl.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/step_callback.cpp" diff --git a/include/heyoka/math.hpp b/include/heyoka/math.hpp index ba5ba8c2f..35ae6ed4a 100644 --- a/include/heyoka/math.hpp +++ b/include/heyoka/math.hpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include diff --git a/include/heyoka/math/logical.hpp b/include/heyoka/math/logical.hpp new file mode 100644 index 000000000..87b6fa466 --- /dev/null +++ b/include/heyoka/math/logical.hpp @@ -0,0 +1,56 @@ +// Copyright 2020, 2021, 2022, 2023, 2024 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef HEYOKA_MATH_LOGICAL_HPP +#define HEYOKA_MATH_LOGICAL_HPP + +#include +#include + +#include +#include +#include +#include +#include +#include + +HEYOKA_BEGIN_NAMESPACE + +namespace detail +{ + +class HEYOKA_DLL_PUBLIC logical_and_impl : public func_base +{ + friend class boost::serialization::access; + template + void serialize(Archive &ar, unsigned) + { + ar &boost::serialization::base_object(*this); + } + +public: + logical_and_impl(); + explicit logical_and_impl(std::vector); + + [[nodiscard]] std::vector gradient() const; + + [[nodiscard]] llvm::Value *llvm_eval(llvm_state &, llvm::Type *, const std::vector &, llvm::Value *, + llvm::Value *, llvm::Value *, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const; +}; + +} // namespace detail + +HEYOKA_DLL_PUBLIC expression logical_and(std::vector); + +HEYOKA_END_NAMESPACE + +HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::logical_and_impl) + +#endif diff --git a/src/math/logical.cpp b/src/math/logical.cpp new file mode 100644 index 000000000..f40dab8cc --- /dev/null +++ b/src/math/logical.cpp @@ -0,0 +1,100 @@ +// Copyright 2020, 2021, 2022, 2023, 2024 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +HEYOKA_BEGIN_NAMESPACE + +namespace detail +{ + +logical_and_impl::logical_and_impl() : logical_and_impl({1_dbl}) {} + +logical_and_impl::logical_and_impl(std::vector args) : func_base("logical_and", std::move(args)) +{ + assert(!this->args().empty()); +} + +std::vector logical_and_impl::gradient() const +{ + return std::vector(this->args().size(), 0_dbl); +} + +namespace +{ + +llvm::Value *logical_and_eval_impl(llvm_state &s, const std::vector &args) +{ + assert(!args.empty()); + + auto &builder = s.builder(); + + auto *ret = llvm_fnz(s, args[0]); + + for (decltype(args.size()) i = 1; i < args.size(); ++i) { + auto *tmp = llvm_fnz(s, args[i]); + ret = builder.CreateLogicalAnd(ret, tmp); + } + + return llvm_ui_to_fp(s, ret, args[0]->getType()); +} + +} // namespace + +llvm::Value *logical_and_impl::llvm_eval(llvm_state &s, llvm::Type *fp_t, const std::vector &eval_arr, + llvm::Value *par_ptr, llvm::Value *, llvm::Value *stride, + std::uint32_t batch_size, bool high_accuracy) const +{ + return llvm_eval_helper( + [&s](const std::vector &args, bool) { return logical_and_eval_impl(s, args); }, *this, s, fp_t, + eval_arr, par_ptr, stride, batch_size, high_accuracy); +} + +llvm::Function *logical_and_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t batch_size, + bool high_accuracy) const +{ + return llvm_c_eval_func_helper( + "logical_and", [&s](const std::vector &args, bool) { return logical_and_eval_impl(s, args); }, + *this, s, fp_t, batch_size, high_accuracy); +} + +} // namespace detail + +expression logical_and(std::vector args) +{ + if (args.empty()) { + return 1_dbl; + } + + if (args.size() == 1u) { + return std::move(args[0]); + } + + return expression{func{detail::logical_and_impl{std::move(args)}}}; +} + +HEYOKA_END_NAMESPACE + +// NOLINTNEXTLINE(cert-err58-cpp) +HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::logical_and_impl) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5b0be4090..0c6f13ceb 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -121,6 +121,7 @@ ADD_HEYOKA_TESTCASE(div) ADD_HEYOKA_TESTCASE(sub) ADD_HEYOKA_TESTCASE(time) ADD_HEYOKA_TESTCASE(rel) +ADD_HEYOKA_TESTCASE(logical) ADD_HEYOKA_TESTCASE(wavy_ramp) ADD_HEYOKA_TESTCASE(dfloat_time) ADD_HEYOKA_TESTCASE(timestep_check) diff --git a/test/logical.cpp b/test/logical.cpp new file mode 100644 index 000000000..ecea3a4cf --- /dev/null +++ b/test/logical.cpp @@ -0,0 +1,261 @@ +// Copyright 2020, 2021, 2022, 2023, 2024 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#if defined(HEYOKA_HAVE_REAL128) + +#include + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +#include + +#endif + +#include +#include +#include +#include +#include +#include +#include + +#include "catch.hpp" +#include "test_utils.hpp" + +static std::mt19937 rng; + +using namespace heyoka; +using namespace heyoka_test; + +const auto fp_types = std::tuple{}; + +constexpr bool skip_batch_ld = +#if LLVM_VERSION_MAJOR <= 17 + std::numeric_limits::digits == 64 +#else + false +#endif + ; + +TEST_CASE("basic") +{ + auto x = make_vars("x"); + + REQUIRE(expression{func{detail::logical_and_impl{}}} == expression{func{detail::logical_and_impl{{1_dbl}}}}); + REQUIRE(logical_and({}) == 1_dbl); + REQUIRE(logical_and({x}) == x); +} + +TEST_CASE("stream") +{ + auto [x, y, z] = make_vars("x", "y", "z"); + + { + std::ostringstream oss; + oss << logical_and({x, y}); + REQUIRE(oss.str() == "logical_and(x, y)"); + } + + { + std::ostringstream oss; + oss << logical_and({x, y + z}); + REQUIRE(oss.str() == "logical_and(x, (y + z))"); + } +} + +TEST_CASE("diff") +{ + auto [x, y] = make_vars("x", "y"); + + REQUIRE(diff(logical_and({x, y}), "x") == 0_dbl); + REQUIRE(diff(logical_and({x * y, y - x}), "x") == 0_dbl); +} + +TEST_CASE("s11n logical_and") +{ + std::stringstream ss; + + auto [x, y] = make_vars("x", "y"); + + auto ex = logical_and({x, y}); + + { + boost::archive::binary_oarchive oa(ss); + + oa << ex; + } + + ex = 1_dbl; + + { + boost::archive::binary_iarchive ia(ss); + + ia >> ex; + } + + REQUIRE(ex == logical_and({x, y})); +} + +TEST_CASE("cfunc logical_and") +{ + auto tester = [](auto fp_x, unsigned opt_level, bool high_accuracy, bool compact_mode) { + using fp_t = decltype(fp_x); + + auto [x, y] = make_vars("x", "y"); + + std::uniform_int_distribution idist(-3, 3); + + auto gen = [&idist]() { return static_cast(idist(rng)); }; + + std::vector outs, ins, pars, time; + + for (auto batch_size : {1u, 2u, 4u, 5u}) { + if (batch_size != 1u && std::is_same_v && skip_batch_ld) { + continue; + } + + outs.resize(batch_size * 2u); + ins.resize(batch_size * 2u); + pars.resize(batch_size); + time.resize(batch_size); + + std::generate(ins.begin(), ins.end(), gen); + std::generate(pars.begin(), pars.end(), gen); + std::generate(time.begin(), time.end(), gen); + + llvm_state s{kw::opt_level = opt_level}; + + add_cfunc(s, "cfunc", {logical_and({x, y}), logical_and({par[0], heyoka::time, y})}, {x, y}, + kw::batch_size = batch_size, kw::high_accuracy = high_accuracy, + kw::compact_mode = compact_mode); + + if (opt_level == 0u && compact_mode) { + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.logical_and.")); + } + + s.compile(); + + auto *cf_ptr + = reinterpret_cast(s.jit_lookup("cfunc")); + + cf_ptr(outs.data(), ins.data(), pars.data(), time.data()); + + for (auto i = 0u; i < batch_size; ++i) { + REQUIRE(outs[i] == (ins[i] && ins[i + batch_size])); + REQUIRE(outs[i + batch_size] == (pars[i] && time[i] && ins[i + batch_size])); + } + } + }; + + for (auto cm : {false, true}) { + for (auto f : {false, true}) { + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 0, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 1, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 2, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 3, f, cm); }); + } + } + + // A test specific for NaN handling. + auto [x, y] = make_vars("x", "y"); + llvm_state s; + std::vector outs, ins{1., std::numeric_limits::quiet_NaN()}; + outs.resize(1); + + add_cfunc(s, "cfunc", {logical_and({x, y})}, {x, y}); + + s.compile(); + + auto *cf_ptr + = reinterpret_cast(s.jit_lookup("cfunc")); + + cf_ptr(outs.data(), ins.data(), nullptr, nullptr); + + REQUIRE(outs[0] == 1.); +} + +#if defined(HEYOKA_HAVE_REAL) + +TEST_CASE("cfunc logical_and mp") +{ + auto [x, y] = make_vars("x", "y"); + + const auto prec = 237u; + + for (auto compact_mode : {false, true}) { + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + add_cfunc(s, "cfunc", {logical_and({x, y}), logical_and({par[0], heyoka::time, y})}, {x, y}, + kw::compact_mode = compact_mode, kw::prec = prec); + + s.compile(); + + auto *cf_ptr + = reinterpret_cast( + s.jit_lookup("cfunc")); + + const std::vector ins{mppp::real{".7", prec}, mppp::real{"-.1", prec}}; + const std::vector pars{mppp::real{"0", prec}}; + const std::vector time{mppp::real{".3", prec}}; + std::vector outs(8u, mppp::real{0, prec}); + + cf_ptr(outs.data(), ins.data(), pars.data(), time.data()); + + auto i = 0u; + auto batch_size = 1u; + REQUIRE(outs[i] == (ins[i] && ins[i + batch_size])); + REQUIRE(static_cast(outs[i])); + REQUIRE(outs[i + batch_size] == (pars[i] && time[i] && ins[i + batch_size])); + REQUIRE(!static_cast(outs[i + batch_size])); + } + } + + // A test specific for NaN handling. + llvm_state s; + std::vector outs, ins{mppp::real{1., prec}, mppp::real{std::numeric_limits::quiet_NaN(), prec}}; + outs.resize(1, mppp::real{0., prec}); + + add_cfunc(s, "cfunc", {logical_and({x, y})}, {x, y}, kw::prec = prec); + + s.compile(); + + auto *cf_ptr = reinterpret_cast( + s.jit_lookup("cfunc")); + + cf_ptr(outs.data(), ins.data(), nullptr, nullptr); + + REQUIRE(outs[0] == 1); +} + +#endif From 8747825a3fab1c628d7ec0f1e269ca28283cd525 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Thu, 20 Jun 2024 08:52:21 +0200 Subject: [PATCH 15/26] Add logical OR implementation. --- include/heyoka/math/logical.hpp | 23 +++++ src/math/logical.cpp | 66 ++++++++++++ test/logical.cpp | 174 ++++++++++++++++++++++++++++++++ 3 files changed, 263 insertions(+) diff --git a/include/heyoka/math/logical.hpp b/include/heyoka/math/logical.hpp index 87b6fa466..3356f14b5 100644 --- a/include/heyoka/math/logical.hpp +++ b/include/heyoka/math/logical.hpp @@ -45,12 +45,35 @@ class HEYOKA_DLL_PUBLIC logical_and_impl : public func_base [[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const; }; +class HEYOKA_DLL_PUBLIC logical_or_impl : public func_base +{ + friend class boost::serialization::access; + template + void serialize(Archive &ar, unsigned) + { + ar &boost::serialization::base_object(*this); + } + +public: + logical_or_impl(); + explicit logical_or_impl(std::vector); + + [[nodiscard]] std::vector gradient() const; + + [[nodiscard]] llvm::Value *llvm_eval(llvm_state &, llvm::Type *, const std::vector &, llvm::Value *, + llvm::Value *, llvm::Value *, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const; +}; + } // namespace detail HEYOKA_DLL_PUBLIC expression logical_and(std::vector); +HEYOKA_DLL_PUBLIC expression logical_or(std::vector); HEYOKA_END_NAMESPACE HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::logical_and_impl) +HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::logical_or_impl) #endif diff --git a/src/math/logical.cpp b/src/math/logical.cpp index f40dab8cc..b304ba322 100644 --- a/src/math/logical.cpp +++ b/src/math/logical.cpp @@ -79,6 +79,56 @@ llvm::Function *logical_and_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp *this, s, fp_t, batch_size, high_accuracy); } +logical_or_impl::logical_or_impl() : logical_or_impl({1_dbl}) {} + +logical_or_impl::logical_or_impl(std::vector args) : func_base("logical_or", std::move(args)) +{ + assert(!this->args().empty()); +} + +std::vector logical_or_impl::gradient() const +{ + return std::vector(this->args().size(), 0_dbl); +} + +namespace +{ + +llvm::Value *logical_or_eval_impl(llvm_state &s, const std::vector &args) +{ + assert(!args.empty()); + + auto &builder = s.builder(); + + auto *ret = llvm_fnz(s, args[0]); + + for (decltype(args.size()) i = 1; i < args.size(); ++i) { + auto *tmp = llvm_fnz(s, args[i]); + ret = builder.CreateLogicalOr(ret, tmp); + } + + return llvm_ui_to_fp(s, ret, args[0]->getType()); +} + +} // namespace + +llvm::Value *logical_or_impl::llvm_eval(llvm_state &s, llvm::Type *fp_t, const std::vector &eval_arr, + llvm::Value *par_ptr, llvm::Value *, llvm::Value *stride, + std::uint32_t batch_size, bool high_accuracy) const +{ + return llvm_eval_helper( + [&s](const std::vector &args, bool) { return logical_or_eval_impl(s, args); }, *this, s, fp_t, + eval_arr, par_ptr, stride, batch_size, high_accuracy); +} + +llvm::Function *logical_or_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t batch_size, + bool high_accuracy) const +{ + return llvm_c_eval_func_helper( + "logical_or", [&s](const std::vector &args, bool) { return logical_or_eval_impl(s, args); }, + *this, s, fp_t, batch_size, high_accuracy); +} + } // namespace detail expression logical_and(std::vector args) @@ -94,7 +144,23 @@ expression logical_and(std::vector args) return expression{func{detail::logical_and_impl{std::move(args)}}}; } +expression logical_or(std::vector args) +{ + if (args.empty()) { + return 0_dbl; + } + + if (args.size() == 1u) { + return std::move(args[0]); + } + + return expression{func{detail::logical_or_impl{std::move(args)}}}; +} + HEYOKA_END_NAMESPACE // NOLINTNEXTLINE(cert-err58-cpp) HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::logical_and_impl) + +// NOLINTNEXTLINE(cert-err58-cpp) +HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::logical_or_impl) diff --git a/test/logical.cpp b/test/logical.cpp index ecea3a4cf..e921edb37 100644 --- a/test/logical.cpp +++ b/test/logical.cpp @@ -74,6 +74,10 @@ TEST_CASE("basic") REQUIRE(expression{func{detail::logical_and_impl{}}} == expression{func{detail::logical_and_impl{{1_dbl}}}}); REQUIRE(logical_and({}) == 1_dbl); REQUIRE(logical_and({x}) == x); + + REQUIRE(expression{func{detail::logical_or_impl{}}} == expression{func{detail::logical_or_impl{{1_dbl}}}}); + REQUIRE(logical_or({}) == 0_dbl); + REQUIRE(logical_or({x}) == x); } TEST_CASE("stream") @@ -91,6 +95,18 @@ TEST_CASE("stream") oss << logical_and({x, y + z}); REQUIRE(oss.str() == "logical_and(x, (y + z))"); } + + { + std::ostringstream oss; + oss << logical_or({x, y}); + REQUIRE(oss.str() == "logical_or(x, y)"); + } + + { + std::ostringstream oss; + oss << logical_or({x, y + z}); + REQUIRE(oss.str() == "logical_or(x, (y + z))"); + } } TEST_CASE("diff") @@ -99,6 +115,9 @@ TEST_CASE("diff") REQUIRE(diff(logical_and({x, y}), "x") == 0_dbl); REQUIRE(diff(logical_and({x * y, y - x}), "x") == 0_dbl); + + REQUIRE(diff(logical_or({x, y}), "x") == 0_dbl); + REQUIRE(diff(logical_or({x * y, y - x}), "x") == 0_dbl); } TEST_CASE("s11n logical_and") @@ -126,6 +145,31 @@ TEST_CASE("s11n logical_and") REQUIRE(ex == logical_and({x, y})); } +TEST_CASE("s11n logical_or") +{ + std::stringstream ss; + + auto [x, y] = make_vars("x", "y"); + + auto ex = logical_or({x, y}); + + { + boost::archive::binary_oarchive oa(ss); + + oa << ex; + } + + ex = 1_dbl; + + { + boost::archive::binary_iarchive ia(ss); + + ia >> ex; + } + + REQUIRE(ex == logical_or({x, y})); +} + TEST_CASE("cfunc logical_and") { auto tester = [](auto fp_x, unsigned opt_level, bool high_accuracy, bool compact_mode) { @@ -204,6 +248,84 @@ TEST_CASE("cfunc logical_and") REQUIRE(outs[0] == 1.); } +TEST_CASE("cfunc logical_or") +{ + auto tester = [](auto fp_x, unsigned opt_level, bool high_accuracy, bool compact_mode) { + using fp_t = decltype(fp_x); + + auto [x, y] = make_vars("x", "y"); + + std::uniform_int_distribution idist(-3, 3); + + auto gen = [&idist]() { return static_cast(idist(rng)); }; + + std::vector outs, ins, pars, time; + + for (auto batch_size : {1u, 2u, 4u, 5u}) { + if (batch_size != 1u && std::is_same_v && skip_batch_ld) { + continue; + } + + outs.resize(batch_size * 2u); + ins.resize(batch_size * 2u); + pars.resize(batch_size); + time.resize(batch_size); + + std::generate(ins.begin(), ins.end(), gen); + std::generate(pars.begin(), pars.end(), gen); + std::generate(time.begin(), time.end(), gen); + + llvm_state s{kw::opt_level = opt_level}; + + add_cfunc(s, "cfunc", {logical_or({x, y}), logical_or({par[0], heyoka::time, y})}, {x, y}, + kw::batch_size = batch_size, kw::high_accuracy = high_accuracy, + kw::compact_mode = compact_mode); + + if (opt_level == 0u && compact_mode) { + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.logical_or.")); + } + + s.compile(); + + auto *cf_ptr + = reinterpret_cast(s.jit_lookup("cfunc")); + + cf_ptr(outs.data(), ins.data(), pars.data(), time.data()); + + for (auto i = 0u; i < batch_size; ++i) { + REQUIRE(outs[i] == (ins[i] || ins[i + batch_size])); + REQUIRE(outs[i + batch_size] == (pars[i] || time[i] || ins[i + batch_size])); + } + } + }; + + for (auto cm : {false, true}) { + for (auto f : {false, true}) { + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 0, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 1, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 2, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 3, f, cm); }); + } + } + + // A test specific for NaN handling. + auto [x, y] = make_vars("x", "y"); + llvm_state s; + std::vector outs, ins{0., std::numeric_limits::quiet_NaN()}; + outs.resize(1); + + add_cfunc(s, "cfunc", {logical_or({x, y})}, {x, y}); + + s.compile(); + + auto *cf_ptr + = reinterpret_cast(s.jit_lookup("cfunc")); + + cf_ptr(outs.data(), ins.data(), nullptr, nullptr); + + REQUIRE(outs[0] == 1.); +} + #if defined(HEYOKA_HAVE_REAL) TEST_CASE("cfunc logical_and mp") @@ -258,4 +380,56 @@ TEST_CASE("cfunc logical_and mp") REQUIRE(outs[0] == 1); } +TEST_CASE("cfunc logical_or mp") +{ + auto [x, y] = make_vars("x", "y"); + + const auto prec = 237u; + + for (auto compact_mode : {false, true}) { + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + add_cfunc(s, "cfunc", {logical_or({x, y}), logical_or({par[0], heyoka::time, y})}, {x, y}, + kw::compact_mode = compact_mode, kw::prec = prec); + + s.compile(); + + auto *cf_ptr + = reinterpret_cast( + s.jit_lookup("cfunc")); + + const std::vector ins{mppp::real{".7", prec}, mppp::real{"-.1", prec}}; + const std::vector pars{mppp::real{"0", prec}}; + const std::vector time{mppp::real{".3", prec}}; + std::vector outs(8u, mppp::real{0, prec}); + + cf_ptr(outs.data(), ins.data(), pars.data(), time.data()); + + auto i = 0u; + auto batch_size = 1u; + REQUIRE(outs[i] == (ins[i] || ins[i + batch_size])); + REQUIRE(static_cast(outs[i])); + REQUIRE(outs[i + batch_size] == (pars[i] || time[i] || ins[i + batch_size])); + REQUIRE(static_cast(outs[i + batch_size])); + } + } + + // A test specific for NaN handling. + llvm_state s; + std::vector outs, ins{mppp::real{0., prec}, mppp::real{std::numeric_limits::quiet_NaN(), prec}}; + outs.resize(1, mppp::real{0., prec}); + + add_cfunc(s, "cfunc", {logical_or({x, y})}, {x, y}, kw::prec = prec); + + s.compile(); + + auto *cf_ptr = reinterpret_cast( + s.jit_lookup("cfunc")); + + cf_ptr(outs.data(), ins.data(), nullptr, nullptr); + + REQUIRE(outs[0] == 1); +} + #endif From 92c57d16eef1315427fd36254c92f8bc36c3618a Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Thu, 20 Jun 2024 09:00:12 +0200 Subject: [PATCH 16/26] Reduce repetition. --- src/math/logical.cpp | 51 +++++++++++++++++--------------------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/src/math/logical.cpp b/src/math/logical.cpp index b304ba322..35153da2d 100644 --- a/src/math/logical.cpp +++ b/src/math/logical.cpp @@ -44,7 +44,11 @@ std::vector logical_and_impl::gradient() const namespace { -llvm::Value *logical_and_eval_impl(llvm_state &s, const std::vector &args) +// NOTE: perhaps it could be worth it to implement this via a pairwise reduction. We cannot +// do this right now because we need to transform args via llvm_fnz() before performing the +// reduction and pairwise_reduce() does not support this pre-transform. +template +llvm::Value *logical_andor_eval_impl(llvm_state &s, const std::vector &args) { assert(!args.empty()); @@ -54,7 +58,11 @@ llvm::Value *logical_and_eval_impl(llvm_state &s, const std::vectorgetType()); @@ -67,16 +75,17 @@ llvm::Value *logical_and_impl::llvm_eval(llvm_state &s, llvm::Type *fp_t, const std::uint32_t batch_size, bool high_accuracy) const { return llvm_eval_helper( - [&s](const std::vector &args, bool) { return logical_and_eval_impl(s, args); }, *this, s, fp_t, - eval_arr, par_ptr, stride, batch_size, high_accuracy); + [&s](const std::vector &args, bool) { return logical_andor_eval_impl(s, args); }, *this, s, + fp_t, eval_arr, par_ptr, stride, batch_size, high_accuracy); } llvm::Function *logical_and_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t batch_size, bool high_accuracy) const { return llvm_c_eval_func_helper( - "logical_and", [&s](const std::vector &args, bool) { return logical_and_eval_impl(s, args); }, - *this, s, fp_t, batch_size, high_accuracy); + "logical_and", + [&s](const std::vector &args, bool) { return logical_andor_eval_impl(s, args); }, *this, s, + fp_t, batch_size, high_accuracy); } logical_or_impl::logical_or_impl() : logical_or_impl({1_dbl}) {} @@ -91,42 +100,22 @@ std::vector logical_or_impl::gradient() const return std::vector(this->args().size(), 0_dbl); } -namespace -{ - -llvm::Value *logical_or_eval_impl(llvm_state &s, const std::vector &args) -{ - assert(!args.empty()); - - auto &builder = s.builder(); - - auto *ret = llvm_fnz(s, args[0]); - - for (decltype(args.size()) i = 1; i < args.size(); ++i) { - auto *tmp = llvm_fnz(s, args[i]); - ret = builder.CreateLogicalOr(ret, tmp); - } - - return llvm_ui_to_fp(s, ret, args[0]->getType()); -} - -} // namespace - llvm::Value *logical_or_impl::llvm_eval(llvm_state &s, llvm::Type *fp_t, const std::vector &eval_arr, llvm::Value *par_ptr, llvm::Value *, llvm::Value *stride, std::uint32_t batch_size, bool high_accuracy) const { return llvm_eval_helper( - [&s](const std::vector &args, bool) { return logical_or_eval_impl(s, args); }, *this, s, fp_t, - eval_arr, par_ptr, stride, batch_size, high_accuracy); + [&s](const std::vector &args, bool) { return logical_andor_eval_impl(s, args); }, *this, + s, fp_t, eval_arr, par_ptr, stride, batch_size, high_accuracy); } llvm::Function *logical_or_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t batch_size, bool high_accuracy) const { return llvm_c_eval_func_helper( - "logical_or", [&s](const std::vector &args, bool) { return logical_or_eval_impl(s, args); }, - *this, s, fp_t, batch_size, high_accuracy); + "logical_or", + [&s](const std::vector &args, bool) { return logical_andor_eval_impl(s, args); }, *this, + s, fp_t, batch_size, high_accuracy); } } // namespace detail From 71e181c08f3944ab8eb06eac08aa6ffa4201d85f Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Thu, 20 Jun 2024 11:17:01 +0200 Subject: [PATCH 17/26] Initial implementation/testing of select(). --- CMakeLists.txt | 1 + include/heyoka/math.hpp | 1 + include/heyoka/math/select.hpp | 56 +++++++++++++++++++++++ src/math/select.cpp | 82 ++++++++++++++++++++++++++++++++++ test/CMakeLists.txt | 1 + 5 files changed, 141 insertions(+) create mode 100644 include/heyoka/math/select.hpp create mode 100644 src/math/select.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index c7a2344d8..bf64efe10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -315,6 +315,7 @@ set(HEYOKA_SRC_FILES "${CMAKE_CURRENT_SOURCE_DIR}/src/math/dfun.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/math/relational.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/math/logical.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/src/math/select.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/detail/string_conv.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/detail/logging_impl.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/step_callback.cpp" diff --git a/include/heyoka/math.hpp b/include/heyoka/math.hpp index 35ae6ed4a..648252be7 100644 --- a/include/heyoka/math.hpp +++ b/include/heyoka/math.hpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include diff --git a/include/heyoka/math/select.hpp b/include/heyoka/math/select.hpp new file mode 100644 index 000000000..f86f444fe --- /dev/null +++ b/include/heyoka/math/select.hpp @@ -0,0 +1,56 @@ +// Copyright 2020, 2021, 2022, 2023, 2024 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef HEYOKA_MATH_SELECT_HPP +#define HEYOKA_MATH_SELECT_HPP + +#include +#include + +#include +#include +#include +#include +#include +#include + +HEYOKA_BEGIN_NAMESPACE + +namespace detail +{ + +class HEYOKA_DLL_PUBLIC select_impl : public func_base +{ + friend class boost::serialization::access; + template + void serialize(Archive &ar, unsigned) + { + ar &boost::serialization::base_object(*this); + } + +public: + select_impl(); + explicit select_impl(expression, expression, expression); + + [[nodiscard]] std::vector gradient() const; + + [[nodiscard]] llvm::Value *llvm_eval(llvm_state &, llvm::Type *, const std::vector &, llvm::Value *, + llvm::Value *, llvm::Value *, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const; +}; + +} // namespace detail + +HEYOKA_DLL_PUBLIC expression select(expression, expression, expression); + +HEYOKA_END_NAMESPACE + +HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::select_impl) + +#endif diff --git a/src/math/select.cpp b/src/math/select.cpp new file mode 100644 index 000000000..4c1a1a7a4 --- /dev/null +++ b/src/math/select.cpp @@ -0,0 +1,82 @@ +// Copyright 2020, 2021, 2022, 2023, 2024 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +HEYOKA_BEGIN_NAMESPACE + +namespace detail +{ + +select_impl::select_impl() : select_impl(0_dbl, 0_dbl, 0_dbl) {} + +select_impl::select_impl(expression cond, expression t, expression f) + : func_base("select", {std::move(cond), std::move(t), std::move(f)}) +{ +} + +std::vector select_impl::gradient() const +{ + return {0_dbl, select(args()[0], 1_dbl, 0_dbl), select(args()[0], 0_dbl, 1_dbl)}; +} + +namespace +{ + +llvm::Value *select_eval_impl(llvm_state &s, const std::vector &args) +{ + assert(args.size() == 3u); + + return s.builder().CreateSelect(llvm_fnz(s, args[0]), args[1], args[2]); +} + +} // namespace + +llvm::Value *select_impl::llvm_eval(llvm_state &s, llvm::Type *fp_t, const std::vector &eval_arr, + llvm::Value *par_ptr, llvm::Value *, llvm::Value *stride, std::uint32_t batch_size, + bool high_accuracy) const +{ + return llvm_eval_helper([&s](const std::vector &args, bool) { return select_eval_impl(s, args); }, + *this, s, fp_t, eval_arr, par_ptr, stride, batch_size, high_accuracy); +} + +llvm::Function *select_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t batch_size, + bool high_accuracy) const +{ + return llvm_c_eval_func_helper( + "select", [&s](const std::vector &args, bool) { return select_eval_impl(s, args); }, *this, s, + fp_t, batch_size, high_accuracy); +} + +} // namespace detail + +expression select(expression cond, expression t, expression f) +{ + return expression{func{detail::select_impl{std::move(cond), std::move(t), std::move(f)}}}; +} + +HEYOKA_END_NAMESPACE + +// NOLINTNEXTLINE(cert-err58-cpp) +HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::select_impl) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 0c6f13ceb..0badf4878 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -122,6 +122,7 @@ ADD_HEYOKA_TESTCASE(sub) ADD_HEYOKA_TESTCASE(time) ADD_HEYOKA_TESTCASE(rel) ADD_HEYOKA_TESTCASE(logical) +ADD_HEYOKA_TESTCASE(select) ADD_HEYOKA_TESTCASE(wavy_ramp) ADD_HEYOKA_TESTCASE(dfloat_time) ADD_HEYOKA_TESTCASE(timestep_check) From ad045aa31396143ef8f6e5f0d4c4cbbc9ad9dc4b Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Thu, 20 Jun 2024 11:18:22 +0200 Subject: [PATCH 18/26] Missing file. --- test/select.cpp | 235 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 test/select.cpp diff --git a/test/select.cpp b/test/select.cpp new file mode 100644 index 000000000..3e6cd921c --- /dev/null +++ b/test/select.cpp @@ -0,0 +1,235 @@ +// Copyright 2020, 2021, 2022, 2023, 2024 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#if defined(HEYOKA_HAVE_REAL128) + +#include + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +#include + +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "catch.hpp" +#include "test_utils.hpp" + +static std::mt19937 rng; + +using namespace heyoka; +using namespace heyoka_test; + +const auto fp_types = std::tuple{}; + +constexpr bool skip_batch_ld = +#if LLVM_VERSION_MAJOR <= 17 + std::numeric_limits::digits == 64 +#else + false +#endif + ; + +TEST_CASE("basic") +{ + auto [x, y, z] = make_vars("x", "y", "z"); + + REQUIRE(expression{func{detail::select_impl{}}} == expression{func{detail::select_impl{0_dbl, 0_dbl, 0_dbl}}}); +} + +TEST_CASE("stream") +{ + auto [x, y, z] = make_vars("x", "y", "z"); + + { + std::ostringstream oss; + oss << select(x, y, z); + REQUIRE(oss.str() == "select(x, y, z)"); + } + + { + std::ostringstream oss; + oss << select(x, y + z, y - z); + REQUIRE(oss.str() == "select(x, (y + z), (y - z))"); + } +} + +TEST_CASE("diff") +{ + auto [x, y, z] = make_vars("x", "y", "z"); + + REQUIRE(diff(select(x, y * z, y / z), x) == 0_dbl); + REQUIRE(diff(select(x, y * z, y / z), y) == ((select(x, 1_dbl, 0_dbl) * z) + (select(x, 0_dbl, 1_dbl) / z))); +} + +TEST_CASE("s11n") +{ + std::stringstream ss; + + auto [x, y, z] = make_vars("x", "y", "z"); + + auto ex = select(x, y * z, y / z); + + { + boost::archive::binary_oarchive oa(ss); + + oa << ex; + } + + ex = 1_dbl; + + { + boost::archive::binary_iarchive ia(ss); + + ia >> ex; + } + + REQUIRE(ex == select(x, y * z, y / z)); +} + +TEST_CASE("cfunc") +{ + auto tester = [](auto fp_x, unsigned opt_level, bool high_accuracy, bool compact_mode) { + using fp_t = decltype(fp_x); + + auto [x, y] = make_vars("x", "y"); + + std::uniform_real_distribution rdist(-1., 1.); + + auto gen = [&rdist]() { return static_cast(rdist(rng)); }; + + std::vector outs, ins, pars, time; + + for (auto batch_size : {1u, 2u, 4u, 5u}) { + if (batch_size != 1u && std::is_same_v && skip_batch_ld) { + continue; + } + + outs.resize(batch_size * 2u); + ins.resize(batch_size * 2u); + pars.resize(batch_size); + time.resize(batch_size); + + std::generate(ins.begin(), ins.end(), gen); + std::generate(pars.begin(), pars.end(), gen); + std::generate(time.begin(), time.end(), gen); + + llvm_state s{kw::opt_level = opt_level}; + + add_cfunc(s, "cfunc", + {select(logical_and({gt(x, par[0]), lt(x, 2. * y)}), x * x, y * y), + select(logical_or({lte(x, 0_dbl), lt(y, heyoka::time)}), x / y, y / x)}, + {x, y}, kw::batch_size = batch_size, kw::high_accuracy = high_accuracy, + kw::compact_mode = compact_mode); + + if (opt_level == 0u && compact_mode) { + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.select.")); + } + + s.compile(); + + auto *cf_ptr + = reinterpret_cast(s.jit_lookup("cfunc")); + + cf_ptr(outs.data(), ins.data(), pars.data(), time.data()); + + for (auto i = 0u; i < batch_size; ++i) { + REQUIRE(outs[i] + == ((ins[i] > pars[i] && ins[i] < 2 * ins[i + batch_size]) + ? ins[i] * ins[i] + : ins[i + batch_size] * ins[i + batch_size])); + REQUIRE(outs[i + batch_size] + == ((ins[i] <= 0 || ins[i + batch_size] < time[i]) ? ins[i] / ins[i + batch_size] + : ins[i + batch_size] / ins[i])); + } + } + }; + + for (auto cm : {false, true}) { + for (auto f : {false, true}) { + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 0, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 1, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 2, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 3, f, cm); }); + } + } +} + +#if defined(HEYOKA_HAVE_REAL) + +TEST_CASE("cfunc mp") +{ + auto [x, y] = make_vars("x", "y"); + + const auto prec = 237u; + + for (auto compact_mode : {false, true}) { + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + add_cfunc(s, "cfunc", + {select(logical_and({gt(x, par[0]), lt(x, 2. * y)}), x * x, y * y), + select(logical_or({lte(x, 0_dbl), lt(y, heyoka::time)}), x / y, y / x)}, + {x, y}, kw::compact_mode = compact_mode, kw::prec = prec); + + s.compile(); + + auto *cf_ptr + = reinterpret_cast( + s.jit_lookup("cfunc")); + + const std::vector ins{mppp::real{".7", prec}, mppp::real{"-.1", prec}}; + const std::vector pars{mppp::real{"0", prec}}; + const std::vector time{mppp::real{".3", prec}}; + std::vector outs(8u, mppp::real{0, prec}); + + cf_ptr(outs.data(), ins.data(), pars.data(), time.data()); + + auto i = 0u; + auto batch_size = 1u; + REQUIRE(outs[i] == mppp::real{"-.1", prec} * mppp::real{"-.1", prec}); + REQUIRE(outs[i + batch_size] == mppp::real{".7", prec} / mppp::real{"-.1", prec}); + } + } +} + +#endif From 564ea8bf9b0be76210dfc42209e42627f5fde635 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Thu, 20 Jun 2024 11:19:32 +0200 Subject: [PATCH 19/26] Try to fix warning in test. --- test/logical.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/logical.cpp b/test/logical.cpp index e921edb37..6bb59c413 100644 --- a/test/logical.cpp +++ b/test/logical.cpp @@ -67,6 +67,13 @@ constexpr bool skip_batch_ld = #endif ; +#if defined(__GNUC__) || defined(__clang__) + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wfloat-conversion" + +#endif + TEST_CASE("basic") { auto x = make_vars("x"); @@ -433,3 +440,9 @@ TEST_CASE("cfunc logical_or mp") } #endif + +#if defined(__GNUC__) || defined(__clang__) + +#pragma GCC diagnostic pop + +#endif From 8e6ce67b6aa6ab5173f64505d8f0ce1ba0837799 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Thu, 20 Jun 2024 15:52:39 +0200 Subject: [PATCH 20/26] More tweaks to the implementation. --- src/math/logical.cpp | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/math/logical.cpp b/src/math/logical.cpp index 35153da2d..3f1e3eb65 100644 --- a/src/math/logical.cpp +++ b/src/math/logical.cpp @@ -6,8 +6,10 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +#include #include #include +#include #include #include @@ -44,27 +46,27 @@ std::vector logical_and_impl::gradient() const namespace { -// NOTE: perhaps it could be worth it to implement this via a pairwise reduction. We cannot -// do this right now because we need to transform args via llvm_fnz() before performing the -// reduction and pairwise_reduce() does not support this pre-transform. -template -llvm::Value *logical_andor_eval_impl(llvm_state &s, const std::vector &args) +llvm::Value *logical_andor_eval_impl(llvm_state &s, const std::vector &args, bool is_and) { assert(!args.empty()); auto &builder = s.builder(); - auto *ret = llvm_fnz(s, args[0]); + // Convert the values in args into booleans. + std::vector tmp; + tmp.reserve(args.size()); + std::ranges::transform(args, std::back_inserter(tmp), [&s](auto *v) { return llvm_fnz(s, v); }); - for (decltype(args.size()) i = 1; i < args.size(); ++i) { - auto *tmp = llvm_fnz(s, args[i]); - if constexpr (IsAnd) { - ret = builder.CreateLogicalAnd(ret, tmp); + // Run a pairwise AND/OR reduction on the transformed values. + auto *ret = pairwise_reduce(tmp, [&builder, is_and](auto *a, auto *b) { + if (is_and) { + return builder.CreateLogicalAnd(a, b); } else { - ret = builder.CreateLogicalOr(ret, tmp); + return builder.CreateLogicalOr(a, b); } - } + }); + // Convert back to floating-point. return llvm_ui_to_fp(s, ret, args[0]->getType()); } @@ -75,7 +77,7 @@ llvm::Value *logical_and_impl::llvm_eval(llvm_state &s, llvm::Type *fp_t, const std::uint32_t batch_size, bool high_accuracy) const { return llvm_eval_helper( - [&s](const std::vector &args, bool) { return logical_andor_eval_impl(s, args); }, *this, s, + [&s](const std::vector &args, bool) { return logical_andor_eval_impl(s, args, true); }, *this, s, fp_t, eval_arr, par_ptr, stride, batch_size, high_accuracy); } @@ -84,7 +86,7 @@ llvm::Function *logical_and_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp { return llvm_c_eval_func_helper( "logical_and", - [&s](const std::vector &args, bool) { return logical_andor_eval_impl(s, args); }, *this, s, + [&s](const std::vector &args, bool) { return logical_andor_eval_impl(s, args, true); }, *this, s, fp_t, batch_size, high_accuracy); } @@ -105,7 +107,7 @@ llvm::Value *logical_or_impl::llvm_eval(llvm_state &s, llvm::Type *fp_t, const s std::uint32_t batch_size, bool high_accuracy) const { return llvm_eval_helper( - [&s](const std::vector &args, bool) { return logical_andor_eval_impl(s, args); }, *this, + [&s](const std::vector &args, bool) { return logical_andor_eval_impl(s, args, false); }, *this, s, fp_t, eval_arr, par_ptr, stride, batch_size, high_accuracy); } @@ -114,7 +116,7 @@ llvm::Function *logical_or_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_ { return llvm_c_eval_func_helper( "logical_or", - [&s](const std::vector &args, bool) { return logical_andor_eval_impl(s, args); }, *this, + [&s](const std::vector &args, bool) { return logical_andor_eval_impl(s, args, false); }, *this, s, fp_t, batch_size, high_accuracy); } From 27af4d0420e76ceb18e825dfefb1ca7fc3df0f33 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Fri, 21 Jun 2024 08:45:51 +0200 Subject: [PATCH 21/26] Implement taylor diff support for the relational ops. --- include/heyoka/math/relational.hpp | 7 ++ src/math/relational.cpp | 158 +++++++++++++++++++++++++++++ test/rel.cpp | 105 +++++++++++++++++++ 3 files changed, 270 insertions(+) diff --git a/include/heyoka/math/relational.hpp b/include/heyoka/math/relational.hpp index 6e86b3886..eafb628bd 100644 --- a/include/heyoka/math/relational.hpp +++ b/include/heyoka/math/relational.hpp @@ -53,6 +53,13 @@ class HEYOKA_DLL_PUBLIC rel_impl : public func_base llvm::Value *, llvm::Value *, std::uint32_t, bool) const; [[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Value *taylor_diff(llvm_state &, llvm::Type *, const std::vector &, + const std::vector &, llvm::Value *, llvm::Value *, + std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Function *taylor_c_diff_func(llvm_state &, llvm::Type *, std::uint32_t, std::uint32_t, + bool) const; }; } // namespace detail diff --git a/src/math/relational.cpp b/src/math/relational.cpp index 0490aff69..a1c2486e9 100644 --- a/src/math/relational.cpp +++ b/src/math/relational.cpp @@ -7,25 +7,37 @@ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #include +#include #include #include +#include #include #include +#include #include +#include +#include #include #include +#include +#include #include #include #include #include +#include #include #include #include #include +#include +#include #include +#include +#include HEYOKA_BEGIN_NAMESPACE @@ -180,6 +192,152 @@ llvm::Function *rel_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_t, std: fp_t, batch_size, high_accuracy); } +llvm::Value *rel_impl::taylor_diff(llvm_state &s, llvm::Type *fp_t, const std::vector &deps, + const std::vector &arr, llvm::Value *par_ptr, llvm::Value *, + std::uint32_t n_uvars, std::uint32_t order, std::uint32_t, std::uint32_t batch_size, + bool) const +{ + assert(args().size() == 2u); + assert(deps.empty()); + + // NOTE: we need to do something only at differentiation order 0. + if (order == 0u) { + std::vector tmp; + tmp.reserve(2); + + for (const auto &cur_arg : args()) { + tmp.push_back(std::visit( + [&](const T &v) -> llvm::Value * { + if constexpr (std::same_as) { + // Variable. + return taylor_fetch_diff(arr, uname_to_index(v.name()), 0, n_uvars); + } else if constexpr (is_num_param_v) { + // Number/param. + return taylor_codegen_numparam(s, fp_t, v, par_ptr, batch_size); + } else { + // LCOV_EXCL_START + throw std::invalid_argument( + "An invalid argument type was encountered while trying to build the " + "Taylor derivative of a relational operation"); + // LCOV_EXCL_STOP + } + }, + cur_arg.value())); + } + + return rel_eval_impl(s, m_op, tmp); + } else { + return vector_splat(s.builder(), llvm_codegen(s, fp_t, number{0.}), batch_size); + } +} + +llvm::Function *rel_impl::taylor_c_diff_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t n_uvars, + std::uint32_t batch_size, bool) const +{ + assert(args().size() == 2u); + + auto &md = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + // Fetch the vector floating-point type. + auto *val_t = make_vector_type(fp_t, batch_size); + + // Build the vector of arguments needed to determine the function name. + std::vector> nm_args; + nm_args.reserve(static_cast(args().size())); + for (const auto &arg : args()) { + nm_args.push_back(std::visit( + [](const T &v) -> std::variant { + if constexpr (std::same_as) { + // LCOV_EXCL_START + assert(false); + throw; + // LCOV_EXCL_STOP + } else { + return v; + } + }, + arg.value())); + } + + // Fetch the function name and arguments. + const auto [fname, fargs] + = taylor_c_diff_func_name_args(context, fp_t, name_from_op(m_op), n_uvars, batch_size, nm_args); + + // Try to see if we already created the function. + auto *f = md.getFunction(fname); + + if (f != nullptr) { + return f; + } + + // The function was not created before, do it now. + + // Fetch the current insertion block. + auto *orig_bb = builder.GetInsertBlock(); + + // The return type is val_t. + auto *ft = llvm::FunctionType::get(val_t, fargs, false); + // Create the function + f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &md); + assert(f != nullptr); + + // Fetch the necessary function arguments. + auto *order = f->args().begin(); + auto *diff_arr = f->args().begin() + 2; + auto *par_ptr = f->args().begin() + 3; + auto *operands = f->args().begin() + 5; + + // Create a new basic block to start insertion into. + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + // Create the return value. + auto *retval = builder.CreateAlloca(val_t); + + llvm_if_then_else( + s, builder.CreateICmpEQ(order, builder.getInt32(0)), + [&]() { + // For order zero, evaluate the relational operation. + std::vector vals; + vals.reserve(2); + + for (decltype(args().size()) i = 0; i < args().size(); ++i) { + vals.push_back(std::visit( + [&](const T &v) -> llvm::Value * { + if constexpr (std::same_as) { + return taylor_c_load_diff(s, val_t, diff_arr, n_uvars, order, operands + i); + } else if constexpr (is_num_param_v) { + return taylor_c_diff_numparam_codegen(s, fp_t, v, operands + i, par_ptr, batch_size); + } else { + // LCOV_EXCL_START + throw std::invalid_argument( + "An invalid argument type was encountered while trying to build the " + "Taylor derivative of a relational operation"); + // LCOV_EXCL_STOP + } + }, + args()[i].value())); + } + + builder.CreateStore(rel_eval_impl(s, m_op, vals), retval); + }, + [&]() { + // Otherwise, return zero. + builder.CreateStore(llvm_constantfp(s, val_t, 0.), retval); + }); + + builder.CreateRet(builder.CreateLoad(val_t, retval)); + + // Verify. + s.verify_function(f); + + // Restore the original insertion block. + builder.SetInsertPoint(orig_bb); + + return f; +} + } // namespace detail #define HEYOKA_MATH_REL_IMPL(op) \ diff --git a/test/rel.cpp b/test/rel.cpp index 59e7aca0f..b6b43ab17 100644 --- a/test/rel.cpp +++ b/test/rel.cpp @@ -38,8 +38,10 @@ #include #include #include +#include #include #include +#include #include "catch.hpp" #include "test_utils.hpp" @@ -274,3 +276,106 @@ TEST_CASE("cfunc_mp") } #endif + +TEST_CASE("taylor_adaptive") +{ + auto [x, v] = make_vars("x", "v"); + + for (auto opt_level : {0u, 3u}) { + for (auto cm : {false, true}) { + auto ta1 = taylor_adaptive{ + {prime(x) = v, prime(v) = -sin(x)}, {1.23, 0.}, kw::compact_mode = cm, kw::opt_level = opt_level}; + + auto ta2 = taylor_adaptive{{prime(x) = v, prime(v) = -(1. + gt(x, 1.24_dbl)) * sin(x) + gt(x, 1.24_dbl)}, + {1.23, 0.}, + kw::compact_mode = cm, + kw::opt_level = opt_level}; + + if (opt_level == 0u && cm) { + REQUIRE(boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.rel_gt.var_num.")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + + ta1 = taylor_adaptive{ + {prime(x) = v, prime(v) = -2. * sin(x)}, {1.23, 0.}, kw::compact_mode = cm, kw::opt_level = opt_level}; + ta2 = taylor_adaptive{{prime(x) = v, prime(v) = -(1. + lt(x, par[0])) * sin(x)}, + {1.23, 0.}, + kw::compact_mode = cm, + kw::opt_level = opt_level, + kw::pars = {1.24}}; + + if (opt_level == 0u && cm) { + REQUIRE(boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.rel_lt.var_par.")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + } + } +} + +TEST_CASE("taylor_adaptive_batch") +{ + auto [x, v] = make_vars("x", "v"); + + for (auto opt_level : {0u, 3u}) { + for (auto cm : {false, true}) { + auto ta1 = taylor_adaptive_batch{{prime(x) = v, prime(v) = -sin(x)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level}; + + auto ta2 + = taylor_adaptive_batch{{prime(x) = v, prime(v) = -(1. + gt(x, 1.24_dbl)) * sin(x) + gt(x, 1.24_dbl)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level}; + + if (opt_level == 0u && cm) { + REQUIRE(boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.rel_gt.var_num.")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + REQUIRE(ta1.get_state()[2] == approximately(ta2.get_state()[2])); + REQUIRE(ta1.get_state()[3] == approximately(ta2.get_state()[3])); + + ta1 = taylor_adaptive_batch{{prime(x) = v, prime(v) = -2. * sin(x)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level}; + ta2 = taylor_adaptive_batch{{prime(x) = v, prime(v) = -(1. + lt(x, par[0])) * sin(x)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level, + kw::pars = {1.24, 1.25}}; + + if (opt_level == 0u && cm) { + REQUIRE(boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.rel_lt.var_par.")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + REQUIRE(ta1.get_state()[2] == approximately(ta2.get_state()[2])); + REQUIRE(ta1.get_state()[3] == approximately(ta2.get_state()[3])); + } + } +} From ff9b08153bed9c53a21e31556505b9965215f302 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Sat, 22 Jun 2024 10:50:00 +0200 Subject: [PATCH 22/26] Taylor diff for the logical ops. --- include/heyoka/math/logical.hpp | 14 +++ src/math/logical.cpp | 214 ++++++++++++++++++++++++++++++++ test/logical.cpp | 53 ++++++++ 3 files changed, 281 insertions(+) diff --git a/include/heyoka/math/logical.hpp b/include/heyoka/math/logical.hpp index 3356f14b5..6e98b3687 100644 --- a/include/heyoka/math/logical.hpp +++ b/include/heyoka/math/logical.hpp @@ -43,6 +43,13 @@ class HEYOKA_DLL_PUBLIC logical_and_impl : public func_base llvm::Value *, llvm::Value *, std::uint32_t, bool) const; [[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Value *taylor_diff(llvm_state &, llvm::Type *, const std::vector &, + const std::vector &, llvm::Value *, llvm::Value *, + std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Function *taylor_c_diff_func(llvm_state &, llvm::Type *, std::uint32_t, std::uint32_t, + bool) const; }; class HEYOKA_DLL_PUBLIC logical_or_impl : public func_base @@ -64,6 +71,13 @@ class HEYOKA_DLL_PUBLIC logical_or_impl : public func_base llvm::Value *, llvm::Value *, std::uint32_t, bool) const; [[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Value *taylor_diff(llvm_state &, llvm::Type *, const std::vector &, + const std::vector &, llvm::Value *, llvm::Value *, + std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Function *taylor_c_diff_func(llvm_state &, llvm::Type *, std::uint32_t, std::uint32_t, + bool) const; }; } // namespace detail diff --git a/src/math/logical.cpp b/src/math/logical.cpp index 3f1e3eb65..883ce63af 100644 --- a/src/math/logical.cpp +++ b/src/math/logical.cpp @@ -8,23 +8,35 @@ #include #include +#include #include #include +#include #include +#include #include +#include +#include #include #include +#include +#include #include #include #include #include +#include #include #include #include #include +#include +#include #include +#include +#include HEYOKA_BEGIN_NAMESPACE @@ -90,6 +102,163 @@ llvm::Function *logical_and_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp fp_t, batch_size, high_accuracy); } +llvm::Value *logical_and_impl::taylor_diff(llvm_state &s, llvm::Type *fp_t, const std::vector &deps, + const std::vector &arr, llvm::Value *par_ptr, llvm::Value *, + std::uint32_t n_uvars, std::uint32_t order, std::uint32_t, + std::uint32_t batch_size, bool) const +{ + assert(!args().empty()); + assert(deps.empty()); + + // NOTE: we need to do something only at differentiation order 0. + if (order == 0u) { + std::vector tmp; + tmp.reserve(static_cast(args().size())); + + for (const auto &cur_arg : args()) { + tmp.push_back(std::visit( + [&](const T &v) -> llvm::Value * { + if constexpr (std::same_as) { + // Variable. + return taylor_fetch_diff(arr, uname_to_index(v.name()), 0, n_uvars); + } else if constexpr (is_num_param_v) { + // Number/param. + return taylor_codegen_numparam(s, fp_t, v, par_ptr, batch_size); + } else { + // LCOV_EXCL_START + throw std::invalid_argument( + "An invalid argument type was encountered while trying to build the " + "Taylor derivative of a logical_and()"); + // LCOV_EXCL_STOP + } + }, + cur_arg.value())); + } + + return logical_andor_eval_impl(s, tmp, true); + } else { + return vector_splat(s.builder(), llvm_codegen(s, fp_t, number{0.}), batch_size); + } +} + +namespace +{ + +llvm::Function *taylor_c_diff_func_logical_andor_impl(const func_base &fb, llvm_state &s, llvm::Type *fp_t, + std::uint32_t n_uvars, std::uint32_t batch_size, bool is_and) +{ + assert(!fb.args().empty()); + + auto &md = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + // Fetch the vector floating-point type. + auto *val_t = make_vector_type(fp_t, batch_size); + + // Build the vector of arguments needed to determine the function name. + std::vector> nm_args; + nm_args.reserve(static_cast(fb.args().size())); + for (const auto &arg : fb.args()) { + nm_args.push_back(std::visit( + [](const T &v) -> std::variant { + if constexpr (std::same_as) { + // LCOV_EXCL_START + assert(false); + throw; + // LCOV_EXCL_STOP + } else { + return v; + } + }, + arg.value())); + } + + // Fetch the function name and arguments. + const auto [fname, fargs] = taylor_c_diff_func_name_args(context, fp_t, is_and ? "logical_and" : "logical_or", + n_uvars, batch_size, nm_args); + + // Try to see if we already created the function. + auto *f = md.getFunction(fname); + + if (f != nullptr) { + return f; + } + + // The function was not created before, do it now. + + // Fetch the current insertion block. + auto *orig_bb = builder.GetInsertBlock(); + + // The return type is val_t. + auto *ft = llvm::FunctionType::get(val_t, fargs, false); + // Create the function + f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &md); + assert(f != nullptr); + + // Fetch the necessary function arguments. + auto *order = f->args().begin(); + auto *diff_arr = f->args().begin() + 2; + auto *par_ptr = f->args().begin() + 3; + auto *operands = f->args().begin() + 5; + + // Create a new basic block to start insertion into. + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + // Create the return value. + auto *retval = builder.CreateAlloca(val_t); + + llvm_if_then_else( + s, builder.CreateICmpEQ(order, builder.getInt32(0)), + [&]() { + // For order zero, evaluate the logical operation. + std::vector vals; + vals.reserve(2); + + for (decltype(fb.args().size()) i = 0; i < fb.args().size(); ++i) { + vals.push_back(std::visit( + [&](const T &v) -> llvm::Value * { + if constexpr (std::same_as) { + return taylor_c_load_diff(s, val_t, diff_arr, n_uvars, order, operands + i); + } else if constexpr (is_num_param_v) { + return taylor_c_diff_numparam_codegen(s, fp_t, v, operands + i, par_ptr, batch_size); + } else { + // LCOV_EXCL_START + throw std::invalid_argument( + "An invalid argument type was encountered while trying to build the " + "Taylor derivative of a logical operation"); + // LCOV_EXCL_STOP + } + }, + fb.args()[i].value())); + } + + builder.CreateStore(logical_andor_eval_impl(s, vals, is_and), retval); + }, + [&]() { + // Otherwise, return zero. + builder.CreateStore(llvm_constantfp(s, val_t, 0.), retval); + }); + + builder.CreateRet(builder.CreateLoad(val_t, retval)); + + // Verify. + s.verify_function(f); + + // Restore the original insertion block. + builder.SetInsertPoint(orig_bb); + + return f; +} + +} // namespace + +llvm::Function *logical_and_impl::taylor_c_diff_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t n_uvars, + std::uint32_t batch_size, bool) const +{ + return taylor_c_diff_func_logical_andor_impl(*this, s, fp_t, n_uvars, batch_size, true); +} + logical_or_impl::logical_or_impl() : logical_or_impl({1_dbl}) {} logical_or_impl::logical_or_impl(std::vector args) : func_base("logical_or", std::move(args)) @@ -120,6 +289,51 @@ llvm::Function *logical_or_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_ s, fp_t, batch_size, high_accuracy); } +llvm::Value *logical_or_impl::taylor_diff(llvm_state &s, llvm::Type *fp_t, const std::vector &deps, + const std::vector &arr, llvm::Value *par_ptr, llvm::Value *, + std::uint32_t n_uvars, std::uint32_t order, std::uint32_t, + std::uint32_t batch_size, bool) const +{ + assert(!args().empty()); + assert(deps.empty()); + + // NOTE: we need to do something only at differentiation order 0. + if (order == 0u) { + std::vector tmp; + tmp.reserve(static_cast(args().size())); + + for (const auto &cur_arg : args()) { + tmp.push_back(std::visit( + [&](const T &v) -> llvm::Value * { + if constexpr (std::same_as) { + // Variable. + return taylor_fetch_diff(arr, uname_to_index(v.name()), 0, n_uvars); + } else if constexpr (is_num_param_v) { + // Number/param. + return taylor_codegen_numparam(s, fp_t, v, par_ptr, batch_size); + } else { + // LCOV_EXCL_START + throw std::invalid_argument( + "An invalid argument type was encountered while trying to build the " + "Taylor derivative of a logical_or()"); + // LCOV_EXCL_STOP + } + }, + cur_arg.value())); + } + + return logical_andor_eval_impl(s, tmp, false); + } else { + return vector_splat(s.builder(), llvm_codegen(s, fp_t, number{0.}), batch_size); + } +} + +llvm::Function *logical_or_impl::taylor_c_diff_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t n_uvars, + std::uint32_t batch_size, bool) const +{ + return taylor_c_diff_func_logical_andor_impl(*this, s, fp_t, n_uvars, batch_size, false); +} + } // namespace detail expression logical_and(std::vector args) diff --git a/test/logical.cpp b/test/logical.cpp index 6bb59c413..cb436fabb 100644 --- a/test/logical.cpp +++ b/test/logical.cpp @@ -37,8 +37,11 @@ #include #include #include +#include +#include #include #include +#include #include "catch.hpp" #include "test_utils.hpp" @@ -441,6 +444,56 @@ TEST_CASE("cfunc logical_or mp") #endif +TEST_CASE("taylor_adaptive") +{ + auto [x, v] = make_vars("x", "v"); + + for (auto opt_level : {0u, 3u}) { + for (auto cm : {false, true}) { + auto ta1 = taylor_adaptive{ + {prime(x) = v, prime(v) = -sin(x)}, {1.23, 0.}, kw::compact_mode = cm, kw::opt_level = opt_level}; + + auto ta2 = taylor_adaptive{ + {prime(x) = v, prime(v) = -(1. + logical_and({lt(x, 1.24_dbl), gt(x, 1.24_dbl), 0_dbl})) * sin(x)}, + {1.23, 0.}, + kw::compact_mode = cm, + kw::opt_level = opt_level}; + + if (opt_level == 0u && cm) { + REQUIRE( + boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.logical_and.var_var_num.")); + REQUIRE(!boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.logical_or")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + + ta1 = taylor_adaptive{ + {prime(x) = v, prime(v) = -2. * sin(x)}, {1.23, 0.}, kw::compact_mode = cm, kw::opt_level = opt_level}; + ta2 = taylor_adaptive{ + {prime(x) = v, prime(v) = -(1. + logical_or({lt(x, par[0]), gte(x, par[0]), par[0]})) * sin(x)}, + {1.23, 0.}, + kw::compact_mode = cm, + kw::opt_level = opt_level, + kw::pars = {1.24}}; + + if (opt_level == 0u && cm) { + REQUIRE(boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.logical_or.var_var_par.")); + REQUIRE(!boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.logical_and")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + } + } +} + #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop From c04675c87c57bbc72532574cb38644b782c8ae12 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Sat, 22 Jun 2024 10:54:14 +0200 Subject: [PATCH 23/26] Missing batch testing. --- test/logical.cpp | 62 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/test/logical.cpp b/test/logical.cpp index cb436fabb..bf7a39ce9 100644 --- a/test/logical.cpp +++ b/test/logical.cpp @@ -494,6 +494,68 @@ TEST_CASE("taylor_adaptive") } } +TEST_CASE("taylor_adaptive_batch") +{ + auto [x, v] = make_vars("x", "v"); + + for (auto opt_level : {0u, 3u}) { + for (auto cm : {false, true}) { + auto ta1 = taylor_adaptive_batch{{prime(x) = v, prime(v) = -sin(x)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level}; + + auto ta2 = taylor_adaptive_batch{ + {prime(x) = v, prime(v) = -(1. + logical_and({lt(x, 1.24_dbl), gt(x, 1.24_dbl), 0_dbl})) * sin(x)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level}; + + if (opt_level == 0u && cm) { + REQUIRE( + boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.logical_and.var_var_num.")); + REQUIRE(!boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.logical_or")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + REQUIRE(ta1.get_state()[2] == approximately(ta2.get_state()[2])); + REQUIRE(ta1.get_state()[3] == approximately(ta2.get_state()[3])); + + ta1 = taylor_adaptive_batch{{prime(x) = v, prime(v) = -2. * sin(x)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level}; + ta2 = taylor_adaptive_batch{ + {prime(x) = v, prime(v) = -(1. + logical_or({lt(x, par[0]), gte(x, par[0]), par[0]})) * sin(x)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level, + kw::pars = {1.24, 1.25}}; + + if (opt_level == 0u && cm) { + REQUIRE(boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.logical_or.var_var_par.")); + REQUIRE(!boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.logical_and")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + REQUIRE(ta1.get_state()[2] == approximately(ta2.get_state()[2])); + REQUIRE(ta1.get_state()[3] == approximately(ta2.get_state()[3])); + } + } +} + #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop From 17495cd9b6c39f9346dedc73bdbdc9b581378bc8 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Sat, 22 Jun 2024 14:54:39 +0200 Subject: [PATCH 24/26] Implement and test taylor diff for select(). --- include/heyoka/math/select.hpp | 7 ++ src/math/select.cpp | 192 +++++++++++++++++++++++++++++++++ test/select.cpp | 149 +++++++++++++++++++++++++ 3 files changed, 348 insertions(+) diff --git a/include/heyoka/math/select.hpp b/include/heyoka/math/select.hpp index f86f444fe..7f6b10eea 100644 --- a/include/heyoka/math/select.hpp +++ b/include/heyoka/math/select.hpp @@ -43,6 +43,13 @@ class HEYOKA_DLL_PUBLIC select_impl : public func_base llvm::Value *, llvm::Value *, std::uint32_t, bool) const; [[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Value *taylor_diff(llvm_state &, llvm::Type *, const std::vector &, + const std::vector &, llvm::Value *, llvm::Value *, + std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Function *taylor_c_diff_func(llvm_state &, llvm::Type *, std::uint32_t, std::uint32_t, + bool) const; }; } // namespace detail diff --git a/src/math/select.cpp b/src/math/select.cpp index 4c1a1a7a4..338a2d580 100644 --- a/src/math/select.cpp +++ b/src/math/select.cpp @@ -7,22 +7,34 @@ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #include +#include #include +#include #include +#include #include +#include +#include #include #include +#include +#include #include #include #include #include +#include #include #include #include #include +#include +#include #include +#include +#include HEYOKA_BEGIN_NAMESPACE @@ -69,6 +81,186 @@ llvm::Function *select_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_t, s fp_t, batch_size, high_accuracy); } +llvm::Value *select_impl::taylor_diff(llvm_state &s, llvm::Type *fp_t, const std::vector &deps, + const std::vector &arr, llvm::Value *par_ptr, llvm::Value *, + std::uint32_t n_uvars, std::uint32_t order, std::uint32_t, + std::uint32_t batch_size, bool) const +{ + assert(args().size() == 3u); + assert(deps.empty()); + + std::vector tmp; + tmp.reserve(3); + + // For the condition, we always use the order-0 values. + tmp.push_back(std::visit( + [&](const T &v) -> llvm::Value * { + if constexpr (std::same_as) { + // Variable. + return taylor_fetch_diff(arr, uname_to_index(v.name()), 0, n_uvars); + } else if constexpr (is_num_param_v) { + // Number/param. + return taylor_codegen_numparam(s, fp_t, v, par_ptr, batch_size); + } else { + // LCOV_EXCL_START + throw std::invalid_argument("An invalid argument type was encountered while trying to build the " + "Taylor derivative of select()"); + // LCOV_EXCL_STOP + } + }, + args()[0].value())); + + // For the branches, we use the order-n derivatives. + for (decltype(args().size()) i = 1; i < args().size(); ++i) { + tmp.push_back(std::visit( + [&](const T &v) -> llvm::Value * { + if constexpr (std::same_as) { + // Variable. + return taylor_fetch_diff(arr, uname_to_index(v.name()), order, n_uvars); + } else if constexpr (is_num_param_v) { + // Number/param. + if (order == 0u) { + return taylor_codegen_numparam(s, fp_t, v, par_ptr, batch_size); + } else { + return vector_splat(s.builder(), llvm_codegen(s, fp_t, number{0.}), batch_size); + } + } else { + // LCOV_EXCL_START + throw std::invalid_argument("An invalid argument type was encountered while trying to build the " + "Taylor derivative of select()"); + // LCOV_EXCL_STOP + } + }, + args()[i].value())); + } + + return select_eval_impl(s, tmp); +} + +llvm::Function *select_impl::taylor_c_diff_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t n_uvars, + std::uint32_t batch_size, bool) const +{ + assert(args().size() == 3u); + + auto &md = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + // Fetch the vector floating-point type. + auto *val_t = make_vector_type(fp_t, batch_size); + + // Build the vector of arguments needed to determine the function name. + std::vector> nm_args; + nm_args.reserve(static_cast(args().size())); + for (const auto &arg : args()) { + nm_args.push_back(std::visit( + [](const T &v) -> std::variant { + if constexpr (std::same_as) { + // LCOV_EXCL_START + assert(false); + throw; + // LCOV_EXCL_STOP + } else { + return v; + } + }, + arg.value())); + } + + // Fetch the function name and arguments. + const auto [fname, fargs] = taylor_c_diff_func_name_args(context, fp_t, "select", n_uvars, batch_size, nm_args); + + // Try to see if we already created the function. + auto *f = md.getFunction(fname); + + if (f != nullptr) { + return f; + } + + // The function was not created before, do it now. + + // Fetch the current insertion block. + auto *orig_bb = builder.GetInsertBlock(); + + // The return type is val_t. + auto *ft = llvm::FunctionType::get(val_t, fargs, false); + // Create the function + f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &md); + assert(f != nullptr); + + // Fetch the necessary function arguments. + auto *order = f->args().begin(); + auto *diff_arr = f->args().begin() + 2; + auto *par_ptr = f->args().begin() + 3; + auto *operands = f->args().begin() + 5; + + // Create a new basic block to start insertion into. + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + std::vector vals; + vals.reserve(3); + + // For the condition, we always use the order-0 values. + vals.push_back(std::visit( + [&](const T &v) -> llvm::Value * { + if constexpr (std::same_as) { + return taylor_c_load_diff(s, val_t, diff_arr, n_uvars, builder.getInt32(0), operands); + } else if constexpr (is_num_param_v) { + return taylor_c_diff_numparam_codegen(s, fp_t, v, operands, par_ptr, batch_size); + } else { + // LCOV_EXCL_START + throw std::invalid_argument("An invalid argument type was encountered while trying to build the " + "Taylor derivative of select()"); + // LCOV_EXCL_STOP + } + }, + args()[0].value())); + + // For the branches, we use the order-n derivatives. + for (decltype(args().size()) i = 1; i < args().size(); ++i) { + vals.push_back(std::visit( + [&](const T &v) -> llvm::Value * { + if constexpr (std::same_as) { + return taylor_c_load_diff(s, val_t, diff_arr, n_uvars, order, operands + i); + } else if constexpr (is_num_param_v) { + // Create the return value. + auto *retval = builder.CreateAlloca(val_t); + + llvm_if_then_else( + s, builder.CreateICmpEQ(order, builder.getInt32(0)), + [&]() { + // If the order is zero, run the codegen. + builder.CreateStore( + taylor_c_diff_numparam_codegen(s, fp_t, v, operands + i, par_ptr, batch_size), retval); + }, + [&]() { + // Otherwise, return zero. + builder.CreateStore(vector_splat(builder, llvm_codegen(s, fp_t, number{0.}), batch_size), + retval); + }); + + return builder.CreateLoad(val_t, retval); + } else { + // LCOV_EXCL_START + throw std::invalid_argument("An invalid argument type was encountered while trying to build the " + "Taylor derivative of select()"); + // LCOV_EXCL_STOP + } + }, + args()[i].value())); + } + + builder.CreateRet(select_eval_impl(s, vals)); + + // Verify. + s.verify_function(f); + + // Restore the original insertion block. + builder.SetInsertPoint(orig_bb); + + return f; +} + } // namespace detail expression select(expression cond, expression t, expression f) diff --git a/test/select.cpp b/test/select.cpp index 3e6cd921c..9390bd6b7 100644 --- a/test/select.cpp +++ b/test/select.cpp @@ -39,8 +39,10 @@ #include #include #include +#include #include #include +#include #include "catch.hpp" #include "test_utils.hpp" @@ -233,3 +235,150 @@ TEST_CASE("cfunc mp") } #endif + +TEST_CASE("taylor_adaptive") +{ + auto [x, v] = make_vars("x", "v"); + + for (auto opt_level : {0u, 3u}) { + for (auto cm : {false, true}) { + auto ta1 = taylor_adaptive{ + {prime(x) = v, prime(v) = -sin(x)}, {1.23, 0.}, kw::compact_mode = cm, kw::opt_level = opt_level}; + + auto ta2 = taylor_adaptive{ + {prime(x) = v, prime(v) = -(1. + select(gt(x, 1.24_dbl), 1. - par[0], 0_dbl)) * sin(x)}, + {1.23, 0.}, + kw::compact_mode = cm, + kw::opt_level = opt_level}; + + if (opt_level == 0u && cm) { + REQUIRE(boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.select.var_var_num.")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + + ta1 = taylor_adaptive{ + {prime(x) = v, prime(v) = -2. * sin(x)}, {1.23, 0.}, kw::compact_mode = cm, kw::opt_level = opt_level}; + ta2 = taylor_adaptive{{prime(x) = v, prime(v) = -(1. + select(lt(x, 1.24_dbl), par[0], 0_dbl)) * sin(x)}, + {1.23, 0.}, + kw::compact_mode = cm, + kw::opt_level = opt_level, + kw::pars = {1.}}; + + if (opt_level == 0u && cm) { + REQUIRE(boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.select.var_par_num.")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + + ta1 = taylor_adaptive{ + {prime(x) = v, prime(v) = -2. * sin(x)}, {1.23, 0.}, kw::compact_mode = cm, kw::opt_level = opt_level}; + ta2 = taylor_adaptive{{prime(x) = v, prime(v) = -(1. + select(par[0], par[0], 0_dbl)) * sin(x)}, + {1.23, 0.}, + kw::compact_mode = cm, + kw::opt_level = opt_level, + kw::pars = {1.}}; + + if (opt_level == 0u && cm) { + REQUIRE(boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.select.par_par_num.")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + } + } +} + +TEST_CASE("taylor_adaptive_batch") +{ + auto [x, v] = make_vars("x", "v"); + + for (auto opt_level : {0u, 3u}) { + for (auto cm : {false, true}) { + auto ta1 = taylor_adaptive_batch{{prime(x) = v, prime(v) = -sin(x)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level}; + + auto ta2 = taylor_adaptive_batch{ + {prime(x) = v, prime(v) = -(1. + select(gt(x, 1.24_dbl), 1. - par[0], 0_dbl)) * sin(x)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level}; + + if (opt_level == 0u && cm) { + REQUIRE(boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.select.var_var_num.")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + REQUIRE(ta1.get_state()[2] == approximately(ta2.get_state()[2])); + REQUIRE(ta1.get_state()[3] == approximately(ta2.get_state()[3])); + + ta1 = taylor_adaptive_batch{{prime(x) = v, prime(v) = -2. * sin(x)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level}; + ta2 = taylor_adaptive_batch{ + {prime(x) = v, prime(v) = -(1. + select(lt(x, 1.24_dbl), par[0], 0_dbl)) * sin(x)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level, + kw::pars = {1., 1.}}; + + if (opt_level == 0u && cm) { + REQUIRE(boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.select.var_par_num.")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + REQUIRE(ta1.get_state()[2] == approximately(ta2.get_state()[2])); + REQUIRE(ta1.get_state()[3] == approximately(ta2.get_state()[3])); + + ta1 = taylor_adaptive_batch{{prime(x) = v, prime(v) = -2. * sin(x)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level}; + ta2 = taylor_adaptive_batch{{prime(x) = v, prime(v) = -(1. + select(par[0], par[0], 0_dbl)) * sin(x)}, + {1.23, 1.22, 0., 0.}, + 2u, + kw::compact_mode = cm, + kw::opt_level = opt_level, + kw::pars = {1., 1.}}; + + if (opt_level == 0u && cm) { + REQUIRE(boost::contains(ta2.get_llvm_state().get_ir(), "heyoka.taylor_c_diff.select.par_par_num.")); + } + + ta1.propagate_until(5.); + ta2.propagate_until(5.); + + REQUIRE(ta1.get_state()[0] == approximately(ta2.get_state()[0])); + REQUIRE(ta1.get_state()[1] == approximately(ta2.get_state()[1])); + REQUIRE(ta1.get_state()[2] == approximately(ta2.get_state()[2])); + REQUIRE(ta1.get_state()[3] == approximately(ta2.get_state()[3])); + } + } +} From 97190027d8a6b8582659a3744a11514c7c09b330 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Sat, 22 Jun 2024 16:15:13 +0200 Subject: [PATCH 25/26] Add numerical overloads for select() and the relational ops. --- include/heyoka/math/relational.hpp | 47 ++++++++++++++++- include/heyoka/math/select.hpp | 41 +++++++++++++- src/math/relational.cpp | 85 +++++++++++++++++++++++++++++- src/math/select.cpp | 61 ++++++++++++++++++++- test/rel.cpp | 16 ++++++ test/select.cpp | 17 ++++++ 6 files changed, 263 insertions(+), 4 deletions(-) diff --git a/include/heyoka/math/relational.hpp b/include/heyoka/math/relational.hpp index eafb628bd..088d1388d 100644 --- a/include/heyoka/math/relational.hpp +++ b/include/heyoka/math/relational.hpp @@ -9,11 +9,24 @@ #ifndef HEYOKA_MATH_RELATIONAL_HPP #define HEYOKA_MATH_RELATIONAL_HPP +#include + #include #include #include -#include +#if defined(HEYOKA_HAVE_REAL128) + +#include + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +#include + +#endif + #include #include #include @@ -71,6 +84,38 @@ HEYOKA_DLL_PUBLIC expression gt(expression, expression); HEYOKA_DLL_PUBLIC expression lte(expression, expression); HEYOKA_DLL_PUBLIC expression gte(expression, expression); +#define HEYOKA_DECLARE_REL_OVERLOADS(type) \ + HEYOKA_DLL_PUBLIC expression eq(expression, type); \ + HEYOKA_DLL_PUBLIC expression eq(type, expression); \ + HEYOKA_DLL_PUBLIC expression neq(expression, type); \ + HEYOKA_DLL_PUBLIC expression neq(type, expression); \ + HEYOKA_DLL_PUBLIC expression lt(expression, type); \ + HEYOKA_DLL_PUBLIC expression lt(type, expression); \ + HEYOKA_DLL_PUBLIC expression gt(expression, type); \ + HEYOKA_DLL_PUBLIC expression gt(type, expression); \ + HEYOKA_DLL_PUBLIC expression lte(expression, type); \ + HEYOKA_DLL_PUBLIC expression lte(type, expression); \ + HEYOKA_DLL_PUBLIC expression gte(expression, type); \ + HEYOKA_DLL_PUBLIC expression gte(type, expression); + +HEYOKA_DECLARE_REL_OVERLOADS(float); +HEYOKA_DECLARE_REL_OVERLOADS(double); +HEYOKA_DECLARE_REL_OVERLOADS(long double); + +#if defined(HEYOKA_HAVE_REAL128) + +HEYOKA_DECLARE_REL_OVERLOADS(mppp::real128); + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +HEYOKA_DECLARE_REL_OVERLOADS(mppp::real); + +#endif + +#undef HEYOKA_DECLARE_REL_OVERLOADS + HEYOKA_END_NAMESPACE HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::rel_impl) diff --git a/include/heyoka/math/select.hpp b/include/heyoka/math/select.hpp index 7f6b10eea..079378865 100644 --- a/include/heyoka/math/select.hpp +++ b/include/heyoka/math/select.hpp @@ -9,10 +9,23 @@ #ifndef HEYOKA_MATH_SELECT_HPP #define HEYOKA_MATH_SELECT_HPP +#include + #include #include -#include +#if defined(HEYOKA_HAVE_REAL128) + +#include + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +#include + +#endif + #include #include #include @@ -56,6 +69,32 @@ class HEYOKA_DLL_PUBLIC select_impl : public func_base HEYOKA_DLL_PUBLIC expression select(expression, expression, expression); +#define HEYOKA_DECLARE_SELECT_OVERLOADS(type) \ + HEYOKA_DLL_PUBLIC expression select(expression, type, type); \ + HEYOKA_DLL_PUBLIC expression select(type, expression, type); \ + HEYOKA_DLL_PUBLIC expression select(type, type, expression); \ + HEYOKA_DLL_PUBLIC expression select(expression, expression, type); \ + HEYOKA_DLL_PUBLIC expression select(expression, type, expression); \ + HEYOKA_DLL_PUBLIC expression select(type, expression, expression) + +HEYOKA_DECLARE_SELECT_OVERLOADS(float); +HEYOKA_DECLARE_SELECT_OVERLOADS(double); +HEYOKA_DECLARE_SELECT_OVERLOADS(long double); + +#if defined(HEYOKA_HAVE_REAL128) + +HEYOKA_DECLARE_SELECT_OVERLOADS(mppp::real128); + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +HEYOKA_DECLARE_SELECT_OVERLOADS(mppp::real); + +#endif + +#undef HEYOKA_DECLARE_SELECT_OVERLOADS + HEYOKA_END_NAMESPACE HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::select_impl) diff --git a/src/math/relational.cpp b/src/math/relational.cpp index a1c2486e9..ed50d5bf8 100644 --- a/src/math/relational.cpp +++ b/src/math/relational.cpp @@ -6,6 +6,8 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +#include + #include #include #include @@ -26,7 +28,18 @@ #include #include -#include +#if defined(HEYOKA_HAVE_REAL128) + +#include + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +#include + +#endif + #include #include #include @@ -355,6 +368,76 @@ HEYOKA_MATH_REL_IMPL(gte) #undef HEYOKA_MATH_REL_IMPL +// NOTE: this macro was copy-pasted from the kepE() overloads, hence +// the weird (but inconsequential) naming of the function arguments. +#define HEYOKA_DEFINE_REL_OVERLOADS(type) \ + expression eq(expression e, type M) \ + { \ + return eq(std::move(e), expression{std::move(M)}); \ + } \ + expression eq(type e, expression M) \ + { \ + return eq(expression{std::move(e)}, std::move(M)); \ + } \ + expression neq(expression e, type M) \ + { \ + return neq(std::move(e), expression{std::move(M)}); \ + } \ + expression neq(type e, expression M) \ + { \ + return neq(expression{std::move(e)}, std::move(M)); \ + } \ + expression lt(expression e, type M) \ + { \ + return lt(std::move(e), expression{std::move(M)}); \ + } \ + expression lt(type e, expression M) \ + { \ + return lt(expression{std::move(e)}, std::move(M)); \ + } \ + expression gt(expression e, type M) \ + { \ + return gt(std::move(e), expression{std::move(M)}); \ + } \ + expression gt(type e, expression M) \ + { \ + return gt(expression{std::move(e)}, std::move(M)); \ + } \ + expression lte(expression e, type M) \ + { \ + return lte(std::move(e), expression{std::move(M)}); \ + } \ + expression lte(type e, expression M) \ + { \ + return lte(expression{std::move(e)}, std::move(M)); \ + } \ + expression gte(expression e, type M) \ + { \ + return gte(std::move(e), expression{std::move(M)}); \ + } \ + expression gte(type e, expression M) \ + { \ + return gte(expression{std::move(e)}, std::move(M)); \ + } + +HEYOKA_DEFINE_REL_OVERLOADS(float) +HEYOKA_DEFINE_REL_OVERLOADS(double) +HEYOKA_DEFINE_REL_OVERLOADS(long double) + +#if defined(HEYOKA_HAVE_REAL128) + +HEYOKA_DEFINE_REL_OVERLOADS(mppp::real128); + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +HEYOKA_DEFINE_REL_OVERLOADS(mppp::real); + +#endif + +#undef HEYOKA_DEFINE_REL_OVERLOADS + HEYOKA_END_NAMESPACE // NOLINTNEXTLINE(cert-err58-cpp) diff --git a/src/math/select.cpp b/src/math/select.cpp index 338a2d580..10c487a01 100644 --- a/src/math/select.cpp +++ b/src/math/select.cpp @@ -6,6 +6,8 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +#include + #include #include #include @@ -23,7 +25,18 @@ #include #include -#include +#if defined(HEYOKA_HAVE_REAL128) + +#include + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +#include + +#endif + #include #include #include @@ -268,6 +281,52 @@ expression select(expression cond, expression t, expression f) return expression{func{detail::select_impl{std::move(cond), std::move(t), std::move(f)}}}; } +// NOTE: this macro was copy-pasted from the kepDE() overloads, hence +// the weird (but inconsequential) naming of the function arguments. +#define HEYOKA_DEFINE_SELECT_OVERLOADS(type) \ + expression select(expression s0, type c0, type DM) \ + { \ + return select(std::move(s0), expression{std::move(c0)}, expression{std::move(DM)}); \ + } \ + expression select(type s0, expression c0, type DM) \ + { \ + return select(expression{std::move(s0)}, std::move(c0), expression{std::move(DM)}); \ + } \ + expression select(type s0, type c0, expression DM) \ + { \ + return select(expression{std::move(s0)}, expression{std::move(c0)}, std::move(DM)); \ + } \ + expression select(expression s0, expression c0, type DM) \ + { \ + return select(std::move(s0), std::move(c0), expression{std::move(DM)}); \ + } \ + expression select(expression s0, type c0, expression DM) \ + { \ + return select(std::move(s0), expression{std::move(c0)}, std::move(DM)); \ + } \ + expression select(type s0, expression c0, expression DM) \ + { \ + return select(expression{std::move(s0)}, std::move(c0), std::move(DM)); \ + } + +HEYOKA_DEFINE_SELECT_OVERLOADS(float) +HEYOKA_DEFINE_SELECT_OVERLOADS(double) +HEYOKA_DEFINE_SELECT_OVERLOADS(long double) + +#if defined(HEYOKA_HAVE_REAL128) + +HEYOKA_DEFINE_SELECT_OVERLOADS(mppp::real128); + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +HEYOKA_DEFINE_SELECT_OVERLOADS(mppp::real); + +#endif + +#undef HEYOKA_DEFINE_SELECT_OVERLOADS + HEYOKA_END_NAMESPACE // NOLINTNEXTLINE(cert-err58-cpp) diff --git a/test/rel.cpp b/test/rel.cpp index b6b43ab17..3b9123113 100644 --- a/test/rel.cpp +++ b/test/rel.cpp @@ -80,6 +80,22 @@ TEST_CASE("basic") REQUIRE(eq(x, y) != neq(x, y)); REQUIRE(lte(x, y) != gte(x, y)); REQUIRE(lte(x, y) == lte(x, y)); + + // Test a couple of numerical overloads too. + REQUIRE(eq(x, 1.) == eq(x, 1_dbl)); + REQUIRE(lte(1.l, par[0]) == lte(1_ldbl, par[0])); + +#if defined(HEYOKA_HAVE_REAL128) + + REQUIRE(lte(mppp::real128{"1.1"}, par[0]) == lte(1.1_f128, par[0])); + +#endif + +#if defined(HEYOKA_HAVE_REAL) + + REQUIRE(lte(mppp::real{"1.1", 14}, par[0]) == lte(expression{mppp::real{"1.1", 14}}, par[0])); + +#endif } TEST_CASE("stream") diff --git a/test/select.cpp b/test/select.cpp index 9390bd6b7..df2afa337 100644 --- a/test/select.cpp +++ b/test/select.cpp @@ -76,6 +76,23 @@ TEST_CASE("basic") auto [x, y, z] = make_vars("x", "y", "z"); REQUIRE(expression{func{detail::select_impl{}}} == expression{func{detail::select_impl{0_dbl, 0_dbl, 0_dbl}}}); + + // A couple of tests for the numeric overloads. + REQUIRE(select(1., x, 2.) == select(1_dbl, x, 2_dbl)); + REQUIRE(select(x, par[0], 2.f) == select(x, par[0], 2_flt)); + +#if defined(HEYOKA_HAVE_REAL128) + + REQUIRE(select(mppp::real128{"3.1"}, x, mppp::real128{"2.1"}) == select(3.1_f128, x, 2.1_f128)); + +#endif + +#if defined(HEYOKA_HAVE_REAL) + + REQUIRE(select(mppp::real{"3.1", 14}, x, mppp::real{"2.1", 14}) + == select(expression{mppp::real{"3.1", 14}}, x, expression{mppp::real{"2.1", 14}})); + +#endif } TEST_CASE("stream") From 766b5850714421f49f156a7dc65d712b36dc93d9 Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Sat, 22 Jun 2024 16:18:53 +0200 Subject: [PATCH 26/26] Update changelog. --- doc/changelog.rst | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/doc/changelog.rst b/doc/changelog.rst index 6ccebf907..ffca39d5a 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,6 +1,17 @@ Changelog ========= +5.1.0 (unreleased) +------------------ + +New +~~~ + +- Add the ``select()`` primitive to the expression system + (`#432 `__). +- Add relational and logical operators to the expression system + (`#432 `__). + 5.0.0 (2024-06-13) ------------------