Skip to content

Commit

Permalink
Simplification, small test additions.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Oct 30, 2023
1 parent 1d63ee4 commit 442ac24
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
19 changes: 14 additions & 5 deletions src/expression_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1468,14 +1468,23 @@ std::vector<expression> dtens::get_jacobian() const

std::uint32_t dtens::get_nvars() const
{
if (p_impl->m_map.empty()) {
return 0;
}

// NOTE: we ensure in the diff_tensors() implementation
// that the number of diff variables is representable
// by std::uint32_t.
return static_cast<std::uint32_t>(begin()->first.size() - 1u);
auto ret = static_cast<std::uint32_t>(get_args().size());

#if !defined(NDEBUG)

if (p_impl->m_map.empty()) {
assert(ret == 0u);
} else {
assert(!begin()->first.empty());
assert(ret == begin()->first.size() - 1u);
}

#endif

return ret;
}

std::uint32_t dtens::get_nouts() const
Expand Down
8 changes: 8 additions & 0 deletions test/expression_diff_tensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ TEST_CASE("diff_tensors basic")
};

REQUIRE(dt.size() == 2u);
REQUIRE(dt.get_nvars() == 1u);

assign_sr(dt.get_derivatives(0, 0));
REQUIRE(diff_vec.size() == 1u);
Expand All @@ -127,6 +128,7 @@ TEST_CASE("diff_tensors basic")
Message(fmt::format("Cannot locate the derivative corresponding to the indices vector {}", std::vector{0, 2})));

dt = diff_tensors({1_dbl}, kw::diff_order = 2, kw::diff_args = {par[0]});
REQUIRE(dt.get_nvars() == 1u);
REQUIRE(dt.size() == 3u);
assign_sr(dt.get_derivatives(0, 0));
REQUIRE(diff_vec.size() == 1u);
Expand All @@ -149,6 +151,7 @@ TEST_CASE("diff_tensors basic")
Message(fmt::format("Cannot locate the derivative corresponding to the indices vector {}", std::vector{0, 3})));

dt = diff_tensors({1_dbl}, kw::diff_order = 3, kw::diff_args = {par[0]});
REQUIRE(dt.get_nvars() == 1u);
REQUIRE(dt.size() == 4u);
assign_sr(dt.get_derivatives(0, 0));
REQUIRE(diff_vec.size() == 1u);
Expand Down Expand Up @@ -176,6 +179,7 @@ TEST_CASE("diff_tensors basic")

// Automatically deduced diff variables.
dt = diff_tensors({x + y, x * y * y}, kw::diff_order = 2);
REQUIRE(dt.get_nvars() == 2u);
REQUIRE(dt.size() == 12u);
assign_sr(dt.get_derivatives(0));
REQUIRE(diff_vec == std::vector{x + y, x * y * y});
Expand All @@ -186,6 +190,7 @@ TEST_CASE("diff_tensors basic")

// Diff wrt all variables.
dt = diff_tensors({x + y, x * y * y}, kw::diff_order = 2, kw::diff_args = diff_args::vars);
REQUIRE(dt.get_nvars() == 2u);
REQUIRE(dt.size() == 12u);
assign_sr(dt.get_derivatives(0));
REQUIRE(diff_vec == std::vector{x + y, x * y * y});
Expand All @@ -196,6 +201,7 @@ TEST_CASE("diff_tensors basic")

// Diff wrt some variables.
dt = diff_tensors({x + y, x * y * y}, kw::diff_order = 2, kw::diff_args = {x});
REQUIRE(dt.get_nvars() == 1u);
REQUIRE(dt.size() == 6u);
assign_sr(dt.get_derivatives(0));
REQUIRE(diff_vec == std::vector{x + y, x * y * y});
Expand All @@ -206,6 +212,7 @@ TEST_CASE("diff_tensors basic")

// Diff wrt all params.
dt = diff_tensors({par[0] + y, x * y * par[1]}, kw::diff_order = 2, kw::diff_args = diff_args::params);
REQUIRE(dt.get_nvars() == 2u);
REQUIRE(dt.size() == 12u);
assign_sr(dt.get_derivatives(0));
REQUIRE(diff_vec == std::vector{par[0] + y, x * y * par[1]});
Expand All @@ -216,6 +223,7 @@ TEST_CASE("diff_tensors basic")

// Diff wrt some param.
dt = diff_tensors({par[0] + y, x * y * par[1]}, kw::diff_order = 2, kw::diff_args = {par[1]});
REQUIRE(dt.get_nvars() == 1u);
REQUIRE(dt.size() == 6u);
assign_sr(dt.get_derivatives(0));
REQUIRE(diff_vec == std::vector{par[0] + y, x * y * par[1]});
Expand Down

0 comments on commit 442ac24

Please sign in to comment.