diff --git a/quantus/helpers/asserts.py b/quantus/helpers/asserts.py index ad2d95ee..401f5f56 100644 --- a/quantus/helpers/asserts.py +++ b/quantus/helpers/asserts.py @@ -286,6 +286,12 @@ def assert_value_smaller_than_input_size( ------- None """ + if len(x.shape) == 2: + if value >= np.prod(x.shape[1:]): + raise ValueError( + f"'{value_name}' must be smaller than input size." + f" [{value} >= {np.prod(x.shape[1:])}]" + ) if value >= np.prod(x.shape[2:]): raise ValueError( f"'{value_name}' must be smaller than input size."