Skip to content

Commit

Permalink
ENH: Enable fancy indexing in AxesArray
Browse files Browse the repository at this point in the history
Involves processing the keys several times, with increasing standardization
  • Loading branch information
Jacob-Stevens-Haas committed Jan 4, 2024
1 parent 700521a commit 204223f
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 171 deletions.
234 changes: 86 additions & 148 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
import warnings
from enum import Enum
from typing import Collection
from typing import List
from typing import Literal
from typing import NewType
from typing import Optional
from typing import Sequence
Expand All @@ -26,6 +28,14 @@
]


class Sentinels(Enum):
ADV_NAME = object()
ADV_REMOVE = object()


Literal[Sentinels.ADV_NAME]


class _AxisMapping:
"""Convenience wrapper for a two-way map between axis names and
indexes.
Expand Down Expand Up @@ -181,7 +191,10 @@ def shape(self):
def __getattr__(self, name):
parts = name.split("_", 1)
if parts[0] == "ax":
return self.axes[name]
try:
return self.axes[name]
except KeyError:
raise AttributeError(f"AxesArray has no axis '{name}'")
if parts[0] == "n":
fwd_map = self._ax_map.fwd_map
shape = tuple(self.shape[ax_id] for ax_id in fwd_map["ax_" + parts[1]])
Expand All @@ -193,104 +206,22 @@ def __getattr__(self, name):
def __getitem__(self, key: Indexer | Sequence[Indexer], /):
output = super().__getitem__(key)
if not isinstance(output, AxesArray):
return output
return output # why?
in_dim = self.shape
key, adv_inds = standardize_indexer(self, key)
adjacent, bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds)
remove_axes, new_axes = _apply_basic_indexing(key)

# Handle moving around non-adjacent advanced axes
old_index = OldIndex(0)
pindexers: list[PartialReIndexer | list[PartialReIndexer]] = []
for key_ind, indexer in enumerate(key):
if isinstance(indexer, int | slice | np.ndarray):
pindexers.append((key_ind, old_index, indexer))
old_index += 1
elif indexer is None:
pindexers.append((key_ind, [None], None))
else:
raise TypeError(
f"AxesArray indexer of type {type(indexer)} not understood"
)
# Advanced indexing can move axes if they are not adjacent
if not adjacent:
_move_idxs_to_front(key, adv_inds)
adv_inds = range(len(adv_inds))
pindexers = _squeeze_to_sublist(pindexers, adv_inds)
cindexers: list[CompleteReIndexer] = []
curr_axis = 0
for pindexer in enumerate(pindexers):
if isinstance(pindexer, list): # advanced indexing bundle
bcast_idxers = _adv_broadcast_magic(key, adv_inds, pindexer)
cindexers += bcast_idxers
curr_axis += bcast_nd
elif pindexer[-1] is None:
cindexers.append((*pindexer[:-1], curr_axis))
curr_axis += 1
elif isinstance(pindexer[-1], int):
cindexers.append((*pindexer[:-1], None))
elif isinstance(pindexer[-1], slice):
cindexers.append((*pindexer[:-1], curr_axis))
curr_axis += 1

bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds)
if adv_inds:
adv_inds = sorted(adv_inds)
source_axis = [ # after basic indexing applied # noqa
len([id for id in range(idx_id) if key[id] is not None])
for idx_id in adv_inds
]
adv_indexers = [np.array(key[i]) for i in adv_inds] # noqa
bcast_nd = np.broadcast(*adv_indexers).nd
adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:]))
bcast_start_ax = 0 if not adjacent else min(adv_inds)
adv_map = {}

for idx_id, idxer in zip(adv_inds, adv_indexers):
base_idxer_ax_name = self._reverse_map[ # count non-None keys
len([id for id in range(idx_id) if key[id] is not None])
]
adv_map[base_idxer_ax_name] = [
bcast_start_ax + shp
for shp in _compare_bcast_shapes(bcast_nd, idxer.shape)
]

conflicts = {}
for bcast_ax in range(bcast_nd):
ax_names = [name for name, axes in adv_map.items() if bcast_ax in axes]
if len(ax_names) > 1:
conflicts[bcast_ax] = ax_names
[]
if len(ax_names) == 0:
if "ax_unk" not in adv_map.keys():
adv_map["ax_unk"] = [bcast_ax + bcast_start_ax]
else:
adv_map["ax_unk"].append(bcast_ax + bcast_start_ax)

for conflict_axis, conflict_names in conflicts.items():
new_name = "ax_"
for name in conflict_names:
adv_map[name].remove(conflict_axis)
if not adv_map[name]:
adv_map.pop(name)
new_name += name[3:]
adv_map[new_name] = [conflict_axis]

# check if integer or boolean indexing
# if integer, check which dimensions get broadcast where
# if multiple, axes are merged. If adjacent, merged inplace,
# otherwise moved to beginning
remove_axes.append(adv_map.keys()) # Error: remove_axis takes ints

out_obj = np.broadcast(np.array(key[i]) for i in adv_inds) # noqa
pass
# mulligan structured arrays, etc.
key = replace_adv_indexers(key, adv_inds, bcast_start_ax, bcast_nd)
remove_axes, new_axes, adv_names = _apply_indexing(key, self._reverse_map)
new_axes = _rename_broadcast_axes(new_axes, adv_names)
new_map = _AxisMapping(
self._ax_map.remove_axis(remove_axes), len(in_dim) - len(remove_axes)
)
new_map = _AxisMapping(
new_map.insert_axis(new_axes),
len(in_dim) - len(remove_axes) + len(new_axes),
)
for new_ax_ind, new_ax_name in new_axes:
new_map = _AxisMapping(
new_map.insert_axis(new_ax_ind, new_ax_name),
len(in_dim) - len(remove_axes) + len(new_axes),
)
output._ax_map = new_map
return output

Expand Down Expand Up @@ -404,7 +335,7 @@ def concatenate(arrays, axis=0):

def standardize_indexer(
arr: np.ndarray, key: Indexer | Sequence[Indexer]
) -> tuple[tuple[StandardIndexer], tuple[KeyIndex]]:
) -> tuple[list[StandardIndexer], tuple[KeyIndex]]:
"""Convert any legal numpy indexer to a "standard" form.
Standard form involves creating an equivalent indexer that is a tuple with
Expand Down Expand Up @@ -432,7 +363,7 @@ def standardize_indexer(
ax_key = np.array(ax_key)
adv_inds.append(indexer_ind)
new_key.append(ax_key)
return tuple(new_key), tuple(adv_inds)
return new_key, tuple(adv_inds)


def _expand_indexer_ellipsis(indexers: list[Indexer], ndim: int) -> None:
Expand All @@ -449,80 +380,87 @@ def _expand_indexer_ellipsis(indexers: list[Indexer], ndim: int) -> None:
indexers[ellind : ellind + 1] = n_ellipsis_dims * (slice(None),)


def _adv_broadcast_magic(*args):
raise NotImplementedError


def _compare_bcast_shapes(result_ndim: int, base_shape: tuple[int]) -> list[int]:
"""Identify which broadcast shape axes are due to base_shape
Args:
result_ndim: number of dimensions broadcast shape has
base_shape: shape of one element of broadcasting
Result:
tuple of axes in broadcast result that come from base shape
"""
return [
result_ndim - 1 - ax_id
for ax_id, length in enumerate(reversed(base_shape))
if length > 1
]


def _move_idxs_to_front(li: list, idxs: Sequence) -> None:
"""Move all items at indexes specified to the front of a list"""
front = []
for idx in reversed(idxs):
obj = li.pop(idx)
front.insert(0, obj)
li = front + li


def _determine_adv_broadcasting(
key: StandardIndexer | Sequence[StandardIndexer], adv_inds: Sequence[OldIndex]
) -> tuple[bool, int, Optional[int]]:
key: Sequence[StandardIndexer], adv_inds: Sequence[OldIndex]
) -> tuple[int, Optional[KeyIndex]]:
"""Calculate the shape and location for the result of advanced indexing."""
adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:]))
adv_indexers = [np.array(key[i]) for i in adv_inds]
bcast_nd = np.broadcast(*adv_indexers).nd
bcast_start_axis = 0 if not adjacent else min(adv_inds) if adv_inds else None
return adjacent, bcast_nd, bcast_start_axis


def _squeeze_to_sublist(li: list, idxs: Sequence[int]) -> list:
"""Turn contiguous elements of a list into a sub-list in the same position
e.g. _squeeze_to_sublist(["a", "b", "c", "d"], [1,2]) = ["a", ["b", "c"], "d"]
"""
for left, right in zip(idxs[:-1], idxs[1:]):
if left + 1 != right:
raise ValueError("Indexes to squeeze must be contiguous")
if not idxs:
return li
return li[: min(idxs)] + [[li[idx] for idx in idxs]] + li[max(idxs) + 1 :]


def _apply_basic_indexing(key: tuple[StandardIndexer]) -> tuple[list[int], list[int]]:
return bcast_nd, KeyIndex(bcast_start_axis)


def _rename_broadcast_axes(
new_axes: list[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]],
adv_names: list[str],
) -> list[tuple[int, str]]:
"""Normalize sentinel and NoneType names"""

def _calc_bcast_name(*names: str) -> str:
if not names:
return ""
if all(a == b for a, b in zip(names[1:], names[:-1])):
return names[0]
names = [name[3:] for name in dict.fromkeys(names)] # ordered deduplication
return "ax_" + "_".join(names)

bcast_name = _calc_bcast_name(*adv_names)
renamed_axes = []
for ax_ind, ax_name in new_axes:
if ax_name is None:
renamed_axes.append((ax_ind, "ax_unk"))
elif ax_name is Sentinels.ADV_NAME:
renamed_axes.append((ax_ind, bcast_name))
else:
renamed_axes.append((ax_ind, ax_name))
return renamed_axes


def replace_adv_indexers(
key: Sequence[StandardIndexer],
adv_inds: list[int],
bcast_start_ax: int,
bcast_nd: int,
) -> tuple[
Union[None, str, int, Literal[Sentinels.ADV_NAME], Literal[Sentinels.ADV_REMOVE]],
...,
]:
for adv_ind in adv_inds:
key[adv_ind] = Sentinels.ADV_REMOVE
key = key[:bcast_start_ax] + bcast_nd * [Sentinels.ADV_NAME] + key[bcast_start_ax:]
return key


def _apply_indexing(
key: tuple[StandardIndexer], reverse_map: dict[int, str]
) -> tuple[
list[int], list[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], list[str]
]:
"""Determine where axes should be removed and added
Only considers the basic indexers in key. Numpy arrays are treated as
slices, in that they don't affect the final dimensions of the output
"""
remove_axes = []
new_axes = []
adv_names = []
deleted_to_left = 0
added_to_left = 0
for key_ind, indexer in enumerate(key):
if isinstance(indexer, int):
if isinstance(indexer, int) or indexer is Sentinels.ADV_REMOVE:
orig_arr_axis = key_ind - added_to_left
if indexer is Sentinels.ADV_REMOVE:
adv_names.append(reverse_map[orig_arr_axis])
remove_axes.append(orig_arr_axis)
deleted_to_left += 1
elif indexer is None:
elif (
indexer is None or indexer is Sentinels.ADV_NAME or isinstance(indexer, str)
):
new_arr_axis = key_ind - deleted_to_left
new_axes.append(new_arr_axis)
new_axes.append((new_arr_axis, indexer))
added_to_left += 1
return remove_axes, new_axes
return remove_axes, new_axes, adv_names


def comprehend_axes(x):
Expand Down
33 changes: 10 additions & 23 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ def test_repr():
assert result == expected


@pytest.mark.skip(
"Not until fancy indexing (boolean) either short-circuited or implemented"
)
def test_ufunc_override():
# This is largely a clone of test_ufunc_override_with_super() from
# numpy/core/tests/test_umath.py
Expand Down Expand Up @@ -199,7 +196,7 @@ def test_basic_indexing_modifies_axes():
assert set(almost_new.ax_unk) == {0, 1, 3, 4}


def test_fancy_indexing_modifies_axes():
def test_adv_indexing_modifies_axes():
axes = {"ax_time": 0, "ax_coord": 1}
arr = AxesArray(np.arange(4).reshape((2, 2)), axes)
flat = arr[[0, 1], [0, 1]]
Expand All @@ -208,21 +205,23 @@ def test_fancy_indexing_modifies_axes():
assert flat.shape == (2,)
np.testing.assert_array_equal(np.asarray(flat), np.array([0, 3]))

assert flat.ax__timecoord == 0
assert flat.ax_time_coord == 0
with pytest.raises(AttributeError):
flat.ax_coord
with pytest.raises(AttributeError):
flat.ax_time

assert same.shape == arr.shape
np.testing.assert_equal(same, arr)
assert same.ax_time == 0
assert same.ax_coord == 1
np.testing.assert_equal(np.asarray(same), np.asarray(arr))
assert same.ax_time_coord == [0, 1]
with pytest.raises(AttributeError):
same.ax_coord

assert tpose.shape == arr.shape
np.testing.assert_equal(same, arr.T)
assert same.ax_time == 1
assert same.ax_coord == 0
np.testing.assert_equal(np.asarray(tpose), np.asarray(arr.T))
assert tpose.ax_time_coord == [0, 1]
with pytest.raises(AttributeError):
tpose.ax_coord

fat = arr[[[0, 1], [0, 1]]]
assert fat.shape == (2, 2, 2)
Expand Down Expand Up @@ -407,18 +406,6 @@ def test_insert_misordered_AxisMapping():
assert result == expected


def test_squeeze_to_sublist():
li = ["a", "b", "c", "d"]
result = axes._squeeze_to_sublist(li, [1, 2])
assert result == ["a", ["b", "c"], "d"]

result = axes._squeeze_to_sublist(li, [])
assert result == li

with pytest.raises(ValueError, match="Indexes to squeeze"):
axes._squeeze_to_sublist(li, [0, 2])


def test_determine_adv_broadcasting():
indexers = (1, np.ones(1), np.ones((4, 1)), np.ones(3))
res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 2, 3])
Expand Down

0 comments on commit 204223f

Please sign in to comment.