Skip to content

Commit

Permalink
feat(axes): Make np.transpose work on AxesArray
Browse files Browse the repository at this point in the history
Finally a simple one
  • Loading branch information
Jacob-Stevens-Haas committed Jan 13, 2024
1 parent f13d593 commit 996d555
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
25 changes: 25 additions & 0 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import NewType
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

import numpy as np
Expand Down Expand Up @@ -167,6 +168,11 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray):
While the functions in the numpy namespace will work on ``AxesArray``
objects, the documentation must be found in their equivalent names here.
Current array function implementations:
* ``np.concatenate``
* ``np.reshape``
* ``np.transpose``
Parameters:
input_array: the data to create the array.
axes: A dictionary of axis labels to shape indices. Axes labels must
Expand Down Expand Up @@ -421,6 +427,25 @@ def reshape(a: AxesArray, newshape: int | tuple[int], order="C"):
return AxesArray(out, axes=new_axes)


@implements(np.transpose)
def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None):
"""Returns an array with axes transposed.
Args:
a: input array
axes: As the numpy function
"""
out = np.transpose(np.asarray(a), axes)
if axes is None:
axes = range(a.ndim)[::-1]
new_axes = {}
old_reverse = a._ax_map.reverse_map
for new_ind, old_ind in enumerate(axes):
_compat_axes_append(new_axes, old_reverse[old_ind], new_ind)

return AxesArray(out, new_axes)


def standardize_indexer(
arr: np.ndarray, key: Indexer | Sequence[Indexer]
) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]:
Expand Down
14 changes: 14 additions & 0 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,3 +491,17 @@ def test_strip_ellipsis():
result = axes._expand_indexer_ellipsis(key, 1)
expected = [1]
assert result == expected


def test_transpose():
axes = {"ax_a": 0, "ax_b": [1, 2]}
arr = AxesArray(np.arange(8).reshape(2, 2, 2), axes)
tp = np.transpose(arr, [2, 0, 1])
result = tp.axes
expected = {"ax_a": 1, "ax_b": [0, 2]}
assert result == expected
assert_array_equal(tp, np.transpose(np.asarray(arr), [2, 0, 1]))
arr = arr[..., 0]
tp = arr.T
expected = {"ax_a": 1, "ax_b": 0}
assert_array_equal(tp, np.asarray(arr).T)

0 comments on commit 996d555

Please sign in to comment.