Skip to content

Commit

Permalink
Add slope getters.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Nov 4, 2023
1 parent b9f1038 commit 693ec36
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
4 changes: 4 additions & 0 deletions include/heyoka/math/relu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class HEYOKA_DLL_PUBLIC relu_impl : public func_base
relu_impl();
explicit relu_impl(expression, double);

[[nodiscard]] double get_slope() const noexcept;

void to_stream(std::ostringstream &) const;

[[nodiscard]] expression normalise() const;
Expand Down Expand Up @@ -75,6 +77,8 @@ class HEYOKA_DLL_PUBLIC relup_impl : public func_base
relup_impl();
explicit relup_impl(expression, double);

[[nodiscard]] double get_slope() const noexcept;

void to_stream(std::ostringstream &) const;

[[nodiscard]] expression normalise() const;
Expand Down
10 changes: 10 additions & 0 deletions src/math/relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ relu_impl::relu_impl(expression ex, double slope)
relu_slope_check(slope);
}

double relu_impl::get_slope() const noexcept
{
return m_slope;
}

void relu_impl::to_stream(std::ostringstream &oss) const
{
assert(args().size() == 1u);
Expand Down Expand Up @@ -354,6 +359,11 @@ relup_impl::relup_impl(expression ex, double slope)
relu_slope_check(slope);
}

double relup_impl::get_slope() const noexcept
{
return m_slope;
}

void relup_impl::to_stream(std::ostringstream &oss) const
{
assert(args().size() == 1u);
Expand Down
4 changes: 4 additions & 0 deletions test/relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ TEST_CASE("def ctor")
REQUIRE(k.args().size() == 1u);
REQUIRE(k.args()[0] == 0_dbl);
REQUIRE(k.get_name() == "relu");
REQUIRE(k.get_slope() == 0.);
}

{
Expand All @@ -96,6 +97,7 @@ TEST_CASE("def ctor")
REQUIRE(k.args().size() == 1u);
REQUIRE(k.args()[0] == 0_dbl);
REQUIRE(k.get_name() == "relup");
REQUIRE(k.get_slope() == 0.);
}
}

Expand Down Expand Up @@ -154,12 +156,14 @@ TEST_CASE("names")
auto ex = relu("x"_var, 1.);
REQUIRE(std::get<func>(ex.value()).get_name() != "relu");
REQUIRE(boost::starts_with(std::get<func>(ex.value()).get_name(), "relu_0x"));
REQUIRE(std::get<func>(ex.value()).extract<detail::relu_impl>()->get_slope() == 1);
}

{
auto ex = relup("x"_var, 1.);
REQUIRE(std::get<func>(ex.value()).get_name() != "relup");
REQUIRE(boost::starts_with(std::get<func>(ex.value()).get_name(), "relup_0x"));
REQUIRE(std::get<func>(ex.value()).extract<detail::relup_impl>()->get_slope() == 1);
}
}

Expand Down

0 comments on commit 693ec36

Please sign in to comment.