Skip to content

Commit

Permalink
Simplify several checks around mdspans.
Browse files Browse the repository at this point in the history
It is really not up to us to check for nonempty nullptr mdspans, as the creation of a nonempty
nullptr mdspan is itself UB. Thus, the only checks we need are on the shape of the mdspan.
  • Loading branch information
bluescarni committed Jul 18, 2024
1 parent dae7cd0 commit 86c25bb
Show file tree
Hide file tree
Showing 6 changed files with 1 addition and 130 deletions.
3 changes: 0 additions & 3 deletions include/heyoka/model/sgp4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ class HEYOKA_DLL_PUBLIC_INLINE_CLASS sgp4_propagator
template <typename Input, typename... KwArgs>
static auto parse_ctor_args(const Input &in, const KwArgs &...kw_args)
{
if (in.data_handle() == nullptr) [[unlikely]] {
throw std::invalid_argument("Cannot initialise an sgp4_propagator with a null list of satellites");
}
if (in.extent(1) == 0u) [[unlikely]] {
throw std::invalid_argument("Cannot initialise an sgp4_propagator with an empty list of satellites");
}
Expand Down
34 changes: 0 additions & 34 deletions src/cfunc_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,21 +452,12 @@ void cfunc<T>::single_eval(out_1d outputs, in_1d inputs, std::optional<in_1d> pa
m_impl->m_nouts, outputs.size()));
}

if (outputs.data_handle() == nullptr) [[unlikely]] {
throw std::invalid_argument("The outputs array passed to a cfunc cannot be null");
}

if (inputs.size() != m_impl->m_nvars) [[unlikely]] {
throw std::invalid_argument(fmt::format("Invalid inputs array passed to a cfunc: the number of function "
"inputs is {}, but the inputs array has a size of {}",
m_impl->m_nvars, inputs.size()));
}

if (inputs.data_handle() == nullptr && !inputs.empty()) [[unlikely]] {
throw std::invalid_argument(
"The inputs array passed to a cfunc can be null only if the number of input arguments is zero");
}

if (m_impl->m_nparams != 0u && !pars) [[unlikely]] {
throw std::invalid_argument(
"An array of parameter values must be passed in order to evaluate a function with parameters");
Expand All @@ -479,11 +470,6 @@ void cfunc<T>::single_eval(out_1d outputs, in_1d inputs, std::optional<in_1d> pa
"but the number of parameters in the function is {}",
pars->size(), m_impl->m_nparams));
}

if (pars->data_handle() == nullptr && !pars->empty()) [[unlikely]] {
throw std::invalid_argument(
"The array of parameter values passed to a cfunc can be null only if the number of parameters is zero");
}
}

if (m_impl->m_is_time_dependent && !time) [[unlikely]] {
Expand Down Expand Up @@ -719,11 +705,6 @@ void cfunc<T>::multi_eval(out_2d outputs, in_2d inputs, std::optional<in_2d> par
m_impl->m_nouts, outputs.extent(0)));
}

if (outputs.data_handle() == nullptr && !outputs.empty()) [[unlikely]] {
throw std::invalid_argument(
"The outputs array passed to a cfunc can be null only if the number of evaluations is zero");
}

// Fetch the number of columns from outputs.
const auto ncols = outputs.extent(1);

Expand All @@ -740,11 +721,6 @@ void cfunc<T>::multi_eval(out_2d outputs, in_2d inputs, std::optional<in_2d> par
ncols, inputs.extent(1)));
}

if (inputs.data_handle() == nullptr && !inputs.empty()) [[unlikely]] {
throw std::invalid_argument("The inputs array passed to a cfunc can be null only if the number of input "
"arguments or the number of evaluations is zero");
}

if (m_impl->m_nparams != 0u && !pars) [[unlikely]] {
throw std::invalid_argument(
"An array of parameter values must be passed in order to evaluate a function with parameters");
Expand All @@ -765,11 +741,6 @@ void cfunc<T>::multi_eval(out_2d outputs, in_2d inputs, std::optional<in_2d> par
"outputs array is {}",
pars->extent(1), ncols));
}

if (pars->data_handle() == nullptr && !pars->empty()) [[unlikely]] {
throw std::invalid_argument("The array of parameter values passed to a cfunc can be null only if the "
"number of parameters or the number of evaluations is zero");
}
}

if (m_impl->m_is_time_dependent && !times) [[unlikely]] {
Expand All @@ -785,11 +756,6 @@ void cfunc<T>::multi_eval(out_2d outputs, in_2d inputs, std::optional<in_2d> par
"outputs array is {}",
times->size(), ncols));
}

if (times->data_handle() == nullptr && !times->empty()) [[unlikely]] {
throw std::invalid_argument("The array of time values passed to a cfunc can be null only if the "
"number of evaluations is zero");
}
}

#if defined(HEYOKA_HAVE_REAL)
Expand Down
28 changes: 0 additions & 28 deletions src/model/sgp4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,10 +701,6 @@ void sgp4_propagator<T>::replace_sat_data(mdspan<const T, extents<std::size_t, 9
// Cache nsats.
const auto nsats = get_nsats();

if (new_data.data_handle() == nullptr) [[unlikely]] {
throw std::invalid_argument("Cannot replace the satellite data with a null array");
}

if (new_data.extent(1) != nsats) [[unlikely]] {
throw std::invalid_argument(fmt::format("Invalid array provided to replace_sat_data(): the number of "
"columns ({}) does not match the number of satellites ({})",
Expand Down Expand Up @@ -900,9 +896,6 @@ template <typename T>
void sgp4_propagator<T>::operator()(out_2d out, in_1d<date> dates)
{
// Check the dates array.
if (dates.data_handle() == nullptr) [[unlikely]] {
throw std::invalid_argument("A null array of dates was passed to the call operator of an sgp4_propagator");
}
const auto n_sats = get_nsats();
if (dates.extent(0) != n_sats) [[unlikely]] {
throw std::invalid_argument(
Expand Down Expand Up @@ -948,23 +941,6 @@ template <typename T>
requires std::same_as<T, double> || std::same_as<T, float>
void sgp4_propagator<T>::operator()(out_3d out, in_2d<T> tms)
{
// NOTE: need to check for nullptr input spans. In the non-batch overload
// we do not need the explicit check because we don't do anything with 'out'
// and 'tms' apart from forwarding them to the cfunc, which does the nullptr check.
// Here however we need to take subspans of 'out' and 'tms' and thus we need to
// pre-check for nullptr in order to avoid undefined behaviour - see the docs for
// the def ctor of mdspan:
//
// https://en.cppreference.com/w/cpp/container/mdspan/mdspan
if (out.data_handle() == nullptr) [[unlikely]] {
throw std::invalid_argument(
"A null output array was passed to the batch-mode call operator of an sgp4_propagator");
}
if (tms.data_handle() == nullptr) [[unlikely]] {
throw std::invalid_argument(
"A null times array was passed to the batch-mode call operator of an sgp4_propagator");
}

// Check the dimensionalities of out and tms.
const auto n_evals = out.extent(0);
if (n_evals != tms.extent(0)) [[unlikely]] {
Expand Down Expand Up @@ -1039,10 +1015,6 @@ template <typename T>
void sgp4_propagator<T>::operator()(out_3d out, in_2d<date> dates)
{
// Check the dates array.
if (dates.data_handle() == nullptr) [[unlikely]] {
throw std::invalid_argument(
"A null array of dates was passed to the batch-mode call operator of an sgp4_propagator");
}
const auto n_sats = get_nsats();
if (dates.extent(1) != n_sats) [[unlikely]] {
throw std::invalid_argument(fmt::format(
Expand Down
22 changes: 0 additions & 22 deletions test/cfunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,19 +350,6 @@ TEST_CASE("single call operator")
REQUIRE(output2[1] == -8);
std::ranges::fill(output2, fp_t(0));

// Null output span.
REQUIRE_THROWS_MATCHES(
cf0(typename cfunc<fp_t>::out_1d{nullptr, 2u}, std::array<fp_t, 0>{}, kw::pars = par1, kw::time = fp_t(10)),
std::invalid_argument, Message("The outputs array passed to a cfunc cannot be null"));

// Null input span with inputs.
cf0 = cfunc<fp_t>({x + y - heyoka::time, x - y + par[0]}, {x, y}, kw::opt_level = opt_level,
kw::high_accuracy = high_accuracy, kw::compact_mode = compact_mode);
REQUIRE_THROWS_MATCHES(
cf0(output2, typename cfunc<fp_t>::in_1d{nullptr, 2}, kw::pars = par1, kw::time = fp_t(10)),
std::invalid_argument,
Message("The inputs array passed to a cfunc can be null only if the number of input arguments is zero"));

// Null par span with no pars.
cf0 = cfunc<fp_t>({x + y - heyoka::time, x - y}, {x, y}, kw::opt_level = opt_level,
kw::high_accuracy = high_accuracy, kw::compact_mode = compact_mode);
Expand All @@ -373,15 +360,6 @@ TEST_CASE("single call operator")
REQUIRE(output2[0] == -7);
REQUIRE(output2[1] == -1);
std::ranges::fill(output2, fp_t(0));

// Null par span with pars.
cf0 = cfunc<fp_t>({x + y - heyoka::time, x - y + par[0]}, {x, y}, kw::opt_level = opt_level,
kw::high_accuracy = high_accuracy, kw::compact_mode = compact_mode);
REQUIRE_THROWS_MATCHES(cf0(output2, std::array<fp_t, 2>{1, 2},
kw::pars = typename cfunc<fp_t>::in_1d{nullptr, 1}, kw::time = fp_t(10)),
std::invalid_argument,
Message("The array of parameter values passed to a cfunc can be null only if the number "
"of parameters is zero"));
};

for (auto cm : {false, true}) {
Expand Down
24 changes: 0 additions & 24 deletions test/cfunc_multieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,6 @@ TEST_CASE("multieval st")
// Check no error on zero nevals with null outputs span.
REQUIRE_NOTHROW(cf0(out_2d{nullptr, 2, 0}, in_2d{ibuf.data(), 2, 0}));

// Check error with null outputs span and nonzero evals.
REQUIRE_THROWS_MATCHES(
cf0(out_2d{nullptr, 2, 10}, in_2d{ibuf.data(), 0, 0}), std::invalid_argument,
Message("The outputs array passed to a cfunc can be null only if the number of evaluations is zero"));

obuf.resize(20u);

REQUIRE_THROWS_MATCHES(cf0(out_2d{obuf.data(), 2, 10}, in_2d{ibuf.data(), 0, 0}), std::invalid_argument,
Expand All @@ -113,11 +108,6 @@ TEST_CASE("multieval st")
Message("Invalid inputs array passed to a cfunc: the expected number of columns deduced from the "
"outputs array is 10, but the number of columns in the inputs array is 5"));

// Null input span.
REQUIRE_THROWS_MATCHES(cf0(out_2d{obuf.data(), 2, 10}, in_2d{nullptr, 2, 10}), std::invalid_argument,
Message("The inputs array passed to a cfunc can be null only if the number of input "
"arguments or the number of evaluations is zero"));

cf0 = cfunc<fp_t>{{x + y + par[0], x - y + heyoka::time},
{x, y},
kw::opt_level = opt_level,
Expand Down Expand Up @@ -167,20 +157,6 @@ TEST_CASE("multieval st")
"but the expected size deduced from the "
"outputs array is 10"));

// Null par span.
REQUIRE_THROWS_MATCHES(cf0(out_2d{obuf.data(), 2, 10}, in_2d{ibuf.data(), 2, 10},
kw::pars = in_2d{nullptr, 1, 10}, kw::time = in_1d{tbuf.data(), 5}),
std::invalid_argument,
Message("The array of parameter values passed to a cfunc can be null only if the "
"number of parameters or the number of evaluations is zero"));

// Null time span.
REQUIRE_THROWS_MATCHES(cf0(out_2d{obuf.data(), 2, 10}, in_2d{ibuf.data(), 2, 10},
kw::pars = in_2d{pbuf.data(), 1, 10}, kw::time = in_1d{nullptr, 10}),
std::invalid_argument,
Message("The array of time values passed to a cfunc can be null only if the "
"number of evaluations is zero"));

// Functional testing.
cf0 = cfunc<fp_t>{{x + y, x - y},
{x, y},
Expand Down
20 changes: 1 addition & 19 deletions test/model_sgp4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,7 @@ TEST_CASE("propagator single")
REQUIRE(out(6, 1) == 0.);

// Try with several bogus input spans.
REQUIRE_THROWS_AS(prop(prop_t::out_2d{nullptr, 7, 2}, date_in), std::invalid_argument);
REQUIRE_THROWS_AS(prop(prop_t::out_2d{outs.data(), 5, 2}, date_in), std::invalid_argument);
REQUIRE_THROWS_AS(prop(out, prop_t::in_1d<double>{nullptr, 2}), std::invalid_argument);
REQUIRE_THROWS_AS(prop(out, prop_t::in_1d<double>{ins.data(), 1}), std::invalid_argument);
}
}
Expand Down Expand Up @@ -380,14 +378,8 @@ TEST_CASE("propagator batch")
prop(prop_t::out_3d{outs.data(), 0, 7, 2}, prop_t::in_2d<double>{tm.data(), 0, 2});

// Try with several bogus input spans.
REQUIRE_THROWS_MATCHES(
prop(prop_t::out_3d{nullptr, 2, 7, 2}, date_in), std::invalid_argument,
Message("A null output array was passed to the batch-mode call operator of an sgp4_propagator"));
REQUIRE_THROWS_AS(prop(prop_t::out_3d{outs.data(), 2, 5, 2}, date_in), std::invalid_argument);
REQUIRE_THROWS_AS(prop(prop_t::out_3d{outs.data(), 2, 4, 1}, date_in), std::invalid_argument);
REQUIRE_THROWS_MATCHES(
prop(out, prop_t::in_2d<double>{nullptr, 2, 2}), std::invalid_argument,
Message("A null times array was passed to the batch-mode call operator of an sgp4_propagator"));
REQUIRE_THROWS_AS(prop(out, prop_t::in_2d<double>{ins.data(), 2, 1}), std::invalid_argument);
REQUIRE_THROWS_AS(prop(out, prop_t::in_2d<double>{ins.data(), 2, 0}), std::invalid_argument);
}
Expand All @@ -403,7 +395,7 @@ TEST_CASE("error handling")

// Propagator with null list or zero satellites.
REQUIRE_THROWS_MATCHES((prop_t{md_input_t{nullptr, 0}}), std::invalid_argument,
Message("Cannot initialise an sgp4_propagator with a null list of satellites"));
Message("Cannot initialise an sgp4_propagator with an empty list of satellites"));

std::vector<double> input(9u);

Expand Down Expand Up @@ -467,9 +459,6 @@ TEST_CASE("error handling")
Message("Invalid propagation date detected for the satellite at index 1: the magnitude of the Julian "
"date (0) is less than the magnitude of the fractional correction (1)"));

REQUIRE_THROWS_MATCHES(prop(out, prop_t::in_1d<prop_t::date>{nullptr, 2}), std::invalid_argument,
Message("A null array of dates was passed to the call operator of an sgp4_propagator"));

prop_t::in_1d<prop_t::date> date_in2{dates.data(), 1};

REQUIRE_THROWS_MATCHES(
Expand All @@ -492,10 +481,6 @@ TEST_CASE("error handling")
"inferred from the output array is 2, which is not consistent with the number of evaluations "
"inferred from the times array (1)"));

REQUIRE_THROWS_MATCHES(
prop(out_batch, prop_t::in_2d<prop_t::date>{nullptr, 1, 2}), std::invalid_argument,
Message("A null array of dates was passed to the batch-mode call operator of an sgp4_propagator"));

date_b = prop_t::in_2d<prop_t::date>{dates_batch.data(), 1, 1};

REQUIRE_THROWS_MATCHES(
Expand Down Expand Up @@ -840,9 +825,6 @@ TEST_CASE("replace_sat_data")
REQUIRE(orig_out == outs);

// Error throwing.
REQUIRE_THROWS_MATCHES((prop.replace_sat_data(md_input_t{nullptr, 2})), std::invalid_argument,
Message("Cannot replace the satellite data with a null array"));

REQUIRE_THROWS_MATCHES((prop.replace_sat_data(md_input_t{ins2.data(), 1})), std::invalid_argument,
Message("Invalid array provided to replace_sat_data(): the number of "
"columns (1) does not match the number of satellites (2)"));
Expand Down

0 comments on commit 86c25bb

Please sign in to comment.