Skip to content

Commit

Permalink
Merge pull request #72 from bluescarni/pr/new_prime
Browse files Browse the repository at this point in the history
Test fix, simplification
  • Loading branch information
bluescarni authored Jan 6, 2021
2 parents 62d956c + f0e9d98 commit 92c3c2c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion benchmark/n_body_creation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ int main(int argc, char *argv[])

auto counter = 0u;
for (const auto &ex : ta.get_decomposition()) {
std::cout << "u_" << counter++ << " = " << ex << '\n';
std::cout << "u_" << counter++ << " = " << ex.first << '\n';
}

std::cout << "Construction time: " << elapsed << "ms\n";
Expand Down
8 changes: 8 additions & 0 deletions src/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ expression operator+(expression e1, expression e2)
// e1 + 0 = e1.
return expression{std::forward<decltype(v1)>(v1)};
}
if (std::visit([](const auto &x) { return x < 0; }, v2.value())) {
// e1 + -x = e1 - x.
return expression{std::forward<decltype(v1)>(v1)} - expression{-std::forward<decltype(v2)>(v2)};
}
// NOTE: fall through the standard case if e2 is not zero.
}

Expand Down Expand Up @@ -257,6 +261,10 @@ expression operator-(expression e1, expression e2)
// e1 - 0 = e1.
return expression{std::forward<decltype(v1)>(v1)};
}
if (std::visit([](const auto &x) { return x < 0; }, v2.value())) {
// e1 - -x = e1 + x.
return expression{std::forward<decltype(v1)>(v1)} + expression{-std::forward<decltype(v2)>(v2)};
}
// NOTE: fall through the standard case if e2 is not zero.
}

Expand Down
8 changes: 8 additions & 0 deletions test/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,11 @@ TEST_CASE("get_param_size")
REQUIRE(get_param_size(par[122] + par[123]) == 124u);
REQUIRE(get_param_size(par[500] - sin(cos(par[1] + "y"_var) + par[4])) == 501u);
}

TEST_CASE("binary simpls")
{
auto [x, y] = make_vars("x", "y");

REQUIRE(x + -1. == x - 1.);
REQUIRE(y - -1. == y + 1.);
}

0 comments on commit 92c3c2c

Please sign in to comment.