Skip to content

Commit

Permalink
BUG: Sort axis argument when inserting or removing axes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Jan 4, 2024
1 parent 6ccba03 commit b8c8739
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def remove_axis(self, axis: Union[Collection[int], int, None] = None):
in_ndim = len(self.reverse_map)
if not isinstance(axis, Collection):
axis = [axis]
for cum_shift, orig_ax_remove in enumerate(axis):
for cum_shift, orig_ax_remove in enumerate(sorted(axis)):
remove_ax_name = self.reverse_map[orig_ax_remove]
curr_ax_remove = orig_ax_remove - cum_shift
if len(new_axes[remove_ax_name]) == 1:
Expand Down Expand Up @@ -129,7 +129,7 @@ def insert_axis(self, axis: Union[Collection[int], int], new_name: str = "ax_unk
in_ndim = len(self.reverse_map)
if not isinstance(axis, Collection):
axis = [axis]
for cum_shift, ax in enumerate(axis):
for cum_shift, ax in enumerate(sorted(axis)):
if new_name in new_axes.keys():
new_axes[new_name].append(ax)
else:
Expand Down
48 changes: 48 additions & 0 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,13 @@ def test_reduce_twisted_AxisMapping():
assert result == expected


def test_reduce_misordered_AxisMapping():
ax_map = _AxisMapping({"ax_a": [0, 1], "ax_b": 2, "ax_c": 3}, 7)
result = ax_map.remove_axis([2, 1])
expected = {"ax_a": 0, "ax_c": 1}
assert result == expected


def test_insert_AxisMapping():
ax_map = _AxisMapping(
{
Expand All @@ -338,6 +345,26 @@ def test_insert_AxisMapping():
assert result == expected


def test_insert_existing_AxisMapping():
ax_map = _AxisMapping(
{
"ax_a": [0, 1],
"ax_b": 2,
"ax_c": 3,
"ax_d": [4, 5],
},
6,
)
result = ax_map.insert_axis(3, "ax_b")
expected = {
"ax_a": [0, 1],
"ax_b": [2, 3],
"ax_c": 4,
"ax_d": [5, 6],
}
assert result == expected


def test_insert_multiple_AxisMapping():
ax_map = _AxisMapping(
{
Expand All @@ -357,3 +384,24 @@ def test_insert_multiple_AxisMapping():
"ax_d": [6, 7],
}
assert result == expected


def test_insert_misordered_AxisMapping():
ax_map = _AxisMapping(
{
"ax_a": [0, 1],
"ax_b": 2,
"ax_c": 3,
"ax_d": [4, 5],
},
6,
)
result = ax_map.insert_axis([4, 1])
expected = {
"ax_a": [0, 2],
"ax_unk": [1, 4],
"ax_b": 3,
"ax_c": 5,
"ax_d": [6, 7],
}
assert result == expected

0 comments on commit b8c8739

Please sign in to comment.