Skip to content

Commit

Permalink
ENH: Enable boolean advanced indexing in AxesArray
Browse files Browse the repository at this point in the history
Modify the standardize_indexer() function
Parameterize StandardIndexer
  • Loading branch information
Jacob-Stevens-Haas committed Jan 5, 2024
1 parent 49ac5a0 commit 42088a1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 32 deletions.
59 changes: 35 additions & 24 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

AxesWarning = type("AxesWarning", (SyntaxWarning,), {})
BasicIndexer = Union[slice, int, type(Ellipsis), type(None)]
Indexer = BasicIndexer | NDArray
StandardIndexer = Union[slice, int, type(None), NDArray]
Indexer = BasicIndexer | NDArray | list
StandardIndexer = Union[slice, int, type(None), NDArray[np.dtype(int)]]
OldIndex = NewType("OldIndex", int) # Before moving advanced axes adajent
KeyIndex = NewType("KeyIndex", int)
NewIndex = NewType("NewIndex", int)
Expand Down Expand Up @@ -338,49 +338,60 @@ def concatenate(arrays, axis=0):

def standardize_indexer(
arr: np.ndarray, key: Indexer | Sequence[Indexer]
) -> tuple[list[StandardIndexer], tuple[KeyIndex]]:
) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]:
"""Convert any legal numpy indexer to a "standard" form.
Standard form involves creating an equivalent indexer that is a tuple with
one element per index of the original axis. All advanced indexer elements
are converted to numpy arrays
are converted to numpy arrays, and boolean arrays are converted to
integer arrays with obj.nonzero().
Returns:
A tuple of the normalized indexer as well as the indexes of
advanced indexers
"""
if isinstance(key, tuple):
key = list(key)
else:
key = [
key,
]
key = [key]

if not any(ax_key is Ellipsis for ax_key in key):
key = [*key, Ellipsis]

_expand_indexer_ellipsis(key, arr.ndim)

new_key: list[Indexer] = []
adv_inds: list[int] = []
for indexer_ind, ax_key in enumerate(key):
for ax_key in key:
if not isinstance(ax_key, BasicIndexer):
ax_key = np.array(ax_key)
adv_inds.append(indexer_ind)
if ax_key.dtype == np.dtype(np.bool_):
new_key += ax_key.nonzero()
continue
new_key.append(ax_key)
return new_key, tuple(adv_inds)

new_key = _expand_indexer_ellipsis(new_key, arr.ndim)
# Can't identify position of advanced indexers before expanding ellipses
adv_inds: list[KeyIndex] = []
for key_ind, ax_key in enumerate(new_key):
if isinstance(ax_key, np.ndarray):
adv_inds.append(KeyIndex(key_ind))

def _expand_indexer_ellipsis(indexers: list[Indexer], ndim: int) -> None:
"""Replace ellipsis in indexers with the appropriate amount of slice(None)
return new_key, tuple(adv_inds)

Mutates indexers
"""
try:
ellind = indexers.index(Ellipsis)
except ValueError:
return
n_new_dims = sum(k is None for k in indexers)
n_ellipsis_dims = ndim - (len(indexers) - n_new_dims - 1)
indexers[ellind : ellind + 1] = n_ellipsis_dims * (slice(None),)

def _expand_indexer_ellipsis(key: list[Indexer], ndim: int) -> list[Indexer]:
"""Replace ellipsis in indexers with the appropriate amount of slice(None)"""
# [...].index errors if list contains numpy array
ellind = [ind for ind, val in enumerate(key) if val is ...][0]
new_key = []
n_new_dims = sum(ax_key is None for ax_key in key)
n_ellipsis_dims = ndim - (len(key) - n_new_dims - 1)
new_key = (
key[:ellind]
+ n_ellipsis_dims
* [
slice(None),
]
+ key[ellind + 1 + n_ellipsis_dims :]
)
return new_key


def _determine_adv_broadcasting(
Expand Down
17 changes: 9 additions & 8 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
c = np.add.reduce(a, 1, None, b)
assert_equal(c, check)
assert_(c is b)


@pytest.mark.skip("Expected error")
def test_ufunc_override_accumulate():
d = np.array([[1, 2, 3], [1, 2, 3]])
a = AxesArray(d, {"ax_time": [0, 1]})
check = np.add.accumulate(d, axis=0)
c = np.add.accumulate(a, axis=0)
assert_equal(c, check)
# assert_equal(c, check)
b = np.zeros_like(c)
c = np.add.accumulate(a, 0, None, b)
assert_equal(c, check)
Expand Down Expand Up @@ -238,7 +232,7 @@ def test_standardize_basic_indexer():
assert result_fancy == ()


def test_standardize_fancy_indexer():
def test_standardize_advanced_indexer():
arr = np.arange(6).reshape(2, 3)
result_indexer, result_fancy = axes.standardize_indexer(arr, [1])
assert result_indexer == [np.ones(1), slice(None)]
Expand All @@ -251,6 +245,13 @@ def test_standardize_fancy_indexer():
assert result_fancy == (1,)


def test_standardize_bool_indexer():
arr = np.ones((1, 2))
result, result_adv = axes.standardize_indexer(arr, [[True, True]])
assert_equal(result, [[0, 0], [0, 1]])
assert result_adv == (0, 1)


def test_reduce_AxisMapping():
ax_map = _AxisMapping(
{
Expand Down

0 comments on commit 42088a1

Please sign in to comment.