diff --git a/src/expression_diff.cpp b/src/expression_diff.cpp index f5f5f18e5..e96f519b6 100644 --- a/src/expression_diff.cpp +++ b/src/expression_diff.cpp @@ -1468,14 +1468,23 @@ std::vector dtens::get_jacobian() const std::uint32_t dtens::get_nvars() const { - if (p_impl->m_map.empty()) { - return 0; - } - // NOTE: we ensure in the diff_tensors() implementation // that the number of diff variables is representable // by std::uint32_t. - return static_cast(begin()->first.size() - 1u); + auto ret = static_cast(get_args().size()); + +#if !defined(NDEBUG) + + if (p_impl->m_map.empty()) { + assert(ret == 0u); + } else { + assert(!begin()->first.empty()); + assert(ret == begin()->first.size() - 1u); + } + +#endif + + return ret; } std::uint32_t dtens::get_nouts() const diff --git a/test/expression_diff_tensors.cpp b/test/expression_diff_tensors.cpp index b50592fa2..aa04923c4 100644 --- a/test/expression_diff_tensors.cpp +++ b/test/expression_diff_tensors.cpp @@ -109,6 +109,7 @@ TEST_CASE("diff_tensors basic") }; REQUIRE(dt.size() == 2u); + REQUIRE(dt.get_nvars() == 1u); assign_sr(dt.get_derivatives(0, 0)); REQUIRE(diff_vec.size() == 1u); @@ -127,6 +128,7 @@ TEST_CASE("diff_tensors basic") Message(fmt::format("Cannot locate the derivative corresponding to the indices vector {}", std::vector{0, 2}))); dt = diff_tensors({1_dbl}, kw::diff_order = 2, kw::diff_args = {par[0]}); + REQUIRE(dt.get_nvars() == 1u); REQUIRE(dt.size() == 3u); assign_sr(dt.get_derivatives(0, 0)); REQUIRE(diff_vec.size() == 1u); @@ -149,6 +151,7 @@ TEST_CASE("diff_tensors basic") Message(fmt::format("Cannot locate the derivative corresponding to the indices vector {}", std::vector{0, 3}))); dt = diff_tensors({1_dbl}, kw::diff_order = 3, kw::diff_args = {par[0]}); + REQUIRE(dt.get_nvars() == 1u); REQUIRE(dt.size() == 4u); assign_sr(dt.get_derivatives(0, 0)); REQUIRE(diff_vec.size() == 1u); @@ -176,6 +179,7 @@ TEST_CASE("diff_tensors basic") // Automatically deduced diff variables. dt = diff_tensors({x + y, x * y * y}, kw::diff_order = 2); + REQUIRE(dt.get_nvars() == 2u); REQUIRE(dt.size() == 12u); assign_sr(dt.get_derivatives(0)); REQUIRE(diff_vec == std::vector{x + y, x * y * y}); @@ -186,6 +190,7 @@ TEST_CASE("diff_tensors basic") // Diff wrt all variables. dt = diff_tensors({x + y, x * y * y}, kw::diff_order = 2, kw::diff_args = diff_args::vars); + REQUIRE(dt.get_nvars() == 2u); REQUIRE(dt.size() == 12u); assign_sr(dt.get_derivatives(0)); REQUIRE(diff_vec == std::vector{x + y, x * y * y}); @@ -196,6 +201,7 @@ TEST_CASE("diff_tensors basic") // Diff wrt some variables. dt = diff_tensors({x + y, x * y * y}, kw::diff_order = 2, kw::diff_args = {x}); + REQUIRE(dt.get_nvars() == 1u); REQUIRE(dt.size() == 6u); assign_sr(dt.get_derivatives(0)); REQUIRE(diff_vec == std::vector{x + y, x * y * y}); @@ -206,6 +212,7 @@ TEST_CASE("diff_tensors basic") // Diff wrt all params. dt = diff_tensors({par[0] + y, x * y * par[1]}, kw::diff_order = 2, kw::diff_args = diff_args::params); + REQUIRE(dt.get_nvars() == 2u); REQUIRE(dt.size() == 12u); assign_sr(dt.get_derivatives(0)); REQUIRE(diff_vec == std::vector{par[0] + y, x * y * par[1]}); @@ -216,6 +223,7 @@ TEST_CASE("diff_tensors basic") // Diff wrt some param. dt = diff_tensors({par[0] + y, x * y * par[1]}, kw::diff_order = 2, kw::diff_args = {par[1]}); + REQUIRE(dt.get_nvars() == 1u); REQUIRE(dt.size() == 6u); assign_sr(dt.get_derivatives(0)); REQUIRE(diff_vec == std::vector{par[0] + y, x * y * par[1]});