From 2dd6324ae91ddc948d0d9c1862cf4cb96e1f82ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anna=20Hedstr=C3=B6m?= Date: Tue, 23 Jan 2024 11:32:23 +0100 Subject: [PATCH] Update asserts.py --- quantus/helpers/asserts.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/quantus/helpers/asserts.py b/quantus/helpers/asserts.py index ad2d95ee3..401f5f568 100644 --- a/quantus/helpers/asserts.py +++ b/quantus/helpers/asserts.py @@ -286,6 +286,12 @@ def assert_value_smaller_than_input_size( ------- None """ + if len(x.shape) == 2: + if value >= np.prod(x.shape[1:]): + raise ValueError( + f"'{value_name}' must be smaller than input size." + f" [{value} >= {np.prod(x.shape[1:])}]" + ) if value >= np.prod(x.shape[2:]): raise ValueError( f"'{value_name}' must be smaller than input size."