diff --git a/CMakeLists.txt b/CMakeLists.txt index 56e649501..b835ec346 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -270,6 +270,7 @@ set(HEYOKA_SRC_FILES "${CMAKE_CURRENT_SOURCE_DIR}/src/math/log.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/math/pow.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/math/sigmoid.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/src/math/relu.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/math/sin.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/math/sqrt.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/math/tan.cpp" diff --git a/include/heyoka/math.hpp b/include/heyoka/math.hpp index f4ca63cb8..f4f43125f 100644 --- a/include/heyoka/math.hpp +++ b/include/heyoka/math.hpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include diff --git a/include/heyoka/math/relu.hpp b/include/heyoka/math/relu.hpp new file mode 100644 index 000000000..ac4c06c0b --- /dev/null +++ b/include/heyoka/math/relu.hpp @@ -0,0 +1,99 @@ +// Copyright 2020, 2021, 2022, 2023 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef HEYOKA_MATH_RELU_HPP +#define HEYOKA_MATH_RELU_HPP + +#include +// #include +#include + +#include +// #include +#include +#include +#include +#include +#include + +HEYOKA_BEGIN_NAMESPACE + +namespace detail +{ + +class HEYOKA_DLL_PUBLIC relu_impl : public func_base +{ + friend class boost::serialization::access; + template + void serialize(Archive &ar, unsigned) + { + ar &boost::serialization::base_object(*this); + } + +public: + relu_impl(); + explicit relu_impl(expression); + + [[nodiscard]] expression normalise() const; + + [[nodiscard]] std::vector gradient() const; + + [[nodiscard]] llvm::Value *llvm_eval(llvm_state &, llvm::Type *, const std::vector &, llvm::Value *, + llvm::Value *, llvm::Value *, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const; + + llvm::Value *taylor_diff(llvm_state &, llvm::Type *, const std::vector &, + const std::vector &, llvm::Value *, llvm::Value *, std::uint32_t, + std::uint32_t, std::uint32_t, std::uint32_t, bool) const; + + llvm::Function *taylor_c_diff_func(llvm_state &, llvm::Type *, std::uint32_t, std::uint32_t, bool) const; +}; + +class HEYOKA_DLL_PUBLIC relup_impl : public func_base +{ + friend class boost::serialization::access; + template + void serialize(Archive &ar, unsigned) + { + ar &boost::serialization::base_object(*this); + } + +public: + relup_impl(); + explicit relup_impl(expression); + + [[nodiscard]] expression normalise() const; + + [[nodiscard]] std::vector gradient() const; + + [[nodiscard]] llvm::Value *llvm_eval(llvm_state &, llvm::Type *, const std::vector &, llvm::Value *, + llvm::Value *, llvm::Value *, std::uint32_t, bool) const; + + [[nodiscard]] llvm::Function *llvm_c_eval_func(llvm_state &, llvm::Type *, std::uint32_t, bool) const; + + llvm::Value *taylor_diff(llvm_state &, llvm::Type *, const std::vector &, + const std::vector &, llvm::Value *, llvm::Value *, std::uint32_t, + std::uint32_t, std::uint32_t, std::uint32_t, bool) const; + + llvm::Function *taylor_c_diff_func(llvm_state &, llvm::Type *, std::uint32_t, std::uint32_t, bool) const; +}; + +} // namespace detail + +HEYOKA_DLL_PUBLIC expression relu(expression); + +HEYOKA_DLL_PUBLIC expression relup(expression); + +HEYOKA_END_NAMESPACE + +HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::relu_impl) + +HEYOKA_S11N_FUNC_EXPORT_KEY(heyoka::detail::relup_impl) + +#endif diff --git a/src/math/relu.cpp b/src/math/relu.cpp new file mode 100644 index 000000000..1b73ab7cb --- /dev/null +++ b/src/math/relu.cpp @@ -0,0 +1,568 @@ +// Copyright 2020, 2021, 2022, 2023 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +HEYOKA_BEGIN_NAMESPACE + +namespace detail +{ + +relu_impl::relu_impl() : relu_impl(0_dbl) {} + +relu_impl::relu_impl(expression ex) : func_base("relu", std::vector{std::move(ex)}) {} + +[[nodiscard]] expression relu_impl::normalise() const +{ + assert(args().size() == 1u); + return relu(args()[0]); +} + +[[nodiscard]] std::vector relu_impl::gradient() const +{ + assert(args().size() == 1u); + return {relup(args()[0])}; +} + +namespace +{ + +// LLVM implementation of relu. +llvm::Value *llvm_relu(llvm_state &s, llvm::Value *x) +{ + auto *zero_c = llvm_constantfp(s, x->getType(), 0.); + return s.builder().CreateSelect(llvm_fcmp_ogt(s, x, zero_c), x, zero_c); +} + +} // namespace + +[[nodiscard]] llvm::Value *relu_impl::llvm_eval(llvm_state &s, llvm::Type *fp_t, + const std::vector &eval_arr, llvm::Value *par_ptr, + llvm::Value *, llvm::Value *stride, std::uint32_t batch_size, + bool high_accuracy) const +{ + return llvm_eval_helper( + [&s](const std::vector &args, bool) { + assert(args.size() == 1u); + return llvm_relu(s, args[0]); + }, + *this, s, fp_t, eval_arr, par_ptr, stride, batch_size, high_accuracy); +} + +llvm::Function *relu_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t batch_size, + bool high_accuracy) const +{ + return llvm_c_eval_func_helper( + "relu", + [&](const std::vector &args, bool) { + assert(args.size() == 1u); + return llvm_relu(s, args[0]); + }, + *this, s, fp_t, batch_size, high_accuracy); +} + +namespace +{ + +// Derivative of relu(number). +template , int> = 0> +llvm::Value *taylor_diff_relu_impl(llvm_state &s, llvm::Type *fp_t, const relu_impl &, + const std::vector &, const U &num, const std::vector &, + llvm::Value *par_ptr, std::uint32_t, std::uint32_t order, std::uint32_t, + std::uint32_t batch_size) +{ + if (order == 0u) { + return llvm_relu(s, taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size)); + } else { + return vector_splat(s.builder(), llvm_constantfp(s, fp_t, 0.), batch_size); + } +} + +// Derivative of relu(variable). +llvm::Value *taylor_diff_relu_impl(llvm_state &s, llvm::Type *, const relu_impl &, const std::vector &, + const variable &var, const std::vector &arr, llvm::Value *, + std::uint32_t n_uvars, std::uint32_t order, std::uint32_t, std::uint32_t) +{ + const auto u_idx = uname_to_index(var.name()); + + auto *u_zero = taylor_fetch_diff(arr, u_idx, 0, n_uvars); + auto *u_order = taylor_fetch_diff(arr, u_idx, order, n_uvars); + + auto *zero_c = llvm_constantfp(s, u_zero->getType(), 0.); + return s.builder().CreateSelect(llvm_fcmp_ogt(s, u_zero, zero_c), u_order, zero_c); +} + +// LCOV_EXCL_START + +// All the other cases. +template , int> = 0> +llvm::Value *taylor_diff_relu_impl(llvm_state &, llvm::Type *, const relu_impl &, const std::vector &, + const U &, const std::vector &, llvm::Value *, std::uint32_t, + std::uint32_t, std::uint32_t, std::uint32_t) +{ + throw std::invalid_argument( + "An invalid argument type was encountered while trying to build the Taylor derivative of a relu"); +} + +// LCOV_EXCL_STOP + +llvm::Value *taylor_diff_relu(llvm_state &s, llvm::Type *fp_t, const relu_impl &f, + const std::vector &deps, const std::vector &arr, + llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, + std::uint32_t batch_size) +{ + assert(f.args().size() == 1u); + + // LCOV_EXCL_START + if (!deps.empty()) { + throw std::invalid_argument( + fmt::format("An empty hidden dependency vector is expected in order to compute the Taylor " + "derivative of the relu, but a vector of size {} was passed instead", + deps.size())); + } + // LCOV_EXCL_STOP + + return std::visit( + [&](const auto &v) { + return taylor_diff_relu_impl(s, fp_t, f, deps, v, arr, par_ptr, n_uvars, order, idx, batch_size); + }, + f.args()[0].value()); +} + +} // namespace + +llvm::Value *relu_impl::taylor_diff(llvm_state &s, llvm::Type *fp_t, const std::vector &deps, + const std::vector &arr, llvm::Value *par_ptr, llvm::Value *, + std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, + std::uint32_t batch_size, bool) const +{ + return taylor_diff_relu(s, fp_t, *this, deps, arr, par_ptr, n_uvars, order, idx, batch_size); +} + +namespace +{ + +// Derivative of relu(number). +template , int> = 0> +llvm::Function *taylor_c_diff_func_relu_impl(llvm_state &s, llvm::Type *fp_t, const relu_impl &, const U &num, + std::uint32_t n_uvars, std::uint32_t batch_size) +{ + return taylor_c_diff_func_numpar( + s, fp_t, n_uvars, batch_size, "relu", 0, + [&s](const auto &args) { + // LCOV_EXCL_START + assert(args.size() == 1u); + assert(args[0] != nullptr); + // LCOV_EXCL_STOP + + return llvm_relu(s, args[0]); + }, + num); +} + +// Derivative of relu(variable). +llvm::Function *taylor_c_diff_func_relu_impl(llvm_state &s, llvm::Type *fp_t, const relu_impl &, const variable &var, + std::uint32_t n_uvars, std::uint32_t batch_size) +{ + auto &module = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + // Fetch the vector floating-point type. + auto *val_t = make_vector_type(fp_t, batch_size); + + // Fetch the function name and arguments. + const auto na_pair = taylor_c_diff_func_name_args(context, fp_t, "relu", n_uvars, batch_size, {var}); + const auto &fname = na_pair.first; + const auto &fargs = na_pair.second; + + // Try to see if we already created the function. + auto *f = module.getFunction(fname); + + if (f != nullptr) { + // The function was created before, return it. + return f; + } + + // The function was not created before, do it now. + + // Fetch the current insertion block. + auto *orig_bb = builder.GetInsertBlock(); + + // The return type is val_t. + auto *ft = llvm::FunctionType::get(val_t, fargs, false); + // Create the function + f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module); + assert(f != nullptr); + + // Fetch the necessary function arguments. + auto *ord = f->args().begin(); + auto *diff_ptr = f->args().begin() + 2; + auto *var_idx = f->args().begin() + 5; + + // Create a new basic block to start insertion into. + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + // Load the orders 0 and ord of var_idx. + auto *u_zero = taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.getInt32(0), var_idx); + auto *u_ord = taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, ord, var_idx); + + auto *zero_c = llvm_constantfp(s, u_zero->getType(), 0.); + + builder.CreateRet(builder.CreateSelect(llvm_fcmp_ogt(s, u_zero, zero_c), u_ord, zero_c)); + + // Verify. + s.verify_function(f); + + // Restore the original insertion block. + builder.SetInsertPoint(orig_bb); + + return f; +} + +// LCOV_EXCL_START + +// All the other cases. +template , int> = 0> +llvm::Function *taylor_c_diff_func_relu_impl(llvm_state &, llvm::Type *, const relu_impl &, const U &, std::uint32_t, + std::uint32_t) +{ + throw std::invalid_argument("An invalid argument type was encountered while trying to build the Taylor derivative " + "of a relu in compact mode"); +} + +// LCOV_EXCL_STOP + +llvm::Function *taylor_c_diff_func_relu(llvm_state &s, llvm::Type *fp_t, const relu_impl &fn, std::uint32_t n_uvars, + std::uint32_t batch_size) +{ + assert(fn.args().size() == 1u); + + return std::visit([&](const auto &v) { return taylor_c_diff_func_relu_impl(s, fp_t, fn, v, n_uvars, batch_size); }, + fn.args()[0].value()); +} + +} // namespace + +llvm::Function *relu_impl::taylor_c_diff_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t n_uvars, + std::uint32_t batch_size, bool) const +{ + return taylor_c_diff_func_relu(s, fp_t, *this, n_uvars, batch_size); +} + +relup_impl::relup_impl() : relup_impl(0_dbl) {} + +relup_impl::relup_impl(expression ex) : func_base("relup", std::vector{std::move(ex)}) {} + +[[nodiscard]] expression relup_impl::normalise() const +{ + assert(args().size() == 1u); + return relup(args()[0]); +} + +[[nodiscard]] std::vector relup_impl::gradient() const +{ + assert(args().size() == 1u); + return {0_dbl}; +} + +namespace +{ + +// LLVM implementation of relup. +llvm::Value *llvm_relup(llvm_state &s, llvm::Value *x) +{ + auto *zero_c = llvm_constantfp(s, x->getType(), 0.); + auto *one_c = llvm_constantfp(s, x->getType(), 1.); + return s.builder().CreateSelect(llvm_fcmp_ogt(s, x, zero_c), one_c, zero_c); +} + +} // namespace + +[[nodiscard]] llvm::Value *relup_impl::llvm_eval(llvm_state &s, llvm::Type *fp_t, + const std::vector &eval_arr, llvm::Value *par_ptr, + llvm::Value *, llvm::Value *stride, std::uint32_t batch_size, + bool high_accuracy) const +{ + return llvm_eval_helper( + [&s](const std::vector &args, bool) { + assert(args.size() == 1u); + return llvm_relup(s, args[0]); + }, + *this, s, fp_t, eval_arr, par_ptr, stride, batch_size, high_accuracy); +} + +llvm::Function *relup_impl::llvm_c_eval_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t batch_size, + bool high_accuracy) const +{ + return llvm_c_eval_func_helper( + "relup", + [&](const std::vector &args, bool) { + assert(args.size() == 1u); + return llvm_relup(s, args[0]); + }, + *this, s, fp_t, batch_size, high_accuracy); +} + +namespace +{ + +// Derivative of relup(number). +template , int> = 0> +llvm::Value *taylor_diff_relup_impl(llvm_state &s, llvm::Type *fp_t, const relup_impl &, + const std::vector &, const U &num, + const std::vector &, llvm::Value *par_ptr, std::uint32_t, + std::uint32_t order, std::uint32_t, std::uint32_t batch_size) +{ + if (order == 0u) { + return llvm_relup(s, taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size)); + } else { + return vector_splat(s.builder(), llvm_constantfp(s, fp_t, 0.), batch_size); + } +} + +// Derivative of relup(variable). +llvm::Value *taylor_diff_relup_impl(llvm_state &s, llvm::Type *, const relup_impl &, const std::vector &, + const variable &var, const std::vector &arr, llvm::Value *, + // NOLINTNEXTLINE(bugprone-easily-swappable-parameters) + std::uint32_t n_uvars, std::uint32_t order, std::uint32_t, std::uint32_t) +{ + const auto u_idx = uname_to_index(var.name()); + auto *u_zero = taylor_fetch_diff(arr, u_idx, 0, n_uvars); + + if (order == 0u) { + return llvm_relup(s, u_zero); + } else { + return llvm_constantfp(s, u_zero->getType(), 0.); + } +} + +// LCOV_EXCL_START + +// All the other cases. +template , int> = 0> +llvm::Value *taylor_diff_relup_impl(llvm_state &, llvm::Type *, const relup_impl &, const std::vector &, + const U &, const std::vector &, llvm::Value *, std::uint32_t, + std::uint32_t, std::uint32_t, std::uint32_t) +{ + throw std::invalid_argument( + "An invalid argument type was encountered while trying to build the Taylor derivative of a relup"); +} + +// LCOV_EXCL_STOP + +llvm::Value *taylor_diff_relup(llvm_state &s, llvm::Type *fp_t, const relup_impl &f, + const std::vector &deps, const std::vector &arr, + llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, + std::uint32_t batch_size) +{ + assert(f.args().size() == 1u); + + // LCOV_EXCL_START + if (!deps.empty()) { + throw std::invalid_argument( + fmt::format("An empty hidden dependency vector is expected in order to compute the Taylor " + "derivative of the relup, but a vector of size {} was passed instead", + deps.size())); + } + // LCOV_EXCL_STOP + + return std::visit( + [&](const auto &v) { + return taylor_diff_relup_impl(s, fp_t, f, deps, v, arr, par_ptr, n_uvars, order, idx, batch_size); + }, + f.args()[0].value()); +} + +} // namespace + +llvm::Value *relup_impl::taylor_diff(llvm_state &s, llvm::Type *fp_t, const std::vector &deps, + const std::vector &arr, llvm::Value *par_ptr, llvm::Value *, + std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, + std::uint32_t batch_size, bool) const +{ + return taylor_diff_relup(s, fp_t, *this, deps, arr, par_ptr, n_uvars, order, idx, batch_size); +} + +namespace +{ + +// Derivative of relup(number). +template , int> = 0> +llvm::Function *taylor_c_diff_func_relup_impl(llvm_state &s, llvm::Type *fp_t, const relup_impl &, const U &num, + std::uint32_t n_uvars, std::uint32_t batch_size) +{ + return taylor_c_diff_func_numpar( + s, fp_t, n_uvars, batch_size, "relup", 0, + [&s](const auto &args) { + // LCOV_EXCL_START + assert(args.size() == 1u); + assert(args[0] != nullptr); + // LCOV_EXCL_STOP + + return llvm_relup(s, args[0]); + }, + num); +} + +// Derivative of relup(variable). +llvm::Function *taylor_c_diff_func_relup_impl(llvm_state &s, llvm::Type *fp_t, const relup_impl &, const variable &var, + std::uint32_t n_uvars, std::uint32_t batch_size) +{ + auto &module = s.module(); + auto &builder = s.builder(); + auto &context = s.context(); + + // Fetch the vector floating-point type. + auto *val_t = make_vector_type(fp_t, batch_size); + + // Fetch the function name and arguments. + const auto na_pair = taylor_c_diff_func_name_args(context, fp_t, "relup", n_uvars, batch_size, {var}); + const auto &fname = na_pair.first; + const auto &fargs = na_pair.second; + + // Try to see if we already created the function. + auto *f = module.getFunction(fname); + + if (f != nullptr) { + // The function was created before, return it. + return f; + } + + // The function was not created before, do it now. + + // Fetch the current insertion block. + auto *orig_bb = builder.GetInsertBlock(); + + // The return type is val_t. + auto *ft = llvm::FunctionType::get(val_t, fargs, false); + // Create the function + f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &module); + assert(f != nullptr); + + // Fetch the necessary function arguments. + auto *ord = f->args().begin(); + auto *diff_ptr = f->args().begin() + 2; + auto *var_idx = f->args().begin() + 5; + + // Create a new basic block to start insertion into. + builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f)); + + // Create the return value. + auto *retval = builder.CreateAlloca(val_t); + + llvm_if_then_else( + s, builder.CreateICmpEQ(ord, builder.getInt32(0)), + [&]() { + // For order 0, invoke the function on the order 0 of var_idx. + auto *ret = llvm_relup(s, taylor_c_load_diff(s, val_t, diff_ptr, n_uvars, builder.getInt32(0), var_idx)); + // NOLINTNEXTLINE(readability-suspicious-call-argument) + builder.CreateStore(ret, retval); + }, + [&]() { + // For all the other orders, the result is zero. + builder.CreateStore(llvm_constantfp(s, val_t, 0.), retval); + }); + + // Return the result. + builder.CreateRet(builder.CreateLoad(val_t, retval)); + + // Verify. + s.verify_function(f); + + // Restore the original insertion block. + builder.SetInsertPoint(orig_bb); + + return f; +} + +// LCOV_EXCL_START + +// All the other cases. +template , int> = 0> +llvm::Function *taylor_c_diff_func_relup_impl(llvm_state &, llvm::Type *, const relup_impl &, const U &, std::uint32_t, + std::uint32_t) +{ + throw std::invalid_argument("An invalid argument type was encountered while trying to build the Taylor derivative " + "of a relup in compact mode"); +} + +// LCOV_EXCL_STOP + +llvm::Function *taylor_c_diff_func_relup(llvm_state &s, llvm::Type *fp_t, const relup_impl &fn, std::uint32_t n_uvars, + std::uint32_t batch_size) +{ + assert(fn.args().size() == 1u); + + return std::visit([&](const auto &v) { return taylor_c_diff_func_relup_impl(s, fp_t, fn, v, n_uvars, batch_size); }, + fn.args()[0].value()); +} + +} // namespace + +llvm::Function *relup_impl::taylor_c_diff_func(llvm_state &s, llvm::Type *fp_t, std::uint32_t n_uvars, + std::uint32_t batch_size, bool) const +{ + return taylor_c_diff_func_relup(s, fp_t, *this, n_uvars, batch_size); +} + +} // namespace detail + +expression relu(expression x) +{ + // Fold relu(number) to its value. + if (const auto *num_ptr = std::get_if(&x.value())) { + return std::visit([](const auto &x) { return expression{x > 0 ? x : 0.}; }, num_ptr->value()); + } else { + return expression{func{detail::relu_impl{std::move(x)}}}; + } +} + +expression relup(expression x) +{ + // Fold relup(number) to its value. + if (const auto *num_ptr = std::get_if(&x.value())) { + return std::visit([](const auto &x) { return expression{x > 0 ? 1. : 0.}; }, num_ptr->value()); + } else { + return expression{func{detail::relup_impl{std::move(x)}}}; + } +} + +HEYOKA_END_NAMESPACE + +HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::relu_impl) + +HEYOKA_S11N_FUNC_EXPORT_IMPLEMENT(heyoka::detail::relup_impl) diff --git a/src/math/sin.cpp b/src/math/sin.cpp index 42b54c2cb..9ea6a25fa 100644 --- a/src/math/sin.cpp +++ b/src/math/sin.cpp @@ -223,6 +223,8 @@ llvm::Value *taylor_diff_sin_impl(llvm_state &s, llvm::Type *fp_t, const sin_imp return llvm_fdiv(s, ret_acc, div); } +// LCOV_EXCL_START + // All the other cases. template , int> = 0> llvm::Value *taylor_diff_sin_impl(llvm_state &, llvm::Type *, const sin_impl &, const std::vector &, @@ -233,18 +235,22 @@ llvm::Value *taylor_diff_sin_impl(llvm_state &, llvm::Type *, const sin_impl &, "An invalid argument type was encountered while trying to build the Taylor derivative of a sine"); } +// LCOV_EXCL_STOP + llvm::Value *taylor_diff_sin(llvm_state &s, llvm::Type *fp_t, const sin_impl &f, const std::vector &deps, const std::vector &arr, llvm::Value *par_ptr, std::uint32_t n_uvars, std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size) { assert(f.args().size() == 1u); + // LCOV_EXCL_START if (deps.size() != 1u) { throw std::invalid_argument( fmt::format("A hidden dependency vector of size 1 is expected in order to compute the Taylor " "derivative of the sine, but a vector of size {} was passed instead", deps.size())); } + // LCOV_EXCL_STOP return std::visit( [&](const auto &v) { @@ -370,6 +376,8 @@ llvm::Function *taylor_c_diff_func_sin_impl(llvm_state &s, llvm::Type *fp_t, con return f; } +// LCOV_EXCL_START + // All the other cases. template , int> = 0> llvm::Function *taylor_c_diff_func_sin_impl(llvm_state &, llvm::Type *, const sin_impl &, const U &, std::uint32_t, @@ -379,6 +387,8 @@ llvm::Function *taylor_c_diff_func_sin_impl(llvm_state &, llvm::Type *, const si "of a sine in compact mode"); } +// LCOV_EXCL_STOP + llvm::Function *taylor_c_diff_func_sin(llvm_state &s, llvm::Type *fp_t, const sin_impl &fn, std::uint32_t n_uvars, std::uint32_t batch_size) { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 7463f3ab1..893e8df41 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -93,6 +93,7 @@ ADD_HEYOKA_TESTCASE(erf) ADD_HEYOKA_TESTCASE(exp) ADD_HEYOKA_TESTCASE(log) ADD_HEYOKA_TESTCASE(sigmoid) +ADD_HEYOKA_TESTCASE(relu) ADD_HEYOKA_TESTCASE(pow) ADD_HEYOKA_TESTCASE(neg) ADD_HEYOKA_TESTCASE(cos) diff --git a/test/relu.cpp b/test/relu.cpp new file mode 100644 index 000000000..a32a380c3 --- /dev/null +++ b/test/relu.cpp @@ -0,0 +1,288 @@ +// Copyright 2020, 2021, 2022, 2023 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com) +// +// This file is part of the heyoka library. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#if defined(HEYOKA_HAVE_REAL128) + +#include + +#endif + +#if defined(HEYOKA_HAVE_REAL) + +#include + +#endif + +#include +#include +#include +#include + +#include "catch.hpp" +#include "test_utils.hpp" + +static std::mt19937 rng; + +using namespace heyoka; +using namespace heyoka_test; + +const auto fp_types = std::tuple{}; + +constexpr bool skip_batch_ld = +#if LLVM_VERSION_MAJOR >= 13 && LLVM_VERSION_MAJOR <= 17 + std::numeric_limits::digits == 64 +#else + false +#endif + ; + +template +T cpp_relu(T x) +{ + return x > 0 ? x : T(0); +} + +template +T cpp_relup(T x) +{ + return x > 0 ? T(1) : T(0); +} + +TEST_CASE("def ctor") +{ + { + detail::relu_impl k; + + REQUIRE(k.args().size() == 1u); + REQUIRE(k.args()[0] == 0_dbl); + } + + { + detail::relup_impl k; + + REQUIRE(k.args().size() == 1u); + REQUIRE(k.args()[0] == 0_dbl); + } +} + +TEST_CASE("diff") +{ + auto [x, y] = make_vars("x", "y"); + + REQUIRE(diff(relu(x), x) == relup(x)); + REQUIRE(diff(relup(x), x) == 0_dbl); + + REQUIRE(diff(relu(x * y), x) == y * relup(x * y)); + REQUIRE(diff(relu(x * y), par[0]) == 0_dbl); + REQUIRE(diff(relu(x * par[0]), par[0]) == x * relup(x * par[0])); + + REQUIRE(diff(relup(x * y), x) == 0_dbl); + REQUIRE(diff(relup(x * y), par[0]) == 0_dbl); + REQUIRE(diff(relup(x * par[0]), par[0]) == 0_dbl); +} + +TEST_CASE("constant fold") +{ + REQUIRE(relu(1.1_dbl) == 1.1_dbl); + REQUIRE(relu(-1.1_dbl) == 0_dbl); + + REQUIRE(relup(1.1_dbl) == 1_dbl); + REQUIRE(relup(-1.1_dbl) == 0_dbl); +} + +TEST_CASE("s11n") +{ + { + std::stringstream ss; + + auto [x, y] = make_vars("x", "y"); + + auto ex = relu(x + y); + + { + boost::archive::binary_oarchive oa(ss); + + oa << ex; + } + + ex = 0_dbl; + + { + boost::archive::binary_iarchive ia(ss); + + ia >> ex; + } + + REQUIRE(ex == relu(x + y)); + } + + { + std::stringstream ss; + + auto [x, y] = make_vars("x", "y"); + + auto ex = relup(x + y); + + { + boost::archive::binary_oarchive oa(ss); + + oa << ex; + } + + ex = 0_dbl; + + { + boost::archive::binary_iarchive ia(ss); + + ia >> ex; + } + + REQUIRE(ex == relup(x + y)); + } +} + +TEST_CASE("cfunc") +{ + auto tester = [](auto fp_x, unsigned opt_level, bool high_accuracy, bool compact_mode) { + using fp_t = decltype(fp_x); + + auto [x] = make_vars("x"); + + std::uniform_real_distribution x_dist(-10, 10); + + std::vector outs, ins, pars; + + for (auto batch_size : {1u, 2u, 4u, 5u}) { + if (batch_size != 1u && std::is_same_v && skip_batch_ld) { + continue; + } + + outs.resize(batch_size * 4u); + ins.resize(batch_size); + pars.resize(batch_size * 2u); + + llvm_state s{kw::opt_level = opt_level}; + + add_cfunc(s, "cfunc", {relu(x), relu(par[0]), relup(x), relup(par[1])}, kw::batch_size = batch_size, + kw::high_accuracy = high_accuracy, kw::compact_mode = compact_mode); + + if (opt_level == 0u && compact_mode) { + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.relu.")); + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.relup.")); + } + + s.compile(); + + auto *cf_ptr + = reinterpret_cast(s.jit_lookup("cfunc")); + + for (auto niter = 0; niter < 100; ++niter) { + for (auto i = 0u; i < batch_size; ++i) { + // Generate the xs. + ins[i] = x_dist(rng); + + // Generate the pars. + pars[i] = x_dist(rng); + pars[i + batch_size] = x_dist(rng); + } + + cf_ptr(outs.data(), ins.data(), pars.data(), nullptr); + + for (auto i = 0u; i < batch_size; ++i) { + REQUIRE(outs[i] == cpp_relu(ins[i])); + REQUIRE(outs[i + batch_size] == cpp_relu(pars[i])); + REQUIRE(outs[i + 2u * batch_size] == cpp_relup(ins[i])); + REQUIRE(outs[i + 3u * batch_size] == cpp_relup(pars[i + batch_size])); + } + } + } + }; + + for (auto cm : {false, true}) { + for (auto f : {false, true}) { + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 0, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 1, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 2, f, cm); }); + tuple_for_each(fp_types, [&tester, f, cm](auto x) { tester(x, 3, f, cm); }); + } + } +} + +#if defined(HEYOKA_HAVE_REAL) + +TEST_CASE("cfunc mp") +{ + using fp_t = mppp::real; + + const auto prec = 237; + + auto [x] = make_vars("x"); + + std::uniform_real_distribution x_dist(-10, 10); + + std::vector outs, ins, pars; + + outs.resize(4u, mppp::real{0, prec}); + ins.resize(1u); + pars.resize(2u); + + for (auto compact_mode : {false, true}) { + for (auto opt_level : {0u, 1u, 2u, 3u}) { + llvm_state s{kw::opt_level = opt_level}; + + add_cfunc(s, "cfunc", {relu(x), relu(par[0]), relup(x), relup(par[1])}, + kw::compact_mode = compact_mode, kw::prec = prec); + + if (opt_level == 0u && compact_mode) { + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.relu.")); + REQUIRE(boost::contains(s.get_ir(), "heyoka.llvm_c_eval.relup.")); + } + + s.compile(); + + auto *cf_ptr + = reinterpret_cast(s.jit_lookup("cfunc")); + + // Generate the x and pars. + ins[0] = mppp::real{x_dist(rng), prec}; + pars[0] = mppp::real{x_dist(rng), prec}; + pars[1] = mppp::real{x_dist(rng), prec}; + + cf_ptr(outs.data(), ins.data(), pars.data(), nullptr); + + REQUIRE(outs[0] == cpp_relu(ins[0])); + REQUIRE(outs[1] == cpp_relu(pars[0])); + REQUIRE(outs[2] == cpp_relup(ins[0])); + REQUIRE(outs[3] == cpp_relup(pars[1])); + } + } +} + +#endif