diff --git a/doc/changelog.rst b/doc/changelog.rst index 8a7bb1e96..695db4cb3 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -7,8 +7,11 @@ Changelog New ~~~ -- Implement ``ReLU`` and its derivative in the expression - system (`#356 `__). +- Implement (leaky) ``ReLU`` and its derivative in the expression + system (`#357 `__, + `#356 `__). +- Add feed-forward neural network model + (`#355 `__). - Implement the eccentric longitude :math:`F` in the expression system (`#352 `__). - Implement the delta eccentric anomaly :math:`\Delta E` in the expression diff --git a/include/heyoka/math/relu.hpp b/include/heyoka/math/relu.hpp index 63482dcfe..9512b46e8 100644 --- a/include/heyoka/math/relu.hpp +++ b/include/heyoka/math/relu.hpp @@ -10,6 +10,7 @@ #define HEYOKA_MATH_RELU_HPP #include +#include #include #include @@ -26,16 +27,23 @@ namespace detail class HEYOKA_DLL_PUBLIC relu_impl : public func_base { + double m_slope; + friend class boost::serialization::access; template void serialize(Archive &ar, unsigned) { ar &boost::serialization::base_object(*this); + ar & m_slope; } public: relu_impl(); - explicit relu_impl(expression); + explicit relu_impl(expression, double); + + [[nodiscard]] double get_slope() const noexcept; + + void to_stream(std::ostringstream &) const; [[nodiscard]] expression normalise() const; @@ -55,16 +63,23 @@ class HEYOKA_DLL_PUBLIC relu_impl : public func_base class HEYOKA_DLL_PUBLIC relup_impl : public func_base { + double m_slope; + friend class boost::serialization::access; template void serialize(Archive &ar, unsigned) { ar &boost::serialization::base_object(*this); + ar & m_slope; } public: relup_impl(); - explicit relup_impl(expression); + explicit relup_impl(expression, double); + + [[nodiscard]] double get_slope() const noexcept; + + void to_stream(std::ostringstream &) const; [[nodiscard]] expression normalise() const; @@ -84,9 +99,27 @@ class HEYOKA_DLL_PUBLIC relup_impl : public func_base } // namespace detail -HEYOKA_DLL_PUBLIC expression relu(expression); +HEYOKA_DLL_PUBLIC expression relu(expression, double = 0); + +HEYOKA_DLL_PUBLIC expression relup(expression, double = 0); + +class HEYOKA_DLL_PUBLIC leaky_relu +{ + double m_slope; + +public: + explicit leaky_relu(double); + expression operator()(expression) const; +}; -HEYOKA_DLL_PUBLIC expression relup(expression); +class HEYOKA_DLL_PUBLIC leaky_relup +{ + double m_slope; + +public: + explicit leaky_relup(double); + expression operator()(expression) const; +}; HEYOKA_END_NAMESPACE diff --git a/src/func.cpp b/src/func.cpp index 3c8304af4..4b13924e5 100644 --- a/src/func.cpp +++ b/src/func.cpp @@ -833,6 +833,8 @@ std::pair> llvm_c_eval_func_name_args(llv std::uint32_t batch_size, const std::vector &args) { + assert(std::find(name.begin(), name.end(), '.') == name.end()); + // Fetch the vector floating-point type. auto *val_t = make_vector_type(fp_t, batch_size); diff --git a/src/llvm_state.cpp b/src/llvm_state.cpp index d84652112..e48739be0 100644 --- a/src/llvm_state.cpp +++ b/src/llvm_state.cpp @@ -29,7 +29,6 @@ #include #include #include -#include #include #include diff --git a/src/math/relu.cpp b/src/math/relu.cpp index 15889729d..9fcba480b 100644 --- a/src/math/relu.cpp +++ b/src/math/relu.cpp @@ -6,8 +6,11 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +#include #include +#include #include +#include #include #include #include @@ -42,30 +45,92 @@ HEYOKA_BEGIN_NAMESPACE namespace detail { -relu_impl::relu_impl() : relu_impl(0_dbl) {} +namespace +{ + +// Checker for the slope parameter of the leaky ReLU. +void relu_slope_check(double slope) +{ + if (!std::isfinite(slope) || slope < 0) { + throw std::invalid_argument(fmt::format("The slope parameter for a leaky ReLU must be finite and non-negative, " + "but the value {} was provided instead", + slope)); + } +} + +// Helper to build a unique name for a relu/relup function, depending +// on the slope value. +std::string relu_name(const char *base, double slope) +{ + if (slope == 0) { + return base; + } else { + // NOTE: we print the slope value in hex format, then we replace + // the decimal point '.' with an underscore '_' (as the '.' is used + // as a separator in the name mangling scheme for compact mode functions). + auto ret = fmt::format("{}_{:a}", base, slope); + std::replace(ret.begin(), ret.end(), '.', '_'); + + return ret; + } +} + +} // namespace -relu_impl::relu_impl(expression ex) : func_base("relu", std::vector{std::move(ex)}) {} +relu_impl::relu_impl() : relu_impl(0_dbl, 0.) {} + +relu_impl::relu_impl(expression ex, double slope) + : func_base(relu_name("relu", slope), std::vector{std::move(ex)}), m_slope(slope) +{ + relu_slope_check(slope); +} + +double relu_impl::get_slope() const noexcept +{ + return m_slope; +} + +void relu_impl::to_stream(std::ostringstream &oss) const +{ + assert(args().size() == 1u); + + if (m_slope == 0) { + oss << "relu("; + stream_expression(oss, args()[0]); + oss << ')'; + } else { + oss << "leaky_relu("; + stream_expression(oss, args()[0]); + oss << fmt::format(", {})", m_slope); + } +} [[nodiscard]] expression relu_impl::normalise() const { assert(args().size() == 1u); - return relu(args()[0]); + return relu(args()[0], m_slope); } [[nodiscard]] std::vector relu_impl::gradient() const { assert(args().size() == 1u); - return {relup(args()[0])}; + return {relup(args()[0], m_slope)}; } namespace { // LLVM implementation of relu. -llvm::Value *llvm_relu(llvm_state &s, llvm::Value *x) +llvm::Value *llvm_relu(llvm_state &s, llvm::Value *x, double slope) { auto *zero_c = llvm_constantfp(s, x->getType(), 0.); - return s.builder().CreateSelect(llvm_fcmp_ogt(s, x, zero_c), x, zero_c); + + if (slope == 0) { + return s.builder().CreateSelect(llvm_fcmp_ogt(s, x, zero_c), x, zero_c); + } else { + auto *slope_c = llvm_constantfp(s, x->getType(), slope); + return s.builder().CreateSelect(llvm_fcmp_ogt(s, x, zero_c), x, llvm_fmul(s, slope_c, x)); + } } } // namespace @@ -76,9 +141,9 @@ llvm::Value *llvm_relu(llvm_state &s, llvm::Value *x) bool high_accuracy) const { return llvm_eval_helper( - [&s](const std::vector &args, bool) { + [&](const std::vector &args, bool) { assert(args.size() == 1u); - return llvm_relu(s, args[0]); + return llvm_relu(s, args[0], m_slope); }, *this, s, fp_t, eval_arr, par_ptr, stride, batch_size, high_accuracy); } @@ -87,10 +152,10 @@ llvm::Function *relu_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_t, std bool high_accuracy) const { return llvm_c_eval_func_helper( - "relu", + get_name(), [&](const std::vector &args, bool) { assert(args.size() == 1u); - return llvm_relu(s, args[0]); + return llvm_relu(s, args[0], m_slope); }, *this, s, fp_t, batch_size, high_accuracy); } @@ -103,10 +168,10 @@ template , int> = 0> llvm::Value *taylor_diff_relu_impl(llvm_state &s, llvm::Type *fp_t, const relu_impl &, const std::vector &, const U &num, const std::vector &, llvm::Value *par_ptr, std::uint32_t, std::uint32_t order, std::uint32_t, - std::uint32_t batch_size) + std::uint32_t batch_size, double slope) { if (order == 0u) { - return llvm_relu(s, taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size)); + return llvm_relu(s, taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size), slope); } else { return vector_splat(s.builder(), llvm_constantfp(s, fp_t, 0.), batch_size); } @@ -115,7 +180,8 @@ llvm::Value *taylor_diff_relu_impl(llvm_state &s, llvm::Type *fp_t, const relu_i // Derivative of relu(variable). llvm::Value *taylor_diff_relu_impl(llvm_state &s, llvm::Type *, const relu_impl &, const std::vector &, const variable &var, const std::vector &arr, llvm::Value *, - std::uint32_t n_uvars, std::uint32_t order, std::uint32_t, std::uint32_t) + std::uint32_t n_uvars, std::uint32_t order, std::uint32_t, std::uint32_t, + double slope) { const auto u_idx = uname_to_index(var.name()); @@ -123,7 +189,13 @@ llvm::Value *taylor_diff_relu_impl(llvm_state &s, llvm::Type *, const relu_impl auto *u_order = taylor_fetch_diff(arr, u_idx, order, n_uvars); auto *zero_c = llvm_constantfp(s, u_zero->getType(), 0.); - return s.builder().CreateSelect(llvm_fcmp_ogt(s, u_zero, zero_c), u_order, zero_c); + + if (slope == 0) { + return s.builder().CreateSelect(llvm_fcmp_ogt(s, u_zero, zero_c), u_order, zero_c); + } else { + auto *slope_c = llvm_constantfp(s, u_zero->getType(), slope); + return s.builder().CreateSelect(llvm_fcmp_ogt(s, u_zero, zero_c), u_order, llvm_fmul(s, slope_c, u_order)); + } } // LCOV_EXCL_START @@ -132,7 +204,7 @@ llvm::Value *taylor_diff_relu_impl(llvm_state &s, llvm::Type *, const relu_impl template , int> = 0> llvm::Value *taylor_diff_relu_impl(llvm_state &, llvm::Type *, const relu_impl &, const std::vector &, const U &, const std::vector &, llvm::Value *, std::uint32_t, - std::uint32_t, std::uint32_t, std::uint32_t) + std::uint32_t, std::uint32_t, std::uint32_t, double) { throw std::invalid_argument( "An invalid argument type was encountered while trying to build the Taylor derivative of a relu"); @@ -140,12 +212,14 @@ llvm::Value *taylor_diff_relu_impl(llvm_state &, llvm::Type *, const relu_impl & // LCOV_EXCL_STOP -llvm::Value *taylor_diff_relu(llvm_state &s, llvm::Type *fp_t, const relu_impl &f, - const std::vector &deps, const std::vector &arr, - llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, - std::uint32_t batch_size) +} // namespace + +llvm::Value *relu_impl::taylor_diff(llvm_state &s, llvm::Type *fp_t, const std::vector &deps, + const std::vector &arr, llvm::Value *par_ptr, llvm::Value *, + std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, + std::uint32_t batch_size, bool) const { - assert(f.args().size() == 1u); + assert(args().size() == 1u); // LCOV_EXCL_START if (!deps.empty()) { @@ -158,19 +232,10 @@ llvm::Value *taylor_diff_relu(llvm_state &s, llvm::Type *fp_t, const relu_impl & return std::visit( [&](const auto &v) { - return taylor_diff_relu_impl(s, fp_t, f, deps, v, arr, par_ptr, n_uvars, order, idx, batch_size); + return taylor_diff_relu_impl(s, fp_t, *this, deps, v, arr, par_ptr, n_uvars, order, idx, batch_size, + m_slope); }, - f.args()[0].value()); -} - -} // namespace - -llvm::Value *relu_impl::taylor_diff(llvm_state &s, llvm::Type *fp_t, const std::vector &deps, - const std::vector &arr, llvm::Value *par_ptr, llvm::Value *, - std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, - std::uint32_t batch_size, bool) const -{ - return taylor_diff_relu(s, fp_t, *this, deps, arr, par_ptr, n_uvars, order, idx, batch_size); + args()[0].value()); } namespace @@ -178,25 +243,27 @@ namespace // Derivative of relu(number). template , int> = 0> -llvm::Function *taylor_c_diff_func_relu_impl(llvm_state &s, llvm::Type *fp_t, const relu_impl &, const U &num, - std::uint32_t n_uvars, std::uint32_t batch_size) +llvm::Function *taylor_c_diff_func_relu_impl(llvm_state &s, llvm::Type *fp_t, const relu_impl &r, const U &num, + // NOLINTNEXTLINE(bugprone-easily-swappable-parameters) + std::uint32_t n_uvars, std::uint32_t batch_size, double slope) { return taylor_c_diff_func_numpar( - s, fp_t, n_uvars, batch_size, "relu", 0, - [&s](const auto &args) { + s, fp_t, n_uvars, batch_size, r.get_name(), 0, + [&](const auto &args) { // LCOV_EXCL_START assert(args.size() == 1u); assert(args[0] != nullptr); // LCOV_EXCL_STOP - return llvm_relu(s, args[0]); + return llvm_relu(s, args[0], slope); }, num); } // Derivative of relu(variable). -llvm::Function *taylor_c_diff_func_relu_impl(llvm_state &s, llvm::Type *fp_t, const relu_impl &, const variable &var, - std::uint32_t n_uvars, std::uint32_t batch_size) +llvm::Function *taylor_c_diff_func_relu_impl(llvm_state &s, llvm::Type *fp_t, const relu_impl &r, const variable &var, + // NOLINTNEXTLINE(bugprone-easily-swappable-parameters) + std::uint32_t n_uvars, std::uint32_t batch_size, double slope) { auto &module = s.module(); auto &builder = s.builder(); @@ -206,7 +273,7 @@ llvm::Function *taylor_c_diff_func_relu_impl(llvm_state &s, llvm::Type *fp_t, co auto *val_t = make_vector_type(fp_t, batch_size); // Fetch the function name and arguments. - const auto na_pair = taylor_c_diff_func_name_args(context, fp_t, "relu", n_uvars, batch_size, {var}); + const auto na_pair = taylor_c_diff_func_name_args(context, fp_t, r.get_name(), n_uvars, batch_size, {var}); const auto &fname = na_pair.first; const auto &fargs = na_pair.second; @@ -243,7 +310,12 @@ llvm::Function *taylor_c_diff_func_relu_impl(llvm_state &s, llvm::Type *fp_t, co auto *zero_c = llvm_constantfp(s, u_zero->getType(), 0.); - builder.CreateRet(builder.CreateSelect(llvm_fcmp_ogt(s, u_zero, zero_c), u_ord, zero_c)); + if (slope == 0) { + builder.CreateRet(builder.CreateSelect(llvm_fcmp_ogt(s, u_zero, zero_c), u_ord, zero_c)); + } else { + auto *slope_c = llvm_constantfp(s, u_zero->getType(), slope); + builder.CreateRet(builder.CreateSelect(llvm_fcmp_ogt(s, u_zero, zero_c), u_ord, llvm_fmul(s, slope_c, u_ord))); + } // Verify. s.verify_function(f); @@ -259,7 +331,7 @@ llvm::Function *taylor_c_diff_func_relu_impl(llvm_state &s, llvm::Type *fp_t, co // All the other cases. template , int> = 0> llvm::Function *taylor_c_diff_func_relu_impl(llvm_state &, llvm::Type *, const relu_impl &, const U &, std::uint32_t, - std::uint32_t) + std::uint32_t, double) { throw std::invalid_argument("An invalid argument type was encountered while trying to build the Taylor derivative " "of a relu in compact mode"); @@ -267,31 +339,50 @@ llvm::Function *taylor_c_diff_func_relu_impl(llvm_state &, llvm::Type *, const r // LCOV_EXCL_STOP -llvm::Function *taylor_c_diff_func_relu(llvm_state &s, llvm::Type *fp_t, const relu_impl &fn, std::uint32_t n_uvars, - std::uint32_t batch_size) +} // namespace + +llvm::Function *relu_impl::taylor_c_diff_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t n_uvars, + std::uint32_t batch_size, bool) const { - assert(fn.args().size() == 1u); + assert(args().size() == 1u); - return std::visit([&](const auto &v) { return taylor_c_diff_func_relu_impl(s, fp_t, fn, v, n_uvars, batch_size); }, - fn.args()[0].value()); + return std::visit( + [&](const auto &v) { return taylor_c_diff_func_relu_impl(s, fp_t, *this, v, n_uvars, batch_size, m_slope); }, + args()[0].value()); } -} // namespace +relup_impl::relup_impl() : relup_impl(0_dbl, 0.) {} -llvm::Function *relu_impl::taylor_c_diff_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t n_uvars, - std::uint32_t batch_size, bool) const +relup_impl::relup_impl(expression ex, double slope) + : func_base(relu_name("relup", slope), std::vector{std::move(ex)}), m_slope(slope) +{ + relu_slope_check(slope); +} + +double relup_impl::get_slope() const noexcept { - return taylor_c_diff_func_relu(s, fp_t, *this, n_uvars, batch_size); + return m_slope; } -relup_impl::relup_impl() : relup_impl(0_dbl) {} +void relup_impl::to_stream(std::ostringstream &oss) const +{ + assert(args().size() == 1u); -relup_impl::relup_impl(expression ex) : func_base("relup", std::vector{std::move(ex)}) {} + if (m_slope == 0) { + oss << "relup("; + stream_expression(oss, args()[0]); + oss << ')'; + } else { + oss << "leaky_relup("; + stream_expression(oss, args()[0]); + oss << fmt::format(", {})", m_slope); + } +} [[nodiscard]] expression relup_impl::normalise() const { assert(args().size() == 1u); - return relup(args()[0]); + return relup(args()[0], m_slope); } [[nodiscard]] std::vector relup_impl::gradient() const @@ -304,11 +395,17 @@ namespace { // LLVM implementation of relup. -llvm::Value *llvm_relup(llvm_state &s, llvm::Value *x) +llvm::Value *llvm_relup(llvm_state &s, llvm::Value *x, double slope) { auto *zero_c = llvm_constantfp(s, x->getType(), 0.); auto *one_c = llvm_constantfp(s, x->getType(), 1.); - return s.builder().CreateSelect(llvm_fcmp_ogt(s, x, zero_c), one_c, zero_c); + + if (slope == 0) { + return s.builder().CreateSelect(llvm_fcmp_ogt(s, x, zero_c), one_c, zero_c); + } else { + auto *slope_c = llvm_constantfp(s, x->getType(), slope); + return s.builder().CreateSelect(llvm_fcmp_ogt(s, x, zero_c), one_c, slope_c); + } } } // namespace @@ -319,9 +416,9 @@ llvm::Value *llvm_relup(llvm_state &s, llvm::Value *x) bool high_accuracy) const { return llvm_eval_helper( - [&s](const std::vector &args, bool) { + [&](const std::vector &args, bool) { assert(args.size() == 1u); - return llvm_relup(s, args[0]); + return llvm_relup(s, args[0], m_slope); }, *this, s, fp_t, eval_arr, par_ptr, stride, batch_size, high_accuracy); } @@ -330,10 +427,10 @@ llvm::Function *relup_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_t, st bool high_accuracy) const { return llvm_c_eval_func_helper( - "relup", + get_name(), [&](const std::vector &args, bool) { assert(args.size() == 1u); - return llvm_relup(s, args[0]); + return llvm_relup(s, args[0], m_slope); }, *this, s, fp_t, batch_size, high_accuracy); } @@ -346,10 +443,10 @@ template , int> = 0> llvm::Value *taylor_diff_relup_impl(llvm_state &s, llvm::Type *fp_t, const relup_impl &, const std::vector &, const U &num, const std::vector &, llvm::Value *par_ptr, std::uint32_t, - std::uint32_t order, std::uint32_t, std::uint32_t batch_size) + std::uint32_t order, std::uint32_t, std::uint32_t batch_size, double slope) { if (order == 0u) { - return llvm_relup(s, taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size)); + return llvm_relup(s, taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size), slope); } else { return vector_splat(s.builder(), llvm_constantfp(s, fp_t, 0.), batch_size); } @@ -359,13 +456,14 @@ llvm::Value *taylor_diff_relup_impl(llvm_state &s, llvm::Type *fp_t, const relup llvm::Value *taylor_diff_relup_impl(llvm_state &s, llvm::Type *, const relup_impl &, const std::vector &, const variable &var, const std::vector &arr, llvm::Value *, // NOLINTNEXTLINE(bugprone-easily-swappable-parameters) - std::uint32_t n_uvars, std::uint32_t order, std::uint32_t, std::uint32_t) + std::uint32_t n_uvars, std::uint32_t order, std::uint32_t, std::uint32_t, + double slope) { const auto u_idx = uname_to_index(var.name()); auto *u_zero = taylor_fetch_diff(arr, u_idx, 0, n_uvars); if (order == 0u) { - return llvm_relup(s, u_zero); + return llvm_relup(s, u_zero, slope); } else { return llvm_constantfp(s, u_zero->getType(), 0.); } @@ -377,7 +475,7 @@ llvm::Value *taylor_diff_relup_impl(llvm_state &s, llvm::Type *, const relup_imp template , int> = 0> llvm::Value *taylor_diff_relup_impl(llvm_state &, llvm::Type *, const relup_impl &, const std::vector &, const U &, const std::vector &, llvm::Value *, std::uint32_t, - std::uint32_t, std::uint32_t, std::uint32_t) + std::uint32_t, std::uint32_t, std::uint32_t, double) { throw std::invalid_argument( "An invalid argument type was encountered while trying to build the Taylor derivative of a relup"); @@ -385,12 +483,14 @@ llvm::Value *taylor_diff_relup_impl(llvm_state &, llvm::Type *, const relup_impl // LCOV_EXCL_STOP -llvm::Value *taylor_diff_relup(llvm_state &s, llvm::Type *fp_t, const relup_impl &f, - const std::vector &deps, const std::vector &arr, - llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, - std::uint32_t batch_size) +} // namespace + +llvm::Value *relup_impl::taylor_diff(llvm_state &s, llvm::Type *fp_t, const std::vector &deps, + const std::vector &arr, llvm::Value *par_ptr, llvm::Value *, + std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, + std::uint32_t batch_size, bool) const { - assert(f.args().size() == 1u); + assert(args().size() == 1u); // LCOV_EXCL_START if (!deps.empty()) { @@ -403,19 +503,10 @@ llvm::Value *taylor_diff_relup(llvm_state &s, llvm::Type *fp_t, const relup_impl return std::visit( [&](const auto &v) { - return taylor_diff_relup_impl(s, fp_t, f, deps, v, arr, par_ptr, n_uvars, order, idx, batch_size); + return taylor_diff_relup_impl(s, fp_t, *this, deps, v, arr, par_ptr, n_uvars, order, idx, batch_size, + m_slope); }, - f.args()[0].value()); -} - -} // namespace - -llvm::Value *relup_impl::taylor_diff(llvm_state &s, llvm::Type *fp_t, const std::vector &deps, - const std::vector &arr, llvm::Value *par_ptr, llvm::Value *, - std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, - std::uint32_t batch_size, bool) const -{ - return taylor_diff_relup(s, fp_t, *this, deps, arr, par_ptr, n_uvars, order, idx, batch_size); + args()[0].value()); } namespace @@ -423,25 +514,27 @@ namespace // Derivative of relup(number). template , int> = 0> -llvm::Function *taylor_c_diff_func_relup_impl(llvm_state &s, llvm::Type *fp_t, const relup_impl &, const U &num, - std::uint32_t n_uvars, std::uint32_t batch_size) +llvm::Function *taylor_c_diff_func_relup_impl(llvm_state &s, llvm::Type *fp_t, const relup_impl &r, const U &num, + // NOLINTNEXTLINE(bugprone-easily-swappable-parameters) + std::uint32_t n_uvars, std::uint32_t batch_size, double slope) { return taylor_c_diff_func_numpar( - s, fp_t, n_uvars, batch_size, "relup", 0, - [&s](const auto &args) { + s, fp_t, n_uvars, batch_size, r.get_name(), 0, + [&](const auto &args) { // LCOV_EXCL_START assert(args.size() == 1u); assert(args[0] != nullptr); // LCOV_EXCL_STOP - return llvm_relup(s, args[0]); + return llvm_relup(s, args[0], slope); }, num); } // Derivative of relup(variable). -llvm::Function *taylor_c_diff_func_relup_impl(llvm_state &s, llvm::Type *fp_t, const relup_impl &, const variable &var, - std::uint32_t n_uvars, std::uint32_t batch_size) +llvm::Function *taylor_c_diff_func_relup_impl(llvm_state &s, llvm::Type *fp_t, const relup_impl &r, const variable &var, + // NOLINTNEXTLINE(bugprone-easily-swappable-parameters) + std::uint32_t n_uvars, std::uint32_t batch_size, double slope) { auto &module = s.module(); auto &builder = s.builder(); @@ -451,7 +544,7 @@ llvm::Function *taylor_c_diff_func_relup_impl(llvm_state &s, llvm::Type *fp_t, c auto *val_t = make_vector_type(fp_t, batch_size); // Fetch the function name and arguments. - const auto na_pair = taylor_c_diff_func_name_args(context, fp_t, "relup", n_uvars, batch_size, {var}); + const auto na_pair = taylor_c_diff_func_name_args(context, fp_t, r.get_name(), n_uvars, batch_size, {var}); const auto &fname = na_pair.first; const auto &fargs = na_pair.second; @@ -489,7 +582,8 @@ llvm::Function *taylor_c_diff_func_relup_impl(llvm_state &s, llvm::Type *fp_t, c s, builder.CreateICmpEQ(ord, builder.getInt32(0)), [&]() { // For order 0, invoke the function on the order 0 of var_idx. - auto *ret = llvm_relup(s, taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.getInt32(0), var_idx)); + auto *ret + = llvm_relup(s, taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.getInt32(0), var_idx), slope); // NOLINTNEXTLINE(readability-suspicious-call-argument) builder.CreateStore(ret, retval); }, @@ -515,7 +609,7 @@ llvm::Function *taylor_c_diff_func_relup_impl(llvm_state &s, llvm::Type *fp_t, c // All the other cases. template , int> = 0> llvm::Function *taylor_c_diff_func_relup_impl(llvm_state &, llvm::Type *, const relup_impl &, const U &, std::uint32_t, - std::uint32_t) + std::uint32_t, double) { throw std::invalid_argument("An invalid argument type was encountered while trying to build the Taylor derivative " "of a relup in compact mode"); @@ -523,45 +617,64 @@ llvm::Function *taylor_c_diff_func_relup_impl(llvm_state &, llvm::Type *, const // LCOV_EXCL_STOP -llvm::Function *taylor_c_diff_func_relup(llvm_state &s, llvm::Type *fp_t, const relup_impl &fn, std::uint32_t n_uvars, - std::uint32_t batch_size) -{ - assert(fn.args().size() == 1u); - - return std::visit([&](const auto &v) { return taylor_c_diff_func_relup_impl(s, fp_t, fn, v, n_uvars, batch_size); }, - fn.args()[0].value()); -} - } // namespace llvm::Function *relup_impl::taylor_c_diff_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t n_uvars, std::uint32_t batch_size, bool) const { - return taylor_c_diff_func_relup(s, fp_t, *this, n_uvars, batch_size); + assert(args().size() == 1u); + + return std::visit( + [&](const auto &v) { return taylor_c_diff_func_relup_impl(s, fp_t, *this, v, n_uvars, batch_size, m_slope); }, + args()[0].value()); } } // namespace detail -expression relu(expression x) +expression relu(expression x, double slope) { + detail::relu_slope_check(slope); + // Fold relu(number) to its value. if (const auto *num_ptr = std::get_if(&x.value())) { - return std::visit([](const auto &x) { return expression{x > 0 ? x : 0.}; }, num_ptr->value()); + return std::visit([slope](const auto &x) { return expression{x > 0 ? x : slope * x}; }, num_ptr->value()); } else { - return expression{func{detail::relu_impl{std::move(x)}}}; + return expression{func{detail::relu_impl{std::move(x), slope}}}; } } -expression relup(expression x) +expression relup(expression x, double slope) { + detail::relu_slope_check(slope); + // Fold relup(number) to its value. if (const auto *num_ptr = std::get_if(&x.value())) { - return std::visit([](const auto &x) { return expression{x > 0 ? 1. : 0.}; }, num_ptr->value()); + return std::visit([slope](const auto &x) { return expression{x > 0 ? 1. : slope}; }, num_ptr->value()); } else { - return expression{func{detail::relup_impl{std::move(x)}}}; + return expression{func{detail::relup_impl{std::move(x), slope}}}; } } +leaky_relu::leaky_relu(double slope) : m_slope(slope) +{ + detail::relu_slope_check(slope); +} + +expression leaky_relu::operator()(expression x) const +{ + return relu(std::move(x), m_slope); +} + +leaky_relup::leaky_relup(double slope) : m_slope(slope) +{ + detail::relu_slope_check(slope); +} + +expression leaky_relup::operator()(expression x) const +{ + return relup(std::move(x), m_slope); +} + HEYOKA_END_NAMESPACE HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::relu_impl) diff --git a/src/taylor_01.cpp b/src/taylor_01.cpp index 31e577e89..6eb23ca0b 100644 --- a/src/taylor_01.cpp +++ b/src/taylor_01.cpp @@ -110,6 +110,7 @@ taylor_c_diff_func_name_args(llvm::LLVMContext &context, llvm::Type *fp_t, const const std::vector> &args, std::uint32_t n_hidden_deps) { + assert(std::find(name.begin(), name.end(), '.') == name.end()); assert(fp_t != nullptr); assert(n_uvars > 0u); diff --git a/test/relu.cpp b/test/relu.cpp index 083b6f30d..5e57682ec 100644 --- a/test/relu.cpp +++ b/test/relu.cpp @@ -11,12 +11,16 @@ #include #include #include +#include #include #include +#include #include #include +#include + #include #if defined(HEYOKA_HAVE_REAL128) @@ -32,6 +36,7 @@ #endif #include +#include #include #include #include @@ -64,15 +69,15 @@ constexpr bool skip_batch_ld = ; template -T cpp_relu(T x) +T cpp_relu(T x, T slope = 0) { - return x > 0 ? x : T(0); + return x > 0 ? x : slope * x; } template -T cpp_relup(T x) +T cpp_relup(T x, T slope = 0) { - return x > 0 ? T(1) : T(0); + return x > 0 ? T(1) : slope; } TEST_CASE("def ctor") @@ -82,6 +87,8 @@ TEST_CASE("def ctor") REQUIRE(k.args().size() == 1u); REQUIRE(k.args()[0] == 0_dbl); + REQUIRE(k.get_name() == "relu"); + REQUIRE(k.get_slope() == 0.); } { @@ -89,9 +96,90 @@ TEST_CASE("def ctor") REQUIRE(k.args().size() == 1u); REQUIRE(k.args()[0] == 0_dbl); + REQUIRE(k.get_name() == "relup"); + REQUIRE(k.get_slope() == 0.); + } +} + +TEST_CASE("stream op") +{ + { + auto ex = relu("x"_var); + std::ostringstream oss; + oss << ex; + REQUIRE(oss.str() == "relu(x)"); + } + + { + auto ex = relup("x"_var); + std::ostringstream oss; + oss << ex; + REQUIRE(oss.str() == "relup(x)"); + } + + { + auto ex = relu("x"_var, 1.); + std::ostringstream oss; + oss << ex; + REQUIRE(oss.str() == "leaky_relu(x, 1)"); + } + + { + auto ex = relup("x"_var, 1.); + std::ostringstream oss; + oss << ex; + REQUIRE(oss.str() == "leaky_relup(x, 1)"); + } +} + +TEST_CASE("leaky wrappers") +{ + auto [x, y] = make_vars("x", "y"); + + REQUIRE(leaky_relu(.01)(x) == relu(x, 0.01)); + REQUIRE(leaky_relup(.01)(y) == relup(y, 0.01)); +} + +TEST_CASE("names") +{ + { + auto ex = relu("x"_var); + REQUIRE(std::get(ex.value()).get_name() == "relu"); + } + + { + auto ex = relup("x"_var); + REQUIRE(std::get(ex.value()).get_name() == "relup"); + } + + { + auto ex = relu("x"_var, 1.); + REQUIRE(std::get(ex.value()).get_name() != "relu"); + REQUIRE(boost::starts_with(std::get(ex.value()).get_name(), "relu_0x")); + REQUIRE(std::get(ex.value()).extract()->get_slope() == 1); + } + + { + auto ex = relup("x"_var, 1.); + REQUIRE(std::get(ex.value()).get_name() != "relup"); + REQUIRE(boost::starts_with(std::get(ex.value()).get_name(), "relup_0x")); + REQUIRE(std::get(ex.value()).extract()->get_slope() == 1); } } +TEST_CASE("invalid slopes") +{ + using Catch::Matchers::Message; + + REQUIRE_THROWS_MATCHES(relu("x"_var, -1.), std::invalid_argument, + Message("The slope parameter for a leaky ReLU must be finite and non-negative, " + "but the value -1 was provided instead")); + REQUIRE_THROWS_MATCHES(relup("x"_var, std::numeric_limits::quiet_NaN()), std::invalid_argument, + Message(fmt::format("The slope parameter for a leaky ReLU must be finite and non-negative, " + "but the value {} was provided instead", + std::numeric_limits::quiet_NaN()))); +} + TEST_CASE("normalise") { { @@ -100,11 +188,23 @@ TEST_CASE("normalise") REQUIRE(ex == .1_dbl); } + { + auto ex = relu(fix(-.1_dbl), 0.01); + ex = normalise(unfix(ex)); + REQUIRE(ex == expression(-.1 * 0.01)); + } + { auto ex = relup(fix(-.1_dbl)); ex = normalise(unfix(ex)); REQUIRE(ex == 0_dbl); } + + { + auto ex = relup(fix(-.1_dbl), 0.01); + ex = normalise(unfix(ex)); + REQUIRE(ex == 0.01_dbl); + } } TEST_CASE("diff") @@ -112,15 +212,25 @@ TEST_CASE("diff") auto [x, y] = make_vars("x", "y"); REQUIRE(diff(relu(x), x) == relup(x)); + REQUIRE(diff(relu(x, 0.01), x) == relup(x, 0.01)); REQUIRE(diff(relup(x), x) == 0_dbl); + REQUIRE(diff(relup(x, 0.01), x) == 0_dbl); REQUIRE(diff(relu(x * y), x) == y * relup(x * y)); REQUIRE(diff(relu(x * y), par[0]) == 0_dbl); REQUIRE(diff(relu(x * par[0]), par[0]) == x * relup(x * par[0])); + REQUIRE(diff(relu(x * y, 0.02), x) == y * relup(x * y, 0.02)); + REQUIRE(diff(relu(x * y, 0.03), par[0]) == 0_dbl); + REQUIRE(diff(relu(x * par[0], 0.04), par[0]) == x * relup(x * par[0], 0.04)); + REQUIRE(diff(relup(x * y), x) == 0_dbl); REQUIRE(diff(relup(x * y), par[0]) == 0_dbl); REQUIRE(diff(relup(x * par[0]), par[0]) == 0_dbl); + + REQUIRE(diff(relup(x * y, 0.01), x) == 0_dbl); + REQUIRE(diff(relup(x * y, 0.02), par[0]) == 0_dbl); + REQUIRE(diff(relup(x * par[0], 0.03), par[0]) == 0_dbl); } TEST_CASE("constant fold") @@ -130,6 +240,12 @@ TEST_CASE("constant fold") REQUIRE(relup(1.1_dbl) == 1_dbl); REQUIRE(relup(-1.1_dbl) == 0_dbl); + + REQUIRE(relu(1.1_dbl, 0.01) == 1.1_dbl); + REQUIRE(relu(-1.1_dbl, 0.01) == expression{-1.1 * 0.01}); + + REQUIRE(relup(1.1_dbl, 0.01) == 1_dbl); + REQUIRE(relup(-1.1_dbl, 0.01) == 0.01_dbl); } TEST_CASE("s11n") @@ -158,6 +274,30 @@ TEST_CASE("s11n") REQUIRE(ex == relu(x + y)); } + { + std::stringstream ss; + + auto [x, y] = make_vars("x", "y"); + + auto ex = relu(x + y, 0.03); + + { + boost::archive::binary_oarchive oa(ss); + + oa << ex; + } + + ex = 0_dbl; + + { + boost::archive::binary_iarchive ia(ss); + + ia >> ex; + } + + REQUIRE(ex == relu(x + y, 0.03)); + } + { std::stringstream ss; @@ -181,6 +321,30 @@ TEST_CASE("s11n") REQUIRE(ex == relup(x + y)); } + + { + std::stringstream ss; + + auto [x, y] = make_vars("x", "y"); + + auto ex = relup(x + y, 0.01); + + { + boost::archive::binary_oarchive oa(ss); + + oa << ex; + } + + ex = 0_dbl; + + { + boost::archive::binary_iarchive ia(ss); + + ia >> ex; + } + + REQUIRE(ex == relup(x + y, 0.01)); + } } TEST_CASE("cfunc") @@ -250,6 +414,74 @@ TEST_CASE("cfunc") } } +TEST_CASE("cfunc leaky") +{ + auto tester = [](auto fp_x, unsigned opt_level, bool high_accuracy, bool compact_mode) { + using fp_t = decltype(fp_x); + + auto [x] = make_vars("x"); + + std::uniform_real_distribution x_dist(-10, 10); + + std::vector outs, ins, pars; + + for (auto batch_size : {1u, 2u, 4u, 5u}) { + if (batch_size != 1u && std::is_same_v && skip_batch_ld) { + continue; + } + + outs.resize(batch_size * 4u); + ins.resize(batch_size); + pars.resize(batch_size * 2u); + + llvm_state s{kw::opt_level = opt_level}; + + add_cfunc(s, "cfunc", {relu(x, .01), relu(par[0], .02), relup(x, .03), relup(par[1], .04)}, + kw::batch_size = batch_size, kw::high_accuracy = high_accuracy, + kw::compact_mode = compact_mode); + + if (opt_level == 0u && compact_mode) { + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.relu_0x")); + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.relup_0x")); + } + + s.compile(); + + auto *cf_ptr + = reinterpret_cast(s.jit_lookup("cfunc")); + + for (auto niter = 0; niter < 100; ++niter) { + for (auto i = 0u; i < batch_size; ++i) { + // Generate the xs. + ins[i] = x_dist(rng); + + // Generate the pars. + pars[i] = x_dist(rng); + pars[i + batch_size] = x_dist(rng); + } + + cf_ptr(outs.data(), ins.data(), pars.data(), nullptr); + + for (auto i = 0u; i < batch_size; ++i) { + REQUIRE(outs[i] == cpp_relu(ins[i], fp_t(0.01))); + REQUIRE(outs[i + batch_size] == cpp_relu(pars[i], fp_t(0.02))); + REQUIRE(outs[i + 2u * batch_size] == cpp_relup(ins[i], fp_t(0.03))); + REQUIRE(outs[i + 3u * batch_size] == cpp_relup(pars[i + batch_size], fp_t(0.04))); + } + } + } + }; + + for (auto cm : {false, true}) { + for (auto f : {false, true}) { + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 0, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 1, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 2, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 3, f, cm); }); + } + } +} + #if defined(HEYOKA_HAVE_REAL) TEST_CASE("cfunc mp") @@ -300,4 +532,52 @@ TEST_CASE("cfunc mp") } } +TEST_CASE("cfunc mp leaky") +{ + using fp_t = mppp::real; + + const auto prec = 237; + + auto [x] = make_vars("x"); + + std::uniform_real_distribution x_dist(-10, 10); + + std::vector outs, ins, pars; + + outs.resize(4u, mppp::real{0, prec}); + ins.resize(1u); + pars.resize(2u); + + for (auto compact_mode : {false, true}) { + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + add_cfunc(s, "cfunc", {relu(x, .01), relu(par[0], .02), relup(x, .03), relup(par[1], .04)}, + kw::compact_mode = compact_mode, kw::prec = prec); + + if (opt_level == 0u && compact_mode) { + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.relu_0x")); + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.relup_0x")); + } + + s.compile(); + + auto *cf_ptr + = reinterpret_cast(s.jit_lookup("cfunc")); + + // Generate the x and pars. + ins[0] = mppp::real{x_dist(rng), prec}; + pars[0] = mppp::real{x_dist(rng), prec}; + pars[1] = mppp::real{x_dist(rng), prec}; + + cf_ptr(outs.data(), ins.data(), pars.data(), nullptr); + + REQUIRE(outs[0] == cpp_relu(ins[0], mppp::real(0.01))); + REQUIRE(outs[1] == cpp_relu(pars[0], mppp::real(0.02))); + REQUIRE(outs[2] == cpp_relup(ins[0], mppp::real(0.03))); + REQUIRE(outs[3] == cpp_relup(pars[1], mppp::real(0.04))); + } + } +} + #endif diff --git a/test/taylor_relu.cpp b/test/taylor_relu.cpp index c53090d4c..9a32085fc 100644 --- a/test/taylor_relu.cpp +++ b/test/taylor_relu.cpp @@ -42,15 +42,15 @@ const auto fp_types = std::tuple{}; template -T cpp_relu(T x) +T cpp_relu(T x, T slope = 0) { - return x > 0 ? x : T(0); + return x > 0 ? x : slope * x; } template -T cpp_relup(T x) +T cpp_relup(T x, T slope = 0) { - return x > 0 ? T(1) : T(0); + return x > 0 ? T(1) : slope; } TEST_CASE("taylor relu relup") @@ -162,3 +162,117 @@ TEST_CASE("taylor relu relup") } } } + +TEST_CASE("taylor relu relup leaky") +{ + auto tester = [](auto fp_x, unsigned opt_level, bool high_accuracy, bool compact_mode) { + using fp_t = decltype(fp_x); + + auto x = "x"_var, y = "y"_var; + + // Number tests. + { + llvm_state s{kw::opt_level = opt_level}; + + taylor_add_jet(s, "jet2", {relu(par[0], 0.01) + relup(par[1], 0.02), x + y}, 3, 2, high_accuracy, + compact_mode); + taylor_add_jet(s, "jet", {relu(par[0], 0.01) + relup(par[1], 0.02), x + y}, 3, 2, high_accuracy, + compact_mode); + + s.compile(); + + if (opt_level == 0u && compact_mode) { + REQUIRE(boost::contains(s.get_ir(), "@heyoka.taylor_c_diff.relu_0x")); + REQUIRE(boost::contains(s.get_ir(), "@heyoka.taylor_c_diff.relup_0x")); + REQUIRE(boost::contains(s.get_ir(), ".par")); + } + + auto jptr = reinterpret_cast(s.jit_lookup("jet")); + + std::vector jet{fp_t{2}, fp_t{-1}, fp_t{-3}, fp_t{5}}, pars{fp_t{-1}, fp_t{2}, fp_t{4}, fp_t{-3}}; + jet.resize(16); + + jptr(jet.data(), pars.data(), nullptr); + + REQUIRE(jet[0] == 2); + REQUIRE(jet[1] == -1); + + REQUIRE(jet[2] == -3); + REQUIRE(jet[3] == 5); + + REQUIRE(jet[4] == approximately(cpp_relu(pars[0], fp_t(0.01)) + cpp_relup(pars[2], fp_t(0.02)))); + REQUIRE(jet[5] == approximately(cpp_relu(pars[1], fp_t(0.01)) + cpp_relup(pars[3], fp_t(0.02)))); + + REQUIRE(jet[6] == approximately(jet[0] + jet[2])); + REQUIRE(jet[7] == approximately(jet[1] + jet[3])); + + REQUIRE(jet[8] == fp_t(0)); + REQUIRE(jet[9] == fp_t(0)); + + REQUIRE(jet[10] == approximately((jet[4] + jet[6]) / 2)); + REQUIRE(jet[11] == approximately((jet[5] + jet[7]) / 2)); + + REQUIRE(jet[12] == fp_t(0)); + REQUIRE(jet[13] == fp_t(0)); + + REQUIRE(jet[14] == approximately((jet[10] + jet[8]) / 3)); + REQUIRE(jet[15] == approximately((jet[11] + jet[9]) / 3)); + } + + // Variable tests. + { + llvm_state s{kw::opt_level = opt_level}; + + taylor_add_jet(s, "jet2", {relu(x, 0.01) + relup(y, 0.02), x + y}, 3, 2, high_accuracy, compact_mode); + taylor_add_jet(s, "jet", {relu(x, 0.01) + relup(y, 0.02), x + y}, 3, 2, high_accuracy, compact_mode); + + s.compile(); + + if (opt_level == 0u && compact_mode) { + REQUIRE(boost::contains(s.get_ir(), "@heyoka.taylor_c_diff.relu_0x")); + REQUIRE(boost::contains(s.get_ir(), "@heyoka.taylor_c_diff.relup_0x")); + REQUIRE(boost::contains(s.get_ir(), ".var")); + } + + auto jptr = reinterpret_cast(s.jit_lookup("jet")); + + std::vector jet{fp_t{2}, fp_t{-1}, fp_t{-3}, fp_t{5}}; + jet.resize(16); + + jptr(jet.data(), nullptr, nullptr); + + REQUIRE(jet[0] == 2); + REQUIRE(jet[1] == -1); + + REQUIRE(jet[2] == -3); + REQUIRE(jet[3] == 5); + + REQUIRE(jet[4] == approximately(cpp_relu(jet[0], fp_t(0.01)) + cpp_relup(jet[2], fp_t(0.02)))); + REQUIRE(jet[5] == approximately(cpp_relu(jet[1], fp_t(0.01)) + cpp_relup(jet[3], fp_t(0.02)))); + + REQUIRE(jet[6] == approximately(jet[0] + jet[2])); + REQUIRE(jet[7] == approximately(jet[1] + jet[3])); + + REQUIRE(jet[8] == approximately((cpp_relup(jet[0], fp_t(0.01)) * jet[4]) / 2)); + REQUIRE(jet[9] == approximately((cpp_relup(jet[1], fp_t(0.01)) * jet[5]) / 2)); + + REQUIRE(jet[10] == approximately((jet[4] + jet[6]) / 2)); + REQUIRE(jet[11] == approximately((jet[5] + jet[7]) / 2)); + + REQUIRE(jet[12] == approximately((cpp_relup(jet[0], fp_t(0.01)) * 2 * jet[8]) / 6)); + REQUIRE(jet[13] == approximately((cpp_relup(jet[1], fp_t(0.01)) * 2 * jet[9]) / 6)); + + REQUIRE(jet[14] == approximately((jet[10] + jet[8]) / 3)); + REQUIRE(jet[15] == approximately((jet[11] + jet[9]) / 3)); + } + }; + + for (auto cm : {false, true}) { + for (auto f : {false, true}) { + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 0, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 1, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 2, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 3, f, cm); }); + } + } +} diff --git a/test/taylor_relu_mp.cpp b/test/taylor_relu_mp.cpp index c174fb056..e97429cc4 100644 --- a/test/taylor_relu_mp.cpp +++ b/test/taylor_relu_mp.cpp @@ -24,15 +24,15 @@ using namespace heyoka; using namespace heyoka_test; template -T cpp_relu(T x) +T cpp_relu(T x, T slope = T(0)) { - return x > 0 ? x : T(0, x.get_prec()); + return x > 0 ? x : x * T(slope, x.get_prec()); } template -T cpp_relup(T x) +T cpp_relup(T x, T slope = T(0)) { - return x > 0 ? T(1, x.get_prec()) : T(0, x.get_prec()); + return x > 0 ? T(1, x.get_prec()) : T(slope, x.get_prec()); } TEST_CASE("relu") @@ -77,3 +77,50 @@ TEST_CASE("relu") } } } + +TEST_CASE("relu leaky") +{ + using fp_t = mppp::real; + + auto [x, y] = make_vars("x", "y"); + + for (auto prec : {30, 123}) { + for (auto cm : {false, true}) { + for (auto ha : {false, true}) { + for (auto opt_level : {0u, 3u}) { + // Test with num/param/var. + { + llvm_state s{kw::opt_level = opt_level}; + + taylor_add_jet(s, "jet", {relu(x, 0.01) + relup(y, 0.02), x + y}, 2, 1, ha, cm, {}, false, + prec); + + s.compile(); + + if (opt_level == 0u && cm) { + REQUIRE(boost::contains(s.get_ir(), "heyoka.taylor_c_diff.relu_0x")); + REQUIRE(boost::contains(s.get_ir(), "heyoka.taylor_c_diff.relup_0x")); + REQUIRE(boost::contains(s.get_ir(), ".var")); + } + + auto jptr = reinterpret_cast(s.jit_lookup("jet")); + + std::vector jet{fp_t{-2, prec}, fp_t{-1, prec}}; + jet.resize(6, fp_t(0, prec)); + + jptr(jet.data(), nullptr, nullptr); + + REQUIRE(jet[0] == -2); + REQUIRE(jet[1] == -1); + REQUIRE( + jet[2] + == approximately(cpp_relu(jet[0], mppp::real(0.01)) + cpp_relup(jet[1], mppp::real(0.02)))); + REQUIRE(jet[3] == jet[0] + jet[1]); + REQUIRE(jet[4] == approximately(cpp_relup(jet[0], mppp::real(0.01)) * jet[2] / fp_t(2, prec))); + REQUIRE(jet[5] == approximately((jet[2] + jet[3]) / fp_t(2, prec))); + } + } + } + } + } +}