Skip to content

Commit

Permalink
CLN: Extract function on the basic indexing
Browse files Browse the repository at this point in the history
TBD: should this fully return a new _AxisMapping and maybe a
new indexer?
  • Loading branch information
Jacob-Stevens-Haas committed Jan 4, 2024
1 parent 48634df commit 91064dd
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
35 changes: 24 additions & 11 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /):
in_dim = self.shape
key, adv_inds = standardize_indexer(self, key)
adjacent, bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds)
remove_axes, new_axes = _apply_basic_indexing(key)

# Handle moving around non-adjacent advanced axes
old_index = OldIndex(0)
pindexers: list[PartialReIndexer | list[PartialReIndexer]] = []
Expand Down Expand Up @@ -231,17 +233,6 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /):
cindexers.append((*pindexer[:-1], curr_axis))
curr_axis += 1

remove_axes = []
new_axes = []
leftshift = 0
rightshift = 0
for key_ind, indexer in enumerate(key):
if indexer is None:
new_axes.append(key_ind - leftshift)
rightshift += 1
elif isinstance(indexer, int):
remove_axes.append(key_ind - rightshift)
leftshift += 1
if adv_inds:
adv_inds = sorted(adv_inds)
source_axis = [ # after basic indexing applied # noqa
Expand Down Expand Up @@ -512,6 +503,28 @@ def _squeeze_to_sublist(li: list, idxs: Sequence[int]) -> list:
return li[: min(idxs)] + [[li[idx] for idx in idxs]] + li[max(idxs) + 1 :]


def _apply_basic_indexing(key: tuple[StandardIndexer]) -> tuple[list[int], list[int]]:
"""Determine where axes should be removed and added
Only considers the basic indexers in key. Numpy arrays are treated as
slices, in that they don't affect the final dimensions of the output
"""
remove_axes = []
new_axes = []
deleted_to_left = 0
added_to_left = 0
for key_ind, indexer in enumerate(key):
if isinstance(indexer, int):
orig_arr_axis = key_ind - added_to_left
remove_axes.append(orig_arr_axis)
deleted_to_left += 1
elif indexer is None:
new_arr_axis = key_ind - deleted_to_left
new_axes.append(new_arr_axis)
added_to_left += 1
return remove_axes, new_axes


def comprehend_axes(x):
axes = {}
axes["ax_coord"] = len(x.shape) - 1
Expand Down
6 changes: 3 additions & 3 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,11 @@ def test_basic_indexing_modifies_axes():
reverse_slim.ax_coord
assert reverse_slim.ax_unk == 0
assert reverse_slim.ax_time == 1
almost_new = arr[None, None, 1, 1, None, None]
almost_new = arr[None, None, 1, :, None, None]
with pytest.raises(KeyError):
almost_new.ax_time
almost_new.ax_coord
assert set(almost_new.ax_unk) == {0, 1, 2, 3}
assert almost_new.ax_coord == 2
assert set(almost_new.ax_unk) == {0, 1, 3, 4}


def test_fancy_indexing_modifies_axes():
Expand Down

0 comments on commit 91064dd

Please sign in to comment.