diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index d90c6d959..4224cf550 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -1,6 +1,41 @@ """ A module that defines one external class, AxesArray, to act like a numpy array -but keep track of axis definitions. +but keep track of axis definitions. It aims to allow meaningful replacement +of magic numbers for axis conventions in code. E.g:: + + import numpy as np + + arr = AxesArray(np.ones((2,3,4)), {"ax_time": 0, "ax_spatial": [1, 2]}) + print(arr.axes) + print(arr.ax_time) + print(arr.n_time) + print(arr.ax_spatial) + print(arr.n_spatial) + +Would show:: + + {"ax_time": 0, "ax_spatial": [1, 2]} + 0 + 2 + [1, 2] + [3, 4] + +It is up to the user to handle the ``list[int] | int`` return values, but this +module has several functions to deal with the axes dictionary, internally +referred to as type ``CompatDict[T]``: + +Appending an item to a ``CompatDict[T]`` + :py:func:`compat_dict_append` + +Generating a ``CompatDict[int]`` of axes from list of axes names: + :py:func:`fwd_from_names` + +Create new ``CompatDict[int]`` from this ``AxesArray`` with new axis/axes added: + :py:meth:`AxesArray.insert_axis` + +Create new ``CompatDict[int]`` from this ``AxesArray`` with axis/axes removed: + :py:meth:`AxesArray.remove_axis` + .. todo:: @@ -50,7 +85,7 @@ CompatDict = Dict[str, ItemOrList[T]] -class Sentinels(Enum): +class _Sentinels(Enum): ADV_NAME = object() ADV_REMOVE = object() @@ -95,13 +130,6 @@ def coerce_sequence(obj): AxesWarning, ) - @staticmethod - def fwd_from_names(names: List[str]) -> Dict[str, Sequence[int]]: - fwd_map: Dict[str, Sequence[int]] = {} - for ax_ind, name in enumerate(names): - _compat_dict_append(fwd_map, name, [ax_ind]) - return fwd_map - @staticmethod def _compat_axes(in_dict: Dict[str, List[int]]) -> Dict[str, Union[list[int], int]]: """Like fwd_map, but unpack single-element axis lists""" @@ -222,9 +250,9 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): n_: lookup shape of subarray defined by ax_name Raises: - * AxesWarning if axes does not match shape of input_array. - * ValueError if assigning the same axis index to multiple meanings or - assigning an axis beyond ndim. + AxesWarning if axes does not match shape of input_array. + ValueError if assigning the same axis index to multiple meanings or + assigning an axis beyond ndim. """ @@ -310,7 +338,7 @@ def __array_wrap__(self, out_arr, context=None): return super().__array_wrap__(self, out_arr, context) def __array_finalize__(self, obj) -> None: - if obj is None: # explicit construction via super().__new__().. not called? + if obj is None: # explicit construction via super().__new__() return # view from numpy array, called in constructor but also tests if all( @@ -408,7 +436,7 @@ def decorator(func): def ix_(*args: AxesArray): calc = np.ix_(*(np.asarray(arg) for arg in args)) ax_names = [list(arr.axes)[0] for arr in args] - axes = _AxisMapping.fwd_from_names(ax_names) + axes = fwd_from_names(ax_names) return tuple(AxesArray(arr, axes) for arr in calc) @@ -461,7 +489,7 @@ def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): ) base_name = a._ax_map.reverse_map[curr_base] if a.shape[curr_base] == newshape[curr_new]: - _compat_dict_append(new_axes, base_name, curr_new) + compat_dict_append(new_axes, base_name, curr_new) curr_base += 1 elif newshape[curr_new] == 1: raise ValueError( @@ -486,7 +514,7 @@ def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): ) curr_base += 1 - _compat_dict_append(new_axes, base_name, curr_new) + compat_dict_append(new_axes, base_name, curr_new) return AxesArray(out, axes=new_axes) @@ -505,7 +533,7 @@ def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None) new_axes = {} old_reverse = a._ax_map.reverse_map for new_ind, old_ind in enumerate(axes): - _compat_dict_append(new_axes, old_reverse[old_ind], new_ind) + compat_dict_append(new_axes, old_reverse[old_ind], new_ind) return AxesArray(out, new_axes) @@ -545,7 +573,7 @@ def einsum( ax_name = "ax_" + _join_unique_names(ax_names) out_names.append(ax_name) - out_axes = _AxisMapping.fwd_from_names(out_names) + out_axes = fwd_from_names(out_names) if isinstance(out, AxesArray): out._ax_map = _AxisMapping(out_axes, calc.ndim) return AxesArray(calc, axes=out_axes) @@ -586,7 +614,7 @@ def _label_einsum_scripts( scr_name = op._ax_map.reverse_map[ax_ind] else: scr_name = op._ax_map.reverse_map[ax_ind - 3 + ell_width] - _compat_dict_append(script_names, char, [scr_name]) + compat_dict_append(script_names, char, [scr_name]) return allscript_names @@ -607,7 +635,7 @@ def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray: all_names = ( b_names[: align.stop - 1] + [contracted_axis_name] + b_names[align.stop :] ) - axes = _AxisMapping.fwd_from_names(all_names) + axes = fwd_from_names(all_names) return AxesArray(result, axes) @@ -700,7 +728,7 @@ def _determine_adv_broadcasting( def _rename_broadcast_axes( - new_axes: List[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], + new_axes: List[tuple[int, None | str | Literal[_Sentinels.ADV_NAME]]], adv_names: List[str], ) -> List[tuple[int, str]]: """Normalize sentinel and NoneType names""" @@ -718,7 +746,7 @@ def _calc_bcast_name(*names: str) -> str: for ax_ind, ax_name in new_axes: if ax_name is None: renamed_axes.append((ax_ind, "ax_unk")) - elif ax_name is Sentinels.ADV_NAME: + elif ax_name is _Sentinels.ADV_NAME: renamed_axes.append((ax_ind, bcast_name)) else: renamed_axes.append((ax_ind, "ax_" + ax_name)) @@ -731,19 +759,19 @@ def _replace_adv_indexers( bcast_start_ax: int, bcast_nd: int, ) -> tuple[ - Union[None, str, int, Literal[Sentinels.ADV_NAME], Literal[Sentinels.ADV_REMOVE]], + Union[None, str, int, Literal[_Sentinels.ADV_NAME], Literal[_Sentinels.ADV_REMOVE]], ..., ]: for adv_ind in adv_inds: - key[adv_ind] = Sentinels.ADV_REMOVE - key = key[:bcast_start_ax] + bcast_nd * [Sentinels.ADV_NAME] + key[bcast_start_ax:] + key[adv_ind] = _Sentinels.ADV_REMOVE + key = key[:bcast_start_ax] + bcast_nd * [_Sentinels.ADV_NAME] + key[bcast_start_ax:] return key def _apply_indexing( key: tuple[StandardIndexer], reverse_map: Dict[int, str] ) -> tuple[ - List[int], List[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], List[str] + List[int], List[tuple[int, None | str | Literal[_Sentinels.ADV_NAME]]], List[str] ]: """Determine where axes should be removed and added @@ -756,14 +784,16 @@ def _apply_indexing( deleted_to_left = 0 added_to_left = 0 for key_ind, indexer in enumerate(key): - if isinstance(indexer, int) or indexer is Sentinels.ADV_REMOVE: + if isinstance(indexer, int) or indexer is _Sentinels.ADV_REMOVE: orig_arr_axis = key_ind - added_to_left - if indexer is Sentinels.ADV_REMOVE: + if indexer is _Sentinels.ADV_REMOVE: adv_names.append(reverse_map[orig_arr_axis]) remove_axes.append(orig_arr_axis) deleted_to_left += 1 elif ( - indexer is None or indexer is Sentinels.ADV_NAME or isinstance(indexer, str) + indexer is None + or indexer is _Sentinels.ADV_NAME + or isinstance(indexer, str) ): new_arr_axis = key_ind - deleted_to_left new_axes.append((new_arr_axis, indexer)) @@ -823,7 +853,7 @@ def wrap_axes(axes: dict, obj): return obj -def _compat_dict_append( +def compat_dict_append( compat_dict: CompatDict[T], key: str, item_or_list: ItemOrList[T], @@ -839,3 +869,11 @@ def _compat_dict_append( if not isinstance(prev_val, list): prev_val = [prev_val] compat_dict[key] = prev_val + item_or_list + + +def fwd_from_names(names: List[str]) -> CompatDict[int]: + """Create mapping of name: axis or name: [ax_1, ax_2, ...]""" + fwd_map: Dict[str, Sequence[int]] = {} + for ax_ind, name in enumerate(names): + compat_dict_append(fwd_map, name, [ax_ind]) + return fwd_map