diff --git a/benchmark/n_body_creation.cpp b/benchmark/n_body_creation.cpp index 1c2de46ae..b2d55192d 100644 --- a/benchmark/n_body_creation.cpp +++ b/benchmark/n_body_creation.cpp @@ -66,7 +66,7 @@ int main(int argc, char *argv[]) auto counter = 0u; for (const auto &ex : ta.get_decomposition()) { - std::cout << "u_" << counter++ << " = " << ex << '\n'; + std::cout << "u_" << counter++ << " = " << ex.first << '\n'; } std::cout << "Construction time: " << elapsed << "ms\n"; diff --git a/src/expression.cpp b/src/expression.cpp index 3d769ad36..deb30acce 100644 --- a/src/expression.cpp +++ b/src/expression.cpp @@ -224,6 +224,10 @@ expression operator+(expression e1, expression e2) // e1 + 0 = e1. return expression{std::forward(v1)}; } + if (std::visit([](const auto &x) { return x < 0; }, v2.value())) { + // e1 + -x = e1 - x. + return expression{std::forward(v1)} - expression{-std::forward(v2)}; + } // NOTE: fall through the standard case if e2 is not zero. } @@ -257,6 +261,10 @@ expression operator-(expression e1, expression e2) // e1 - 0 = e1. return expression{std::forward(v1)}; } + if (std::visit([](const auto &x) { return x < 0; }, v2.value())) { + // e1 - -x = e1 + x. + return expression{std::forward(v1)} + expression{-std::forward(v2)}; + } // NOTE: fall through the standard case if e2 is not zero. } diff --git a/test/expression.cpp b/test/expression.cpp index 1b8afc8b2..639bdc51c 100644 --- a/test/expression.cpp +++ b/test/expression.cpp @@ -396,3 +396,11 @@ TEST_CASE("get_param_size") REQUIRE(get_param_size(par[122] + par[123]) == 124u); REQUIRE(get_param_size(par[500] - sin(cos(par[1] + "y"_var) + par[4])) == 501u); } + +TEST_CASE("binary simpls") +{ + auto [x, y] = make_vars("x", "y"); + + REQUIRE(x + -1. == x - 1.); + REQUIRE(y - -1. == y + 1.); +}