Skip to content

Commit

Permalink
TST: Update tests for new helper function values
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Jan 4, 2024
1 parent 204223f commit 49ac5a0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 28 deletions.
7 changes: 5 additions & 2 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,11 @@ def __getattr__(self, name):
except KeyError:
raise AttributeError(f"AxesArray has no axis '{name}'")
if parts[0] == "n":
fwd_map = self._ax_map.fwd_map
shape = tuple(self.shape[ax_id] for ax_id in fwd_map["ax_" + parts[1]])
try:
ax_ids = self._ax_map.fwd_map["ax_" + parts[1]]
except KeyError:
raise AttributeError(f"AxesArray has no axis '{name}'")
shape = tuple(self.shape[ax_id] for ax_id in ax_ids)
if len(shape) == 1:
return shape[0]
return shape
Expand Down
45 changes: 19 additions & 26 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,19 +145,6 @@ def test_n_elements():
assert arr2.n_coord == 4


@pytest.mark.skip("Expected error")
def test_limited_slice():
arr = np.empty(np.arange(1, 5))
arr = AxesArray(arr, {"ax_spatial": [0, 1], "ax_time": 2, "ax_coord": 3})
arr3 = arr[..., :2, 0]
assert arr3.n_spatial == (1, 2)
assert arr3.n_time == 2
# No way to intercept slicing and remove ax_coord
with pytest.raises(IndexError):
assert arr3.n_coord == 1
assert arr3.n_sample == 1


def test_warn_toofew_axes():
axes = {"ax_time": 0, "ax_coord": 1}
with pytest.warns(AxesWarning):
Expand All @@ -176,21 +163,30 @@ def test_conflicting_axes_defn():
AxesArray(np.ones(4), axes)


def test_missing_axis_errors():
axes = {"ax_time": 0}
arr = AxesArray(np.arange(3), axes)
with pytest.raises(AttributeError):
arr.ax_spatial
with pytest.raises(AttributeError):
arr.n_spatial


def test_basic_indexing_modifies_axes():
axes = {"ax_time": 0, "ax_coord": 1}
arr = AxesArray(np.ones(4).reshape((2, 2)), axes)
slim = arr[1, :, None]
with pytest.raises(KeyError):
with pytest.raises(AttributeError):
slim.ax_time
assert slim.ax_unk == 1
assert slim.ax_coord == 0
reverse_slim = arr[None, :, 1]
with pytest.raises(KeyError):
with pytest.raises(AttributeError):
reverse_slim.ax_coord
assert reverse_slim.ax_unk == 0
assert reverse_slim.ax_time == 1
almost_new = arr[None, None, 1, :, None, None]
with pytest.raises(KeyError):
with pytest.raises(AttributeError):
almost_new.ax_time
assert almost_new.ax_coord == 2
assert set(almost_new.ax_unk) == {0, 1, 3, 4}
Expand Down Expand Up @@ -232,26 +228,26 @@ def test_adv_indexing_modifies_axes():
def test_standardize_basic_indexer():
arr = np.arange(6).reshape(2, 3)
result_indexer, result_fancy = axes.standardize_indexer(arr, Ellipsis)
assert result_indexer == (slice(None), slice(None))
assert result_indexer == [slice(None), slice(None)]
assert result_fancy == ()

result_indexer, result_fancy = axes.standardize_indexer(
arr, (np.newaxis, 1, 1, Ellipsis)
)
assert result_indexer == (None, 1, 1)
assert result_indexer == [None, 1, 1]
assert result_fancy == ()


def test_standardize_fancy_indexer():
arr = np.arange(6).reshape(2, 3)
result_indexer, result_fancy = axes.standardize_indexer(arr, [1])
assert result_indexer == (np.ones(1), slice(None))
assert result_indexer == [np.ones(1), slice(None)]
assert result_fancy == (0,)

result_indexer, result_fancy = axes.standardize_indexer(
arr, (np.newaxis, [1], 1, Ellipsis)
)
assert result_indexer == (None, np.ones(1), 1)
assert result_indexer == [None, np.ones(1), 1]
assert result_fancy == (1,)


Expand Down Expand Up @@ -408,18 +404,15 @@ def test_insert_misordered_AxisMapping():

def test_determine_adv_broadcasting():
indexers = (1, np.ones(1), np.ones((4, 1)), np.ones(3))
res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 2, 3])
assert res_adj is True
res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 2, 3])
assert res_nd == 2
assert res_start == 1

indexers = (None, np.ones(1), 2, np.ones(3))
res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 3])
assert res_adj is False
res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 3])
assert res_nd == 1
assert res_start == 0

res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [])
assert res_adj is True
res_nd, res_start = axes._determine_adv_broadcasting(indexers, [])
assert res_nd == 0
assert res_start is None

0 comments on commit 49ac5a0

Please sign in to comment.