diff --git a/engarde/checks.py b/engarde/checks.py index 67c7a85..2ec86f5 100644 --- a/engarde/checks.py +++ b/engarde/checks.py @@ -98,7 +98,8 @@ def is_shape(df, shape): df : DataFrame """ try: - assert df.shape == shape + check = np.all(np.equal(df.shape, shape) | np.equal(shape, [-1, -1])) + assert check except AssertionError as e: msg = ("Expected shape: {}\n" "\t\tActual shape: {}".format(shape, df.shape)) diff --git a/tests/test_checks.py b/tests/test_checks.py index 6e5ec7f..6d57b1e 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -138,10 +138,15 @@ def test_monotonic_items(): def test_is_shape(): shape = 10, 2 + ig_0 = -1, 2 + ig_1 = 10, -1 + shapes = [shape, ig_0, ig_1] df = pd.DataFrame(np.random.randn(*shape)) - tm.assert_frame_equal(df, ck.is_shape(df, shape)) - result = dc.is_shape(shape=shape)(_add_n)(df) - tm.assert_frame_equal(result, df + 1) + for shp in shapes: + tm.assert_frame_equal(df, ck.is_shape(df, shp)) + for shp in shapes: + result = dc.is_shape(shape=shp)(_add_n)(df) + tm.assert_frame_equal(result, df + 1) with pytest.raises(AssertionError): ck.is_shape(df, (9, 2))