diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 411fb4f76..3b752011b 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -98,7 +98,7 @@ def remove_axis(self, axis: Union[Collection[int], int, None] = None): in_ndim = len(self.reverse_map) if not isinstance(axis, Collection): axis = [axis] - for cum_shift, orig_ax_remove in enumerate(axis): + for cum_shift, orig_ax_remove in enumerate(sorted(axis)): remove_ax_name = self.reverse_map[orig_ax_remove] curr_ax_remove = orig_ax_remove - cum_shift if len(new_axes[remove_ax_name]) == 1: @@ -129,7 +129,7 @@ def insert_axis(self, axis: Union[Collection[int], int], new_name: str = "ax_unk in_ndim = len(self.reverse_map) if not isinstance(axis, Collection): axis = [axis] - for cum_shift, ax in enumerate(axis): + for cum_shift, ax in enumerate(sorted(axis)): if new_name in new_axes.keys(): new_axes[new_name].append(ax) else: diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 715fcc318..20e6b0b06 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -317,6 +317,13 @@ def test_reduce_twisted_AxisMapping(): assert result == expected +def test_reduce_misordered_AxisMapping(): + ax_map = _AxisMapping({"ax_a": [0, 1], "ax_b": 2, "ax_c": 3}, 7) + result = ax_map.remove_axis([2, 1]) + expected = {"ax_a": 0, "ax_c": 1} + assert result == expected + + def test_insert_AxisMapping(): ax_map = _AxisMapping( { @@ -338,6 +345,26 @@ def test_insert_AxisMapping(): assert result == expected +def test_insert_existing_AxisMapping(): + ax_map = _AxisMapping( + { + "ax_a": [0, 1], + "ax_b": 2, + "ax_c": 3, + "ax_d": [4, 5], + }, + 6, + ) + result = ax_map.insert_axis(3, "ax_b") + expected = { + "ax_a": [0, 1], + "ax_b": [2, 3], + "ax_c": 4, + "ax_d": [5, 6], + } + assert result == expected + + def test_insert_multiple_AxisMapping(): ax_map = _AxisMapping( { @@ -357,3 +384,24 @@ def test_insert_multiple_AxisMapping(): "ax_d": [6, 7], } assert result == expected + + +def test_insert_misordered_AxisMapping(): + ax_map = _AxisMapping( + { + "ax_a": [0, 1], + "ax_b": 2, + "ax_c": 3, + "ax_d": [4, 5], + }, + 6, + ) + result = ax_map.insert_axis([4, 1]) + expected = { + "ax_a": [0, 2], + "ax_unk": [1, 4], + "ax_b": 3, + "ax_c": 5, + "ax_d": [6, 7], + } + assert result == expected