From 42088a172fe91d0d607d14edfded5812736b844a Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Fri, 5 Jan 2024 17:34:15 +0000 Subject: [PATCH] ENH: Enable boolean advanced indexing in AxesArray Modify the standardize_indexer() function Parameterize StandardIndexer --- pysindy/utils/axes.py | 59 ++++++++++++++++++++++++----------------- test/utils/test_axes.py | 17 ++++++------ 2 files changed, 44 insertions(+), 32 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 1e803b1ba..2235f4a53 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -17,8 +17,8 @@ AxesWarning = type("AxesWarning", (SyntaxWarning,), {}) BasicIndexer = Union[slice, int, type(Ellipsis), type(None)] -Indexer = BasicIndexer | NDArray -StandardIndexer = Union[slice, int, type(None), NDArray] +Indexer = BasicIndexer | NDArray | list +StandardIndexer = Union[slice, int, type(None), NDArray[np.dtype(int)]] OldIndex = NewType("OldIndex", int) # Before moving advanced axes adajent KeyIndex = NewType("KeyIndex", int) NewIndex = NewType("NewIndex", int) @@ -338,12 +338,13 @@ def concatenate(arrays, axis=0): def standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] -) -> tuple[list[StandardIndexer], tuple[KeyIndex]]: +) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]: """Convert any legal numpy indexer to a "standard" form. Standard form involves creating an equivalent indexer that is a tuple with one element per index of the original axis. All advanced indexer elements - are converted to numpy arrays + are converted to numpy arrays, and boolean arrays are converted to + integer arrays with obj.nonzero(). Returns: A tuple of the normalized indexer as well as the indexes of advanced indexers @@ -351,36 +352,46 @@ def standardize_indexer( if isinstance(key, tuple): key = list(key) else: - key = [ - key, - ] + key = [key] + if not any(ax_key is Ellipsis for ax_key in key): key = [*key, Ellipsis] - _expand_indexer_ellipsis(key, arr.ndim) - new_key: list[Indexer] = [] - adv_inds: list[int] = [] - for indexer_ind, ax_key in enumerate(key): + for ax_key in key: if not isinstance(ax_key, BasicIndexer): ax_key = np.array(ax_key) - adv_inds.append(indexer_ind) + if ax_key.dtype == np.dtype(np.bool_): + new_key += ax_key.nonzero() + continue new_key.append(ax_key) - return new_key, tuple(adv_inds) + new_key = _expand_indexer_ellipsis(new_key, arr.ndim) + # Can't identify position of advanced indexers before expanding ellipses + adv_inds: list[KeyIndex] = [] + for key_ind, ax_key in enumerate(new_key): + if isinstance(ax_key, np.ndarray): + adv_inds.append(KeyIndex(key_ind)) -def _expand_indexer_ellipsis(indexers: list[Indexer], ndim: int) -> None: - """Replace ellipsis in indexers with the appropriate amount of slice(None) + return new_key, tuple(adv_inds) - Mutates indexers - """ - try: - ellind = indexers.index(Ellipsis) - except ValueError: - return - n_new_dims = sum(k is None for k in indexers) - n_ellipsis_dims = ndim - (len(indexers) - n_new_dims - 1) - indexers[ellind : ellind + 1] = n_ellipsis_dims * (slice(None),) + +def _expand_indexer_ellipsis(key: list[Indexer], ndim: int) -> list[Indexer]: + """Replace ellipsis in indexers with the appropriate amount of slice(None)""" + # [...].index errors if list contains numpy array + ellind = [ind for ind, val in enumerate(key) if val is ...][0] + new_key = [] + n_new_dims = sum(ax_key is None for ax_key in key) + n_ellipsis_dims = ndim - (len(key) - n_new_dims - 1) + new_key = ( + key[:ellind] + + n_ellipsis_dims + * [ + slice(None), + ] + + key[ellind + 1 + n_ellipsis_dims :] + ) + return new_key def _determine_adv_broadcasting( diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index f33b94750..e396fad05 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -98,15 +98,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): c = np.add.reduce(a, 1, None, b) assert_equal(c, check) assert_(c is b) - - -@pytest.mark.skip("Expected error") -def test_ufunc_override_accumulate(): - d = np.array([[1, 2, 3], [1, 2, 3]]) - a = AxesArray(d, {"ax_time": [0, 1]}) check = np.add.accumulate(d, axis=0) c = np.add.accumulate(a, axis=0) - assert_equal(c, check) + # assert_equal(c, check) b = np.zeros_like(c) c = np.add.accumulate(a, 0, None, b) assert_equal(c, check) @@ -238,7 +232,7 @@ def test_standardize_basic_indexer(): assert result_fancy == () -def test_standardize_fancy_indexer(): +def test_standardize_advanced_indexer(): arr = np.arange(6).reshape(2, 3) result_indexer, result_fancy = axes.standardize_indexer(arr, [1]) assert result_indexer == [np.ones(1), slice(None)] @@ -251,6 +245,13 @@ def test_standardize_fancy_indexer(): assert result_fancy == (1,) +def test_standardize_bool_indexer(): + arr = np.ones((1, 2)) + result, result_adv = axes.standardize_indexer(arr, [[True, True]]) + assert_equal(result, [[0, 0], [0, 1]]) + assert result_adv == (0, 1) + + def test_reduce_AxisMapping(): ax_map = _AxisMapping( {