Skip to content

Commit

Permalink
feat(_AxisMapping): create an ndim property
Browse files Browse the repository at this point in the history
Used in fixing bug: Handle removing negative axis indexes
  • Loading branch information
Jacob-Stevens-Haas committed Jan 12, 2024
1 parent 51bae67 commit 23817f0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
11 changes: 9 additions & 2 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def compat_axes(self):

def remove_axis(self, axis: Union[Collection[int], int, None] = None):
"""Create an axes dict from self with specified axis or axes
removed and all greater axes decremented.
removed and all greater axes decremented. This can be passed to
the constructor to create a new _AxisMapping
Arguments:
axis: the axis index or axes indexes to remove. By numpy
Expand All @@ -105,6 +106,7 @@ def remove_axis(self, axis: Union[Collection[int], int, None] = None):
in_ndim = len(self.reverse_map)
if not isinstance(axis, Collection):
axis = [axis]
axis = [ax_id if ax_id >= 0 else (self.ndim + ax_id) for ax_id in axis]
for cum_shift, orig_ax_remove in enumerate(sorted(axis)):
remove_ax_name = self.reverse_map[orig_ax_remove]
curr_ax_remove = orig_ax_remove - cum_shift
Expand Down Expand Up @@ -146,6 +148,10 @@ def insert_axis(self, axis: Union[Collection[int], int], new_name: str):
new_axes[ax_name].append(ax_id + 1)
return self._compat_axes(new_axes)

@property
def ndim(self):
return len(self.reverse_map)


class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray):
"""A numpy-like array that keeps track of the meaning of its axes.
Expand Down Expand Up @@ -481,7 +487,8 @@ def comprehend_axes(x):
axes = {}
axes["ax_coord"] = len(x.shape) - 1
axes["ax_time"] = len(x.shape) - 2
axes["ax_spatial"] = list(range(len(x.shape) - 2))
if x.ndim > 2:
axes["ax_spatial"] = list(range(len(x.shape) - 2))
return axes


Expand Down
17 changes: 4 additions & 13 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,22 +274,13 @@ def test_standardize_bool_indexer():

def test_reduce_AxisMapping():
ax_map = _AxisMapping(
{
"ax_a": [0, 1],
"ax_b": 2,
"ax_c": 3,
"ax_d": 4,
"ax_e": [5, 6],
},
{"ax_a": [0, 1], "ax_b": 2, "ax_c": 3, "ax_d": 4, "ax_e": [5, 6]},
7,
)
result = ax_map.remove_axis(3)
expected = {
"ax_a": [0, 1],
"ax_b": 2,
"ax_d": 3,
"ax_e": [4, 5],
}
expected = {"ax_a": [0, 1], "ax_b": 2, "ax_d": 3, "ax_e": [4, 5]}
assert result == expected
result = ax_map.remove_axis(-4)
assert result == expected


Expand Down

0 comments on commit 23817f0

Please sign in to comment.