From abd3c0c85a0ac49c6ccdd1d6ea684a139a4c264a Mon Sep 17 00:00:00 2001 From: Francesco Biscani Date: Sun, 5 Nov 2023 12:08:12 +0100 Subject: [PATCH] Additional testing for relu/relup. --- test/relu.cpp | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/test/relu.cpp b/test/relu.cpp index 5e57682ec..51e72c134 100644 --- a/test/relu.cpp +++ b/test/relu.cpp @@ -8,6 +8,7 @@ #include +#include #include #include #include @@ -167,6 +168,32 @@ TEST_CASE("names") } } +// Test to check that equality, hashing and less-than, which take into account +// the function name, behave correctly when changing slope. +TEST_CASE("hash eq lt") +{ + auto [x, y] = make_vars("x", "y"); + + REQUIRE(relu(x + y) != relu(x + y, 0.01)); + REQUIRE(relup(x + y) != relup(x + y, 0.01)); + REQUIRE(relu(x + y, 0.02) != relu(x + y, 0.01)); + REQUIRE(relup(x + y, 0.02) != relup(x + y, 0.01)); + REQUIRE((std::get(relu(x + y).value()) < std::get(relu(x + y, 0.01).value()) + || std::get(relu(x + y, 0.01).value()) < std::get(relu(x + y).value()))); + REQUIRE((std::get(relup(x + y).value()) < std::get(relup(x + y, 0.01).value()) + || std::get(relup(x + y, 0.01).value()) < std::get(relup(x + y).value()))); + REQUIRE((std::get(relu(x + y, 0.02).value()) < std::get(relu(x + y, 0.01).value()) + || std::get(relu(x + y, 0.01).value()) < std::get(relu(x + y, 0.02).value()))); + REQUIRE((std::get(relup(x + y, 0.02).value()) < std::get(relup(x + y, 0.01).value()) + || std::get(relup(x + y, 0.01).value()) < std::get(relup(x + y, 0.02).value()))); + + // Of course, not 100% guaranteed but hopefully very likely. + REQUIRE(std::hash{}(relu(x + y)) != std::hash{}(relu(x + y, 0.01))); + REQUIRE(std::hash{}(relup(x + y)) != std::hash{}(relup(x + y, 0.01))); + REQUIRE(std::hash{}(relu(x + y, 0.02)) != std::hash{}(relu(x + y, 0.01))); + REQUIRE(std::hash{}(relup(x + y, 0.02)) != std::hash{}(relup(x + y, 0.01))); +} + TEST_CASE("invalid slopes") { using Catch::Matchers::Message;