From 693ec36e1d483156bfcffdc10295693a8466698c Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Sat, 4 Nov 2023 13:52:30 +0100 Subject: [PATCH] Add slope getters. --- include/heyoka/math/relu.hpp | 4 ++++ src/math/relu.cpp | 10 ++++++++++ test/relu.cpp | 4 ++++ 3 files changed, 18 insertions(+) diff --git a/include/heyoka/math/relu.hpp b/include/heyoka/math/relu.hpp index da2541b17..9512b46e8 100644 --- a/include/heyoka/math/relu.hpp +++ b/include/heyoka/math/relu.hpp @@ -41,6 +41,8 @@ class HEYOKA_DLL_PUBLIC relu_impl : public func_base relu_impl(); explicit relu_impl(expression, double); + [[nodiscard]] double get_slope() const noexcept; + void to_stream(std::ostringstream &) const; [[nodiscard]] expression normalise() const; @@ -75,6 +77,8 @@ class HEYOKA_DLL_PUBLIC relup_impl : public func_base relup_impl(); explicit relup_impl(expression, double); + [[nodiscard]] double get_slope() const noexcept; + void to_stream(std::ostringstream &) const; [[nodiscard]] expression normalise() const; diff --git a/src/math/relu.cpp b/src/math/relu.cpp index 9c57a84ef..9fcba480b 100644 --- a/src/math/relu.cpp +++ b/src/math/relu.cpp @@ -85,6 +85,11 @@ relu_impl::relu_impl(expression ex, double slope) relu_slope_check(slope); } +double relu_impl::get_slope() const noexcept +{ + return m_slope; +} + void relu_impl::to_stream(std::ostringstream &oss) const { assert(args().size() == 1u); @@ -354,6 +359,11 @@ relup_impl::relup_impl(expression ex, double slope) relu_slope_check(slope); } +double relup_impl::get_slope() const noexcept +{ + return m_slope; +} + void relup_impl::to_stream(std::ostringstream &oss) const { assert(args().size() == 1u); diff --git a/test/relu.cpp b/test/relu.cpp index 8c1c74073..5e57682ec 100644 --- a/test/relu.cpp +++ b/test/relu.cpp @@ -88,6 +88,7 @@ TEST_CASE("def ctor") REQUIRE(k.args().size() == 1u); REQUIRE(k.args()[0] == 0_dbl); REQUIRE(k.get_name() == "relu"); + REQUIRE(k.get_slope() == 0.); } { @@ -96,6 +97,7 @@ TEST_CASE("def ctor") REQUIRE(k.args().size() == 1u); REQUIRE(k.args()[0] == 0_dbl); REQUIRE(k.get_name() == "relup"); + REQUIRE(k.get_slope() == 0.); } } @@ -154,12 +156,14 @@ TEST_CASE("names") auto ex = relu("x"_var, 1.); REQUIRE(std::get(ex.value()).get_name() != "relu"); REQUIRE(boost::starts_with(std::get(ex.value()).get_name(), "relu_0x")); + REQUIRE(std::get(ex.value()).extract()->get_slope() == 1); } { auto ex = relup("x"_var, 1.); REQUIRE(std::get(ex.value()).get_name() != "relup"); REQUIRE(boost::starts_with(std::get(ex.value()).get_name(), "relup_0x")); + REQUIRE(std::get(ex.value()).extract()->get_slope() == 1); } }