diff --git a/pyproject.toml b/pyproject.toml index 60028b62a..65fc0b263 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ docs = [ "ipython", "pandoc", "sphinx-rtd-theme", - "sphinx==5.3.0", + "sphinx==7.1.2", "sphinxcontrib-apidoc", "nbsphinx" ] diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index fa51e4863..d90c6d959 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -2,7 +2,9 @@ A module that defines one external class, AxesArray, to act like a numpy array but keep track of axis definitions. -TODO: Add developer documentation here. +.. todo:: + + Add developer documentation here. The recommended way to refactor existing code to use AxesArrays is to add them at the lowest level possible. Enter debug mode and see how long the expected @@ -43,6 +45,9 @@ OldIndex = NewType("OldIndex", int) # Before moving advanced axes adajent KeyIndex = NewType("KeyIndex", int) NewIndex = NewType("NewIndex", int) +T = TypeVar("T", bound=int) # TODO: Bind to a non-sequence after type-negation PEP +ItemOrList = Union[T, List[T]] +CompatDict = Dict[str, ItemOrList[T]] class Sentinels(Enum): @@ -178,27 +183,31 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): """A numpy-like array that keeps track of the meaning of its axes. Limitations: - * Not all numpy functions, such as ``np.flatten()``, does not have an - implementation for AxesArray, a regular numpy array is returned. - * For functions that are implemented for `AxesArray`, such as - ``np.reshape()``, use the numpy function rather than the bound - method (e.g. arr.reshape) - * Such such functions may raise ValueErrors where numpy would not, when - it is impossible to determine the output axis labels. - Bound methods, such as arr.reshape, are not implemented. Use the functions. - While the functions in the numpy namespace will work on ``AxesArray`` - objects, the documentation must be found in their equivalent names here. + * Not all numpy functions, such as ``np.flatten()``, have an + implementation for ``AxesArray``. In such cases a regular numpy array + is returned. + * For functions that are implemented for `AxesArray`, such as + ``np.reshape()``, use the numpy function rather than the bound + method (e.g. ``arr.reshape``) + * Such functions may raise ``ValueError`` where numpy would not, when + it is impossible to determine the output axis labels. Current array function implementations: + * ``np.concatenate`` * ``np.reshape`` * ``np.transpose`` + * ``np.linalg.solve`` + * ``np.einsum`` + * ``np.tensordot`` Indexing: AxesArray supports all of the basic and advanced indexing of numpy arrays, with the addition that new axes can be inserted with a string - name for the axis. If ``None`` or ``np.newaxis`` are passed, the + name for the axis. E.g. ``arr = arr[..., "lineno"]`` will add a + length-one axis at the end, along with the properties ``arr.ax_lineno`` + and ``arr.n_lineno``. If ``None`` or ``np.newaxis`` are passed, the axis is named "unk". Parameters: @@ -215,7 +224,7 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): Raises: * AxesWarning if axes does not match shape of input_array. * ValueError if assigning the same axis index to multiple meanings or - assigning an axis beyond ndim. + assigning an axis beyond ndim. """ @@ -239,6 +248,7 @@ def _reverse_map(self): @property def shape(self): + """Shape of array. Unlike numpy ndarray, this is not assignable.""" return super().shape def insert_axis( @@ -279,10 +289,10 @@ def __getitem__(self, key: Union[Indexer, Sequence[Indexer]], /): if not isinstance(output, AxesArray): return output # return an element from the array in_dim = self.shape - key, adv_inds = standardize_indexer(self, key) + key, adv_inds = _standardize_indexer(self, key) bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds) if adv_inds: - key = replace_adv_indexers(key, adv_inds, bcast_start_ax, bcast_nd) + key = _replace_adv_indexers(key, adv_inds, bcast_start_ax, bcast_nd) remove_axes, new_axes, adv_names = _apply_indexing(key, self._reverse_map) new_axes = _rename_broadcast_axes(new_axes, adv_names) new_map = _AxisMapping( @@ -384,8 +394,8 @@ def __array_function__(self, func, types, args, kwargs): return HANDLED_FUNCTIONS[func](*args, **kwargs) -def implements(numpy_function): - """Register an __array_function__ implementation for MyArray objects.""" +def _implements(numpy_function): + """Register an __array_function__ implementation for AxesArray objects.""" def decorator(func): HANDLED_FUNCTIONS[numpy_function] = func @@ -394,7 +404,7 @@ def decorator(func): return decorator -@implements(np.ix_) +@_implements(np.ix_) def ix_(*args: AxesArray): calc = np.ix_(*(np.asarray(arg) for arg in args)) ax_names = [list(arr.axes)[0] for arr in args] @@ -402,7 +412,7 @@ def ix_(*args: AxesArray): return tuple(AxesArray(arr, axes) for arr in calc) -@implements(np.concatenate) +@_implements(np.concatenate) def concatenate(arrays, axis=0, out=None, dtype=None, casting="same_kind"): parents = [np.asarray(obj) for obj in arrays] ax_list = [obj.axes for obj in arrays if isinstance(obj, AxesArray)] @@ -415,7 +425,7 @@ def concatenate(arrays, axis=0, out=None, dtype=None, casting="same_kind"): return AxesArray(result, axes=ax_list[0]) -@implements(np.reshape) +@_implements(np.reshape) def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): """Gives a new shape to an array without changing its data. @@ -481,7 +491,7 @@ def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): return AxesArray(out, axes=new_axes) -@implements(np.transpose) +@_implements(np.transpose) def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None): """Returns an array with axes transposed. @@ -500,7 +510,7 @@ def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None) return AxesArray(out, new_axes) -@implements(np.einsum) +@_implements(np.einsum) def einsum( subscripts: str, *operands: AxesArray, out: Optional[NDArray] = None, **kwargs ) -> AxesArray: @@ -580,7 +590,7 @@ def _label_einsum_scripts( return allscript_names -@implements(np.linalg.solve) +@_implements(np.linalg.solve) def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray: result = np.linalg.solve(np.asarray(a), np.asarray(b)) a_rev = a._ax_map.reverse_map @@ -601,7 +611,7 @@ def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray: return AxesArray(result, axes) -@implements(np.tensordot) +@_implements(np.tensordot) def tensordot( a: AxesArray, b: AxesArray, axes: Union[int, Sequence[Sequence[int]]] = 2 ) -> AxesArray: @@ -626,7 +636,7 @@ def _tensordot_to_einsum( return sub -def standardize_indexer( +def _standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] ) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]: """Convert any legal numpy indexer to a "standard" form. @@ -635,6 +645,7 @@ def standardize_indexer( one element per index of the original axis. All advanced indexer elements are converted to numpy arrays, and boolean arrays are converted to integer arrays with obj.nonzero(). + Returns: A tuple of the normalized indexer as well as the indexes of advanced indexers @@ -714,7 +725,7 @@ def _calc_bcast_name(*names: str) -> str: return renamed_axes -def replace_adv_indexers( +def _replace_adv_indexers( key: Sequence[StandardIndexer], adv_inds: List[int], bcast_start_ax: int, @@ -812,11 +823,6 @@ def wrap_axes(axes: dict, obj): return obj -T = TypeVar("T", bound=int) # TODO: Bind to a non-sequence after type-negation PEP -ItemOrList = Union[T, List[T]] -CompatDict = Dict[str, ItemOrList[T]] - - def _compat_dict_append( compat_dict: CompatDict[T], key: str, diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index b4f8fb3d4..c7327f240 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -292,11 +292,11 @@ def test_adv_indexing_adds_axes(): def test_standardize_basic_indexer(): arr = np.arange(6).reshape(2, 3) - result_indexer, result_fancy = axes.standardize_indexer(arr, Ellipsis) + result_indexer, result_fancy = axes._standardize_indexer(arr, Ellipsis) assert result_indexer == [slice(None), slice(None)] assert result_fancy == () - result_indexer, result_fancy = axes.standardize_indexer( + result_indexer, result_fancy = axes._standardize_indexer( arr, (np.newaxis, 1, 1, Ellipsis) ) assert result_indexer == [None, 1, 1] @@ -305,11 +305,11 @@ def test_standardize_basic_indexer(): def test_standardize_advanced_indexer(): arr = np.arange(6).reshape(2, 3) - result_indexer, result_fancy = axes.standardize_indexer(arr, [1]) + result_indexer, result_fancy = axes._standardize_indexer(arr, [1]) assert result_indexer == [np.ones(1), slice(None)] assert result_fancy == (0,) - result_indexer, result_fancy = axes.standardize_indexer( + result_indexer, result_fancy = axes._standardize_indexer( arr, (np.newaxis, [1], 1, Ellipsis) ) assert result_indexer == [None, np.ones(1), 1] @@ -318,7 +318,7 @@ def test_standardize_advanced_indexer(): def test_standardize_bool_indexer(): arr = np.ones((1, 2)) - result, result_adv = axes.standardize_indexer(arr, [[True, True]]) + result, result_adv = axes._standardize_indexer(arr, [[True, True]]) assert_equal(result, [[0, 0], [0, 1]]) assert result_adv == (0, 1)