diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index d6a8d7046..8a0f9dd96 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -1,7 +1,9 @@ import copy import warnings +from enum import Enum from typing import Collection from typing import List +from typing import Literal from typing import NewType from typing import Optional from typing import Sequence @@ -26,6 +28,14 @@ ] +class Sentinels(Enum): + ADV_NAME = object() + ADV_REMOVE = object() + + +Literal[Sentinels.ADV_NAME] + + class _AxisMapping: """Convenience wrapper for a two-way map between axis names and indexes. @@ -181,7 +191,10 @@ def shape(self): def __getattr__(self, name): parts = name.split("_", 1) if parts[0] == "ax": - return self.axes[name] + try: + return self.axes[name] + except KeyError: + raise AttributeError(f"AxesArray has no axis '{name}'") if parts[0] == "n": fwd_map = self._ax_map.fwd_map shape = tuple(self.shape[ax_id] for ax_id in fwd_map["ax_" + parts[1]]) @@ -193,104 +206,22 @@ def __getattr__(self, name): def __getitem__(self, key: Indexer | Sequence[Indexer], /): output = super().__getitem__(key) if not isinstance(output, AxesArray): - return output + return output # why? in_dim = self.shape key, adv_inds = standardize_indexer(self, key) - adjacent, bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds) - remove_axes, new_axes = _apply_basic_indexing(key) - - # Handle moving around non-adjacent advanced axes - old_index = OldIndex(0) - pindexers: list[PartialReIndexer | list[PartialReIndexer]] = [] - for key_ind, indexer in enumerate(key): - if isinstance(indexer, int | slice | np.ndarray): - pindexers.append((key_ind, old_index, indexer)) - old_index += 1 - elif indexer is None: - pindexers.append((key_ind, [None], None)) - else: - raise TypeError( - f"AxesArray indexer of type {type(indexer)} not understood" - ) - # Advanced indexing can move axes if they are not adjacent - if not adjacent: - _move_idxs_to_front(key, adv_inds) - adv_inds = range(len(adv_inds)) - pindexers = _squeeze_to_sublist(pindexers, adv_inds) - cindexers: list[CompleteReIndexer] = [] - curr_axis = 0 - for pindexer in enumerate(pindexers): - if isinstance(pindexer, list): # advanced indexing bundle - bcast_idxers = _adv_broadcast_magic(key, adv_inds, pindexer) - cindexers += bcast_idxers - curr_axis += bcast_nd - elif pindexer[-1] is None: - cindexers.append((*pindexer[:-1], curr_axis)) - curr_axis += 1 - elif isinstance(pindexer[-1], int): - cindexers.append((*pindexer[:-1], None)) - elif isinstance(pindexer[-1], slice): - cindexers.append((*pindexer[:-1], curr_axis)) - curr_axis += 1 - + bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds) if adv_inds: - adv_inds = sorted(adv_inds) - source_axis = [ # after basic indexing applied # noqa - len([id for id in range(idx_id) if key[id] is not None]) - for idx_id in adv_inds - ] - adv_indexers = [np.array(key[i]) for i in adv_inds] # noqa - bcast_nd = np.broadcast(*adv_indexers).nd - adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:])) - bcast_start_ax = 0 if not adjacent else min(adv_inds) - adv_map = {} - - for idx_id, idxer in zip(adv_inds, adv_indexers): - base_idxer_ax_name = self._reverse_map[ # count non-None keys - len([id for id in range(idx_id) if key[id] is not None]) - ] - adv_map[base_idxer_ax_name] = [ - bcast_start_ax + shp - for shp in _compare_bcast_shapes(bcast_nd, idxer.shape) - ] - - conflicts = {} - for bcast_ax in range(bcast_nd): - ax_names = [name for name, axes in adv_map.items() if bcast_ax in axes] - if len(ax_names) > 1: - conflicts[bcast_ax] = ax_names - [] - if len(ax_names) == 0: - if "ax_unk" not in adv_map.keys(): - adv_map["ax_unk"] = [bcast_ax + bcast_start_ax] - else: - adv_map["ax_unk"].append(bcast_ax + bcast_start_ax) - - for conflict_axis, conflict_names in conflicts.items(): - new_name = "ax_" - for name in conflict_names: - adv_map[name].remove(conflict_axis) - if not adv_map[name]: - adv_map.pop(name) - new_name += name[3:] - adv_map[new_name] = [conflict_axis] - - # check if integer or boolean indexing - # if integer, check which dimensions get broadcast where - # if multiple, axes are merged. If adjacent, merged inplace, - # otherwise moved to beginning - remove_axes.append(adv_map.keys()) # Error: remove_axis takes ints - - out_obj = np.broadcast(np.array(key[i]) for i in adv_inds) # noqa - pass - # mulligan structured arrays, etc. + key = replace_adv_indexers(key, adv_inds, bcast_start_ax, bcast_nd) + remove_axes, new_axes, adv_names = _apply_indexing(key, self._reverse_map) + new_axes = _rename_broadcast_axes(new_axes, adv_names) new_map = _AxisMapping( self._ax_map.remove_axis(remove_axes), len(in_dim) - len(remove_axes) ) - new_map = _AxisMapping( - new_map.insert_axis(new_axes), - len(in_dim) - len(remove_axes) + len(new_axes), - ) + for new_ax_ind, new_ax_name in new_axes: + new_map = _AxisMapping( + new_map.insert_axis(new_ax_ind, new_ax_name), + len(in_dim) - len(remove_axes) + len(new_axes), + ) output._ax_map = new_map return output @@ -404,7 +335,7 @@ def concatenate(arrays, axis=0): def standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] -) -> tuple[tuple[StandardIndexer], tuple[KeyIndex]]: +) -> tuple[list[StandardIndexer], tuple[KeyIndex]]: """Convert any legal numpy indexer to a "standard" form. Standard form involves creating an equivalent indexer that is a tuple with @@ -432,7 +363,7 @@ def standardize_indexer( ax_key = np.array(ax_key) adv_inds.append(indexer_ind) new_key.append(ax_key) - return tuple(new_key), tuple(adv_inds) + return new_key, tuple(adv_inds) def _expand_indexer_ellipsis(indexers: list[Indexer], ndim: int) -> None: @@ -449,61 +380,63 @@ def _expand_indexer_ellipsis(indexers: list[Indexer], ndim: int) -> None: indexers[ellind : ellind + 1] = n_ellipsis_dims * (slice(None),) -def _adv_broadcast_magic(*args): - raise NotImplementedError - - -def _compare_bcast_shapes(result_ndim: int, base_shape: tuple[int]) -> list[int]: - """Identify which broadcast shape axes are due to base_shape - - Args: - result_ndim: number of dimensions broadcast shape has - base_shape: shape of one element of broadcasting - - Result: - tuple of axes in broadcast result that come from base shape - """ - return [ - result_ndim - 1 - ax_id - for ax_id, length in enumerate(reversed(base_shape)) - if length > 1 - ] - - -def _move_idxs_to_front(li: list, idxs: Sequence) -> None: - """Move all items at indexes specified to the front of a list""" - front = [] - for idx in reversed(idxs): - obj = li.pop(idx) - front.insert(0, obj) - li = front + li - - def _determine_adv_broadcasting( - key: StandardIndexer | Sequence[StandardIndexer], adv_inds: Sequence[OldIndex] -) -> tuple[bool, int, Optional[int]]: + key: Sequence[StandardIndexer], adv_inds: Sequence[OldIndex] +) -> tuple[int, Optional[KeyIndex]]: """Calculate the shape and location for the result of advanced indexing.""" adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:])) adv_indexers = [np.array(key[i]) for i in adv_inds] bcast_nd = np.broadcast(*adv_indexers).nd bcast_start_axis = 0 if not adjacent else min(adv_inds) if adv_inds else None - return adjacent, bcast_nd, bcast_start_axis - - -def _squeeze_to_sublist(li: list, idxs: Sequence[int]) -> list: - """Turn contiguous elements of a list into a sub-list in the same position - - e.g. _squeeze_to_sublist(["a", "b", "c", "d"], [1,2]) = ["a", ["b", "c"], "d"] - """ - for left, right in zip(idxs[:-1], idxs[1:]): - if left + 1 != right: - raise ValueError("Indexes to squeeze must be contiguous") - if not idxs: - return li - return li[: min(idxs)] + [[li[idx] for idx in idxs]] + li[max(idxs) + 1 :] - - -def _apply_basic_indexing(key: tuple[StandardIndexer]) -> tuple[list[int], list[int]]: + return bcast_nd, KeyIndex(bcast_start_axis) + + +def _rename_broadcast_axes( + new_axes: list[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], + adv_names: list[str], +) -> list[tuple[int, str]]: + """Normalize sentinel and NoneType names""" + + def _calc_bcast_name(*names: str) -> str: + if not names: + return "" + if all(a == b for a, b in zip(names[1:], names[:-1])): + return names[0] + names = [name[3:] for name in dict.fromkeys(names)] # ordered deduplication + return "ax_" + "_".join(names) + + bcast_name = _calc_bcast_name(*adv_names) + renamed_axes = [] + for ax_ind, ax_name in new_axes: + if ax_name is None: + renamed_axes.append((ax_ind, "ax_unk")) + elif ax_name is Sentinels.ADV_NAME: + renamed_axes.append((ax_ind, bcast_name)) + else: + renamed_axes.append((ax_ind, ax_name)) + return renamed_axes + + +def replace_adv_indexers( + key: Sequence[StandardIndexer], + adv_inds: list[int], + bcast_start_ax: int, + bcast_nd: int, +) -> tuple[ + Union[None, str, int, Literal[Sentinels.ADV_NAME], Literal[Sentinels.ADV_REMOVE]], + ..., +]: + for adv_ind in adv_inds: + key[adv_ind] = Sentinels.ADV_REMOVE + key = key[:bcast_start_ax] + bcast_nd * [Sentinels.ADV_NAME] + key[bcast_start_ax:] + return key + + +def _apply_indexing( + key: tuple[StandardIndexer], reverse_map: dict[int, str] +) -> tuple[ + list[int], list[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], list[str] +]: """Determine where axes should be removed and added Only considers the basic indexers in key. Numpy arrays are treated as @@ -511,18 +444,23 @@ def _apply_basic_indexing(key: tuple[StandardIndexer]) -> tuple[list[int], list[ """ remove_axes = [] new_axes = [] + adv_names = [] deleted_to_left = 0 added_to_left = 0 for key_ind, indexer in enumerate(key): - if isinstance(indexer, int): + if isinstance(indexer, int) or indexer is Sentinels.ADV_REMOVE: orig_arr_axis = key_ind - added_to_left + if indexer is Sentinels.ADV_REMOVE: + adv_names.append(reverse_map[orig_arr_axis]) remove_axes.append(orig_arr_axis) deleted_to_left += 1 - elif indexer is None: + elif ( + indexer is None or indexer is Sentinels.ADV_NAME or isinstance(indexer, str) + ): new_arr_axis = key_ind - deleted_to_left - new_axes.append(new_arr_axis) + new_axes.append((new_arr_axis, indexer)) added_to_left += 1 - return remove_axes, new_axes + return remove_axes, new_axes, adv_names def comprehend_axes(x): diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index 8892837df..ebe703f47 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -22,9 +22,6 @@ def test_repr(): assert result == expected -@pytest.mark.skip( - "Not until fancy indexing (boolean) either short-circuited or implemented" -) def test_ufunc_override(): # This is largely a clone of test_ufunc_override_with_super() from # numpy/core/tests/test_umath.py @@ -199,7 +196,7 @@ def test_basic_indexing_modifies_axes(): assert set(almost_new.ax_unk) == {0, 1, 3, 4} -def test_fancy_indexing_modifies_axes(): +def test_adv_indexing_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.arange(4).reshape((2, 2)), axes) flat = arr[[0, 1], [0, 1]] @@ -208,21 +205,23 @@ def test_fancy_indexing_modifies_axes(): assert flat.shape == (2,) np.testing.assert_array_equal(np.asarray(flat), np.array([0, 3])) - assert flat.ax__timecoord == 0 + assert flat.ax_time_coord == 0 with pytest.raises(AttributeError): flat.ax_coord with pytest.raises(AttributeError): flat.ax_time assert same.shape == arr.shape - np.testing.assert_equal(same, arr) - assert same.ax_time == 0 - assert same.ax_coord == 1 + np.testing.assert_equal(np.asarray(same), np.asarray(arr)) + assert same.ax_time_coord == [0, 1] + with pytest.raises(AttributeError): + same.ax_coord assert tpose.shape == arr.shape - np.testing.assert_equal(same, arr.T) - assert same.ax_time == 1 - assert same.ax_coord == 0 + np.testing.assert_equal(np.asarray(tpose), np.asarray(arr.T)) + assert tpose.ax_time_coord == [0, 1] + with pytest.raises(AttributeError): + tpose.ax_coord fat = arr[[[0, 1], [0, 1]]] assert fat.shape == (2, 2, 2) @@ -407,18 +406,6 @@ def test_insert_misordered_AxisMapping(): assert result == expected -def test_squeeze_to_sublist(): - li = ["a", "b", "c", "d"] - result = axes._squeeze_to_sublist(li, [1, 2]) - assert result == ["a", ["b", "c"], "d"] - - result = axes._squeeze_to_sublist(li, []) - assert result == li - - with pytest.raises(ValueError, match="Indexes to squeeze"): - axes._squeeze_to_sublist(li, [0, 2]) - - def test_determine_adv_broadcasting(): indexers = (1, np.ones(1), np.ones((4, 1)), np.ones(3)) res_adj, res_nd, res_start = axes._determine_adv_broadcasting(indexers, [1, 2, 3])