diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index f5347146e..bc16e8a7f 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -197,6 +197,8 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): in_dim = self.shape key, adv_inds = standardize_indexer(self, key) adjacent, bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds) + remove_axes, new_axes = _apply_basic_indexing(key) + # Handle moving around non-adjacent advanced axes old_index = OldIndex(0) pindexers: list[PartialReIndexer | list[PartialReIndexer]] = [] @@ -231,17 +233,6 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): cindexers.append((*pindexer[:-1], curr_axis)) curr_axis += 1 - remove_axes = [] - new_axes = [] - leftshift = 0 - rightshift = 0 - for key_ind, indexer in enumerate(key): - if indexer is None: - new_axes.append(key_ind - leftshift) - rightshift += 1 - elif isinstance(indexer, int): - remove_axes.append(key_ind - rightshift) - leftshift += 1 if adv_inds: adv_inds = sorted(adv_inds) source_axis = [ # after basic indexing applied # noqa @@ -512,6 +503,28 @@ def _squeeze_to_sublist(li: list, idxs: Sequence[int]) -> list: return li[: min(idxs)] + [[li[idx] for idx in idxs]] + li[max(idxs) + 1 :] +def _apply_basic_indexing(key: tuple[StandardIndexer]) -> tuple[list[int], list[int]]: + """Determine where axes should be removed and added + + Only considers the basic indexers in key. Numpy arrays are treated as + slices, in that they don't affect the final dimensions of the output + """ + remove_axes = [] + new_axes = [] + deleted_to_left = 0 + added_to_left = 0 + for key_ind, indexer in enumerate(key): + if isinstance(indexer, int): + orig_arr_axis = key_ind - added_to_left + remove_axes.append(orig_arr_axis) + deleted_to_left += 1 + elif indexer is None: + new_arr_axis = key_ind - deleted_to_left + new_axes.append(new_arr_axis) + added_to_left += 1 + return remove_axes, new_axes + + def comprehend_axes(x): axes = {} axes["ax_coord"] = len(x.shape) - 1 diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index b37b20b83..8892837df 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -192,11 +192,11 @@ def test_basic_indexing_modifies_axes(): reverse_slim.ax_coord assert reverse_slim.ax_unk == 0 assert reverse_slim.ax_time == 1 - almost_new = arr[None, None, 1, 1, None, None] + almost_new = arr[None, None, 1, :, None, None] with pytest.raises(KeyError): almost_new.ax_time - almost_new.ax_coord - assert set(almost_new.ax_unk) == {0, 1, 2, 3} + assert almost_new.ax_coord == 2 + assert set(almost_new.ax_unk) == {0, 1, 3, 4} def test_fancy_indexing_modifies_axes():