Skip to content

Commit

Permalink
feat/doc(axes): Make helpers public so docs pick them up
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Jan 17, 2024
1 parent 2c56053 commit 3ede6d0
Showing 1 changed file with 68 additions and 30 deletions.
98 changes: 68 additions & 30 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,41 @@
"""
A module that defines one external class, AxesArray, to act like a numpy array
but keep track of axis definitions.
but keep track of axis definitions. It aims to allow meaningful replacement
of magic numbers for axis conventions in code. E.g::
import numpy as np
arr = AxesArray(np.ones((2,3,4)), {"ax_time": 0, "ax_spatial": [1, 2]})
print(arr.axes)
print(arr.ax_time)
print(arr.n_time)
print(arr.ax_spatial)
print(arr.n_spatial)
Would show::
{"ax_time": 0, "ax_spatial": [1, 2]}
0
2
[1, 2]
[3, 4]
It is up to the user to handle the ``list[int] | int`` return values, but this
module has several functions to deal with the axes dictionary, internally
referred to as type ``CompatDict[T]``:
Appending an item to a ``CompatDict[T]``
:py:func:`compat_dict_append`
Generating a ``CompatDict[int]`` of axes from list of axes names:
:py:func:`fwd_from_names`
Create new ``CompatDict[int]`` from this ``AxesArray`` with new axis/axes added:
:py:meth:`AxesArray.insert_axis`
Create new ``CompatDict[int]`` from this ``AxesArray`` with axis/axes removed:
:py:meth:`AxesArray.remove_axis`
.. todo::
Expand Down Expand Up @@ -50,7 +85,7 @@
CompatDict = Dict[str, ItemOrList[T]]


class Sentinels(Enum):
class _Sentinels(Enum):
ADV_NAME = object()
ADV_REMOVE = object()

Expand Down Expand Up @@ -95,13 +130,6 @@ def coerce_sequence(obj):
AxesWarning,
)

@staticmethod
def fwd_from_names(names: List[str]) -> Dict[str, Sequence[int]]:
fwd_map: Dict[str, Sequence[int]] = {}
for ax_ind, name in enumerate(names):
_compat_dict_append(fwd_map, name, [ax_ind])
return fwd_map

@staticmethod
def _compat_axes(in_dict: Dict[str, List[int]]) -> Dict[str, Union[list[int], int]]:
"""Like fwd_map, but unpack single-element axis lists"""
Expand Down Expand Up @@ -222,9 +250,9 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray):
n_<ax_name>: lookup shape of subarray defined by ax_name
Raises:
* AxesWarning if axes does not match shape of input_array.
* ValueError if assigning the same axis index to multiple meanings or
assigning an axis beyond ndim.
AxesWarning if axes does not match shape of input_array.
ValueError if assigning the same axis index to multiple meanings or
assigning an axis beyond ndim.
"""

Expand Down Expand Up @@ -310,7 +338,7 @@ def __array_wrap__(self, out_arr, context=None):
return super().__array_wrap__(self, out_arr, context)

Check warning on line 338 in pysindy/utils/axes.py

View check run for this annotation

Codecov / codecov/patch

pysindy/utils/axes.py#L338

Added line #L338 was not covered by tests

def __array_finalize__(self, obj) -> None:
if obj is None: # explicit construction via super().__new__().. not called?
if obj is None: # explicit construction via super().__new__()
return

Check warning on line 342 in pysindy/utils/axes.py

View check run for this annotation

Codecov / codecov/patch

pysindy/utils/axes.py#L342

Added line #L342 was not covered by tests
# view from numpy array, called in constructor but also tests
if all(
Expand Down Expand Up @@ -408,7 +436,7 @@ def decorator(func):
def ix_(*args: AxesArray):
calc = np.ix_(*(np.asarray(arg) for arg in args))
ax_names = [list(arr.axes)[0] for arr in args]
axes = _AxisMapping.fwd_from_names(ax_names)
axes = fwd_from_names(ax_names)
return tuple(AxesArray(arr, axes) for arr in calc)


Expand Down Expand Up @@ -461,7 +489,7 @@ def reshape(a: AxesArray, newshape: int | tuple[int], order="C"):
)
base_name = a._ax_map.reverse_map[curr_base]
if a.shape[curr_base] == newshape[curr_new]:
_compat_dict_append(new_axes, base_name, curr_new)
compat_dict_append(new_axes, base_name, curr_new)
curr_base += 1
elif newshape[curr_new] == 1:
raise ValueError(
Expand All @@ -486,7 +514,7 @@ def reshape(a: AxesArray, newshape: int | tuple[int], order="C"):
)
curr_base += 1

_compat_dict_append(new_axes, base_name, curr_new)
compat_dict_append(new_axes, base_name, curr_new)

return AxesArray(out, axes=new_axes)

Expand All @@ -505,7 +533,7 @@ def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None)
new_axes = {}
old_reverse = a._ax_map.reverse_map
for new_ind, old_ind in enumerate(axes):
_compat_dict_append(new_axes, old_reverse[old_ind], new_ind)
compat_dict_append(new_axes, old_reverse[old_ind], new_ind)

return AxesArray(out, new_axes)

Expand Down Expand Up @@ -545,7 +573,7 @@ def einsum(
ax_name = "ax_" + _join_unique_names(ax_names)
out_names.append(ax_name)

out_axes = _AxisMapping.fwd_from_names(out_names)
out_axes = fwd_from_names(out_names)
if isinstance(out, AxesArray):
out._ax_map = _AxisMapping(out_axes, calc.ndim)

Check warning on line 578 in pysindy/utils/axes.py

View check run for this annotation

Codecov / codecov/patch

pysindy/utils/axes.py#L578

Added line #L578 was not covered by tests
return AxesArray(calc, axes=out_axes)
Expand Down Expand Up @@ -586,7 +614,7 @@ def _label_einsum_scripts(
scr_name = op._ax_map.reverse_map[ax_ind]
else:
scr_name = op._ax_map.reverse_map[ax_ind - 3 + ell_width]

Check warning on line 616 in pysindy/utils/axes.py

View check run for this annotation

Codecov / codecov/patch

pysindy/utils/axes.py#L616

Added line #L616 was not covered by tests
_compat_dict_append(script_names, char, [scr_name])
compat_dict_append(script_names, char, [scr_name])
return allscript_names


Expand All @@ -607,7 +635,7 @@ def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray:
all_names = (
b_names[: align.stop - 1] + [contracted_axis_name] + b_names[align.stop :]
)
axes = _AxisMapping.fwd_from_names(all_names)
axes = fwd_from_names(all_names)
return AxesArray(result, axes)


Expand Down Expand Up @@ -700,7 +728,7 @@ def _determine_adv_broadcasting(


def _rename_broadcast_axes(
new_axes: List[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]],
new_axes: List[tuple[int, None | str | Literal[_Sentinels.ADV_NAME]]],
adv_names: List[str],
) -> List[tuple[int, str]]:
"""Normalize sentinel and NoneType names"""
Expand All @@ -718,7 +746,7 @@ def _calc_bcast_name(*names: str) -> str:
for ax_ind, ax_name in new_axes:
if ax_name is None:
renamed_axes.append((ax_ind, "ax_unk"))
elif ax_name is Sentinels.ADV_NAME:
elif ax_name is _Sentinels.ADV_NAME:
renamed_axes.append((ax_ind, bcast_name))
else:
renamed_axes.append((ax_ind, "ax_" + ax_name))
Expand All @@ -731,19 +759,19 @@ def _replace_adv_indexers(
bcast_start_ax: int,
bcast_nd: int,
) -> tuple[
Union[None, str, int, Literal[Sentinels.ADV_NAME], Literal[Sentinels.ADV_REMOVE]],
Union[None, str, int, Literal[_Sentinels.ADV_NAME], Literal[_Sentinels.ADV_REMOVE]],
...,
]:
for adv_ind in adv_inds:
key[adv_ind] = Sentinels.ADV_REMOVE
key = key[:bcast_start_ax] + bcast_nd * [Sentinels.ADV_NAME] + key[bcast_start_ax:]
key[adv_ind] = _Sentinels.ADV_REMOVE
key = key[:bcast_start_ax] + bcast_nd * [_Sentinels.ADV_NAME] + key[bcast_start_ax:]
return key


def _apply_indexing(
key: tuple[StandardIndexer], reverse_map: Dict[int, str]
) -> tuple[
List[int], List[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], List[str]
List[int], List[tuple[int, None | str | Literal[_Sentinels.ADV_NAME]]], List[str]
]:
"""Determine where axes should be removed and added
Expand All @@ -756,14 +784,16 @@ def _apply_indexing(
deleted_to_left = 0
added_to_left = 0
for key_ind, indexer in enumerate(key):
if isinstance(indexer, int) or indexer is Sentinels.ADV_REMOVE:
if isinstance(indexer, int) or indexer is _Sentinels.ADV_REMOVE:
orig_arr_axis = key_ind - added_to_left
if indexer is Sentinels.ADV_REMOVE:
if indexer is _Sentinels.ADV_REMOVE:
adv_names.append(reverse_map[orig_arr_axis])
remove_axes.append(orig_arr_axis)
deleted_to_left += 1
elif (
indexer is None or indexer is Sentinels.ADV_NAME or isinstance(indexer, str)
indexer is None
or indexer is _Sentinels.ADV_NAME
or isinstance(indexer, str)
):
new_arr_axis = key_ind - deleted_to_left
new_axes.append((new_arr_axis, indexer))
Expand Down Expand Up @@ -823,7 +853,7 @@ def wrap_axes(axes: dict, obj):
return obj


def _compat_dict_append(
def compat_dict_append(
compat_dict: CompatDict[T],
key: str,
item_or_list: ItemOrList[T],
Expand All @@ -839,3 +869,11 @@ def _compat_dict_append(
if not isinstance(prev_val, list):
prev_val = [prev_val]
compat_dict[key] = prev_val + item_or_list


def fwd_from_names(names: List[str]) -> CompatDict[int]:
"""Create mapping of name: axis or name: [ax_1, ax_2, ...]"""
fwd_map: Dict[str, Sequence[int]] = {}
for ax_ind, name in enumerate(names):
compat_dict_append(fwd_map, name, [ax_ind])
return fwd_map

0 comments on commit 3ede6d0

Please sign in to comment.