From 49ac5a0a0af5e931b9f7987a3fa3fae86b76e810 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 4 Jan 2024 17:42:16 +0000 Subject: [PATCH] TST: Update tests for new helper function values --- pysindy/utils/axes.py | 7 +++++-- test/utils/test_axes.py | 45 +++++++++++++++++------------------------ 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 8a0f9dd96..1e803b1ba 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -196,8 +196,11 @@ def __getattr__(self, name): except KeyError: raise AttributeError(f"AxesArray has no axis '{name}'") if parts[0] == "n": - fwd_map = self._ax_map.fwd_map - shape = tuple(self.shape[ax_id] for ax_id in fwd_map["ax_" + parts[1]]) + try: + ax_ids = self._ax_map.fwd_map["ax_" + parts[1]] + except KeyError: + raise AttributeError(f"AxesArray has no axis '{name}'") + shape = tuple(self.shape[ax_id] for ax_id in ax_ids) if len(shape) == 1: return shape[0] return shape diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index ebe703f47..f33b94750 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -145,19 +145,6 @@ def test_n_elements(): assert arr2.n_coord == 4 -@pytest.mark.skip("Expected error") -def test_limited_slice(): - arr = np.empty(np.arange(1, 5)) - arr = AxesArray(arr, {"ax_spatial": [0, 1], "ax_time": 2, "ax_coord": 3}) - arr3 = arr[..., :2, 0] - assert arr3.n_spatial == (1, 2) - assert arr3.n_time == 2 - # No way to intercept slicing and remove ax_coord - with pytest.raises(IndexError): - assert arr3.n_coord == 1 - assert arr3.n_sample == 1 - - def test_warn_toofew_axes(): axes = {"ax_time": 0, "ax_coord": 1} with pytest.warns(AxesWarning): @@ -176,21 +163,30 @@ def test_conflicting_axes_defn(): AxesArray(np.ones(4), axes) +def test_missing_axis_errors(): + axes = {"ax_time": 0} + arr = AxesArray(np.arange(3), axes) + with pytest.raises(AttributeError): + arr.ax_spatial + with pytest.raises(AttributeError): + arr.n_spatial + + def test_basic_indexing_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.ones(4).reshape((2, 2)), axes) slim = arr[1, :, None] - with pytest.raises(KeyError): + with pytest.raises(AttributeError): slim.ax_time assert slim.ax_unk == 1 assert slim.ax_coord == 0 reverse_slim = arr[None, :, 1] - with pytest.raises(KeyError): + with pytest.raises(AttributeError): reverse_slim.ax_coord assert reverse_slim.ax_unk == 0 assert reverse_slim.ax_time == 1 almost_new = arr[None, None, 1, :, None, None] - with pytest.raises(KeyError): + with pytest.raises(AttributeError): almost_new.ax_time assert almost_new.ax_coord == 2 assert set(almost_new.ax_unk) == {0, 1, 3, 4} @@ -232,26 +228,26 @@ def test_adv_indexing_modifies_axes(): def test_standardize_basic_indexer(): arr = np.arange(6).reshape(2, 3) result_indexer, result_fancy = axes.standardize_indexer(arr, Ellipsis) - assert result_indexer == (slice(None), slice(None)) + assert result_indexer == [slice(None), slice(None)] assert result_fancy == () result_indexer, result_fancy = axes.standardize_indexer( arr, (np.newaxis, 1, 1, Ellipsis) ) - assert result_indexer == (None, 1, 1) + assert result_indexer == [None, 1, 1] assert result_fancy == () def test_standardize_fancy_indexer(): arr = np.arange(6).reshape(2, 3) result_indexer, result_fancy = axes.standardize_indexer(arr, [1]) - assert result_indexer == (np.ones(1), slice(None)) + assert result_indexer == [np.ones(1), slice(None)] assert result_fancy == (0,) result_indexer, result_fancy = axes.standardize_indexer( arr, (np.newaxis, [1], 1, Ellipsis) ) - assert result_indexer == (None, np.ones(1), 1) + assert result_indexer == [None, np.ones(1), 1] assert result_fancy == (1,) @@ -408,18 +404,15 @@ def test_insert_misordered_AxisMapping(): def test_determine_adv_broadcasting(): indexers = (1, np.ones(1), np.ones((4, 1)), np.ones(3)) - res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 2, 3]) - assert res_adj is True + res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 2, 3]) assert res_nd == 2 assert res_start == 1 indexers = (None, np.ones(1), 2, np.ones(3)) - res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 3]) - assert res_adj is False + res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 3]) assert res_nd == 1 assert res_start == 0 - res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, []) - assert res_adj is True + res_nd, res_start = axes._determine_adv_broadcasting(indexers, []) assert res_nd == 0 assert res_start is None