Skip to content

Commit

Permalink
Add numerical overloads for select() and the relational ops.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Jun 22, 2024
1 parent 17495cd commit 9719002
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 4 deletions.
47 changes: 46 additions & 1 deletion include/heyoka/math/relational.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,24 @@
#ifndef HEYOKA_MATH_RELATIONAL_HPP
#define HEYOKA_MATH_RELATIONAL_HPP

#include <heyoka/config.hpp>

#include <cstdint>
#include <sstream>
#include <vector>

#include <heyoka/config.hpp>
#if defined(HEYOKA_HAVE_REAL128)

#include <mp++/real128.hpp>

#endif

#if defined(HEYOKA_HAVE_REAL)

#include <mp++/real.hpp>

#endif

#include <heyoka/detail/fwd_decl.hpp>
#include <heyoka/detail/llvm_fwd.hpp>
#include <heyoka/detail/visibility.hpp>
Expand Down Expand Up @@ -71,6 +84,38 @@ HEYOKA_DLL_PUBLIC expression gt(expression, expression);
HEYOKA_DLL_PUBLIC expression lte(expression, expression);
HEYOKA_DLL_PUBLIC expression gte(expression, expression);

#define HEYOKA_DECLARE_REL_OVERLOADS(type) \
HEYOKA_DLL_PUBLIC expression eq(expression, type); \
HEYOKA_DLL_PUBLIC expression eq(type, expression); \
HEYOKA_DLL_PUBLIC expression neq(expression, type); \
HEYOKA_DLL_PUBLIC expression neq(type, expression); \
HEYOKA_DLL_PUBLIC expression lt(expression, type); \
HEYOKA_DLL_PUBLIC expression lt(type, expression); \
HEYOKA_DLL_PUBLIC expression gt(expression, type); \
HEYOKA_DLL_PUBLIC expression gt(type, expression); \
HEYOKA_DLL_PUBLIC expression lte(expression, type); \
HEYOKA_DLL_PUBLIC expression lte(type, expression); \
HEYOKA_DLL_PUBLIC expression gte(expression, type); \
HEYOKA_DLL_PUBLIC expression gte(type, expression);

HEYOKA_DECLARE_REL_OVERLOADS(float);
HEYOKA_DECLARE_REL_OVERLOADS(double);
HEYOKA_DECLARE_REL_OVERLOADS(long double);

#if defined(HEYOKA_HAVE_REAL128)

HEYOKA_DECLARE_REL_OVERLOADS(mppp::real128);

#endif

#if defined(HEYOKA_HAVE_REAL)

HEYOKA_DECLARE_REL_OVERLOADS(mppp::real);

#endif

#undef HEYOKA_DECLARE_REL_OVERLOADS

HEYOKA_END_NAMESPACE

HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::rel_impl)
Expand Down
41 changes: 40 additions & 1 deletion include/heyoka/math/select.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,23 @@
#ifndef HEYOKA_MATH_SELECT_HPP
#define HEYOKA_MATH_SELECT_HPP

#include <heyoka/config.hpp>

#include <cstdint>
#include <vector>

#include <heyoka/config.hpp>
#if defined(HEYOKA_HAVE_REAL128)

#include <mp++/real128.hpp>

#endif

#if defined(HEYOKA_HAVE_REAL)

#include <mp++/real.hpp>

#endif

#include <heyoka/detail/fwd_decl.hpp>
#include <heyoka/detail/llvm_fwd.hpp>
#include <heyoka/detail/visibility.hpp>
Expand Down Expand Up @@ -56,6 +69,32 @@ class HEYOKA_DLL_PUBLIC select_impl : public func_base

HEYOKA_DLL_PUBLIC expression select(expression, expression, expression);

#define HEYOKA_DECLARE_SELECT_OVERLOADS(type) \
HEYOKA_DLL_PUBLIC expression select(expression, type, type); \
HEYOKA_DLL_PUBLIC expression select(type, expression, type); \
HEYOKA_DLL_PUBLIC expression select(type, type, expression); \
HEYOKA_DLL_PUBLIC expression select(expression, expression, type); \
HEYOKA_DLL_PUBLIC expression select(expression, type, expression); \
HEYOKA_DLL_PUBLIC expression select(type, expression, expression)

HEYOKA_DECLARE_SELECT_OVERLOADS(float);
HEYOKA_DECLARE_SELECT_OVERLOADS(double);
HEYOKA_DECLARE_SELECT_OVERLOADS(long double);

#if defined(HEYOKA_HAVE_REAL128)

HEYOKA_DECLARE_SELECT_OVERLOADS(mppp::real128);

#endif

#if defined(HEYOKA_HAVE_REAL)

HEYOKA_DECLARE_SELECT_OVERLOADS(mppp::real);

#endif

#undef HEYOKA_DECLARE_SELECT_OVERLOADS

HEYOKA_END_NAMESPACE

HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::select_impl)
Expand Down
85 changes: 84 additions & 1 deletion src/math/relational.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include <heyoka/config.hpp>

#include <cassert>
#include <concepts>
#include <cstdint>
Expand All @@ -26,7 +28,18 @@
#include <llvm/IR/Type.h>
#include <llvm/IR/Value.h>

#include <heyoka/config.hpp>
#if defined(HEYOKA_HAVE_REAL128)

#include <mp++/real128.hpp>

#endif

#if defined(HEYOKA_HAVE_REAL)

#include <mp++/real.hpp>

#endif

#include <heyoka/detail/llvm_helpers.hpp>
#include <heyoka/detail/string_conv.hpp>
#include <heyoka/expression.hpp>
Expand Down Expand Up @@ -355,6 +368,76 @@ HEYOKA_MATH_REL_IMPL(gte)

#undef HEYOKA_MATH_REL_IMPL

// NOTE: this macro was copy-pasted from the kepE() overloads, hence
// the weird (but inconsequential) naming of the function arguments.
#define HEYOKA_DEFINE_REL_OVERLOADS(type) \
expression eq(expression e, type M) \
{ \
return eq(std::move(e), expression{std::move(M)}); \
} \
expression eq(type e, expression M) \
{ \
return eq(expression{std::move(e)}, std::move(M)); \
} \
expression neq(expression e, type M) \
{ \
return neq(std::move(e), expression{std::move(M)}); \
} \
expression neq(type e, expression M) \
{ \
return neq(expression{std::move(e)}, std::move(M)); \
} \
expression lt(expression e, type M) \
{ \
return lt(std::move(e), expression{std::move(M)}); \
} \
expression lt(type e, expression M) \
{ \
return lt(expression{std::move(e)}, std::move(M)); \
} \
expression gt(expression e, type M) \
{ \
return gt(std::move(e), expression{std::move(M)}); \
} \
expression gt(type e, expression M) \
{ \
return gt(expression{std::move(e)}, std::move(M)); \
} \
expression lte(expression e, type M) \
{ \
return lte(std::move(e), expression{std::move(M)}); \
} \
expression lte(type e, expression M) \
{ \
return lte(expression{std::move(e)}, std::move(M)); \
} \
expression gte(expression e, type M) \
{ \
return gte(std::move(e), expression{std::move(M)}); \
} \
expression gte(type e, expression M) \
{ \
return gte(expression{std::move(e)}, std::move(M)); \
}

HEYOKA_DEFINE_REL_OVERLOADS(float)
HEYOKA_DEFINE_REL_OVERLOADS(double)
HEYOKA_DEFINE_REL_OVERLOADS(long double)

#if defined(HEYOKA_HAVE_REAL128)

HEYOKA_DEFINE_REL_OVERLOADS(mppp::real128);

#endif

#if defined(HEYOKA_HAVE_REAL)

HEYOKA_DEFINE_REL_OVERLOADS(mppp::real);

#endif

#undef HEYOKA_DEFINE_REL_OVERLOADS

HEYOKA_END_NAMESPACE

// NOLINTNEXTLINE(cert-err58-cpp)
Expand Down
61 changes: 60 additions & 1 deletion src/math/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include <heyoka/config.hpp>

#include <cassert>
#include <concepts>
#include <cstdint>
Expand All @@ -23,7 +25,18 @@
#include <llvm/IR/Type.h>
#include <llvm/IR/Value.h>

#include <heyoka/config.hpp>
#if defined(HEYOKA_HAVE_REAL128)

#include <mp++/real128.hpp>

#endif

#if defined(HEYOKA_HAVE_REAL)

#include <mp++/real.hpp>

#endif

#include <heyoka/detail/llvm_helpers.hpp>
#include <heyoka/detail/string_conv.hpp>
#include <heyoka/expression.hpp>
Expand Down Expand Up @@ -268,6 +281,52 @@ expression select(expression cond, expression t, expression f)
return expression{func{detail::select_impl{std::move(cond), std::move(t), std::move(f)}}};
}

// NOTE: this macro was copy-pasted from the kepDE() overloads, hence
// the weird (but inconsequential) naming of the function arguments.
#define HEYOKA_DEFINE_SELECT_OVERLOADS(type) \
expression select(expression s0, type c0, type DM) \
{ \
return select(std::move(s0), expression{std::move(c0)}, expression{std::move(DM)}); \
} \
expression select(type s0, expression c0, type DM) \
{ \
return select(expression{std::move(s0)}, std::move(c0), expression{std::move(DM)}); \
} \
expression select(type s0, type c0, expression DM) \
{ \
return select(expression{std::move(s0)}, expression{std::move(c0)}, std::move(DM)); \
} \
expression select(expression s0, expression c0, type DM) \
{ \
return select(std::move(s0), std::move(c0), expression{std::move(DM)}); \
} \
expression select(expression s0, type c0, expression DM) \
{ \
return select(std::move(s0), expression{std::move(c0)}, std::move(DM)); \
} \
expression select(type s0, expression c0, expression DM) \
{ \
return select(expression{std::move(s0)}, std::move(c0), std::move(DM)); \
}

HEYOKA_DEFINE_SELECT_OVERLOADS(float)
HEYOKA_DEFINE_SELECT_OVERLOADS(double)
HEYOKA_DEFINE_SELECT_OVERLOADS(long double)

#if defined(HEYOKA_HAVE_REAL128)

HEYOKA_DEFINE_SELECT_OVERLOADS(mppp::real128);

#endif

#if defined(HEYOKA_HAVE_REAL)

HEYOKA_DEFINE_SELECT_OVERLOADS(mppp::real);

#endif

#undef HEYOKA_DEFINE_SELECT_OVERLOADS

HEYOKA_END_NAMESPACE

// NOLINTNEXTLINE(cert-err58-cpp)
Expand Down
16 changes: 16 additions & 0 deletions test/rel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,22 @@ TEST_CASE("basic")
REQUIRE(eq(x, y) != neq(x, y));
REQUIRE(lte(x, y) != gte(x, y));
REQUIRE(lte(x, y) == lte(x, y));

// Test a couple of numerical overloads too.
REQUIRE(eq(x, 1.) == eq(x, 1_dbl));
REQUIRE(lte(1.l, par[0]) == lte(1_ldbl, par[0]));

#if defined(HEYOKA_HAVE_REAL128)

REQUIRE(lte(mppp::real128{"1.1"}, par[0]) == lte(1.1_f128, par[0]));

#endif

#if defined(HEYOKA_HAVE_REAL)

REQUIRE(lte(mppp::real{"1.1", 14}, par[0]) == lte(expression{mppp::real{"1.1", 14}}, par[0]));

#endif
}

TEST_CASE("stream")
Expand Down
17 changes: 17 additions & 0 deletions test/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ TEST_CASE("basic")
auto [x, y, z] = make_vars("x", "y", "z");

REQUIRE(expression{func{detail::select_impl{}}} == expression{func{detail::select_impl{0_dbl, 0_dbl, 0_dbl}}});

// A couple of tests for the numeric overloads.
REQUIRE(select(1., x, 2.) == select(1_dbl, x, 2_dbl));
REQUIRE(select(x, par[0], 2.f) == select(x, par[0], 2_flt));

#if defined(HEYOKA_HAVE_REAL128)

REQUIRE(select(mppp::real128{"3.1"}, x, mppp::real128{"2.1"}) == select(3.1_f128, x, 2.1_f128));

#endif

#if defined(HEYOKA_HAVE_REAL)

REQUIRE(select(mppp::real{"3.1", 14}, x, mppp::real{"2.1", 14})
== select(expression{mppp::real{"3.1", 14}}, x, expression{mppp::real{"2.1", 14}}));

#endif
}

TEST_CASE("stream")
Expand Down

0 comments on commit 9719002

Please sign in to comment.