From 8084fd47d1dffe18b48ffef57cc1c9ff35208df9 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 15 Jan 2024 20:12:15 +0000 Subject: [PATCH] feat(axes): Add tensordot function for AxesArrays dispatches to einsum, which is apparently faster anyways Still need to write tests for linalg_solve, einsum, and tensordot --- pysindy/feature_library/weak_pde_library.py | 10 ++-- pysindy/utils/axes.py | 65 ++++++++++++++------- 2 files changed, 49 insertions(+), 26 deletions(-) diff --git a/pysindy/feature_library/weak_pde_library.py b/pysindy/feature_library/weak_pde_library.py index 02ed2851f..1566a2bca 100644 --- a/pysindy/feature_library/weak_pde_library.py +++ b/pysindy/feature_library/weak_pde_library.py @@ -474,11 +474,11 @@ def _set_up_weights(self): ) weights1 = weights1 + [weights2] - # TODO: get rest of code to work with AxesArray - deaxify = lambda arr_list: [np.asarray(arr) for arr in arr_list] - tweights = deaxify(tweights) - weights0 = deaxify(weights0) - weights1 = deaxify(weights1) + # TODO: get rest of code to work with AxesArray. Too unsure of + # which axis labels to use at this point to continue + tweights = [np.asarray(arr) for arr in tweights] + weights0 = [np.asarray(arr) for arr in weights0] + weights1 = [[np.asarray(arr) for arr in sublist] for sublist in weights1] # Product weights over the axes for time derivatives, shaped as inds_k self.fulltweights = [] deriv = np.zeros(self.grid_ndim) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 334887b9b..0c592a9dc 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -40,9 +40,6 @@ class Sentinels(Enum): ADV_REMOVE = object() -Literal[Sentinels.ADV_NAME] - - class _AxisMapping: """Convenience wrapper for a two-way map between axis names and indexes.""" @@ -479,7 +476,7 @@ def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None) @implements(np.einsum) -def _einsum( +def einsum( subscripts: str, *operands: AxesArray, out: Optional[NDArray] = None, **kwargs ) -> AxesArray: calc = np.einsum( @@ -550,7 +547,7 @@ def _join_unique_names(l_of_s: List[str]) -> str: @implements(np.linalg.solve) -def solve(a: AxesArray, b: AxesArray) -> AxesArray: +def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray: result = np.linalg.solve(np.asarray(a), np.asarray(b)) a_rev = a._ax_map.reverse_map contracted_axis_name = a_rev[sorted(a_rev)[-1]] @@ -562,6 +559,34 @@ def solve(a: AxesArray, b: AxesArray) -> AxesArray: return AxesArray(result, axes) +@implements(np.tensordot) +def tensordot( + a: AxesArray, b: AxesArray, axes: Union[int, Sequence[Sequence[int]]] = 2 +) -> AxesArray: + sub = _tensordot_to_einsum(a.ndim, b.ndim, axes) + return einsum(sub, a, b) + + +def _tensordot_to_einsum( + a_ndim: int, b_ndim: int, axes: Union[int, Sequence[Sequence[int]]] +) -> str: + lc_ord = range(97, 123) + if isinstance(axes, int): + if axes > 26: + raise ValueError("Too many axes") + sub_a = f"...{[chr(code) for code in lc_ord[:axes]]}" + sub_b_li = f"{[chr(code) for code in lc_ord[:axes]]}..." + sub = sub_a + sub_b_li + else: + sub_a = f"{[chr(code) for code in lc_ord[:a_ndim]]}" + sub_b_li = [chr(code) for code in lc_ord[a_ndim : a_ndim + b_ndim]] + for a_ind, b_ind in zip(*axes): + sub_b_li[b_ind] - sub_a[a_ind] + sub_b = "".join(sub_b_li) + sub = f"{sub_a},{sub_b}" + return sub + + def standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] ) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]: @@ -748,26 +773,24 @@ def wrap_axes(axes: dict, obj): return obj -T = TypeVar("T") # TODO: Bind to a non-sequence after type-negation PEP +T = TypeVar("T", bound=int) # TODO: Bind to a non-sequence after type-negation PEP +ItemOrList = Union[T, list[T]] +CompatDict = dict[str, ItemOrList[T]] def _compat_dict_append( - compat_dict: dict[str, Union[T, list[T]]], + compat_dict: CompatDict[T], key: str, - item_or_list: Union[T, list[T]], + item_or_list: ItemOrList[T], ) -> None: """Add an element or list of elements to a dictionary, preserving old values""" + try: + prev_val = compat_dict[key] + except KeyError: + compat_dict[key] = item_or_list + return if not isinstance(item_or_list, list): - try: - compat_dict[key].append(item_or_list) - except KeyError: - compat_dict[key] = item_or_list - except AttributeError: - compat_dict[key] = [compat_dict[key], item_or_list] - else: - try: - compat_dict[key] += item_or_list - except KeyError: - compat_dict[key] = item_or_list - except AttributeError: - compat_dict[key] = [compat_dict[key], *item_or_list] + item_or_list = [item_or_list] + if not isinstance(prev_val, list): + prev_val = [prev_val] + compat_dict[key] = prev_val + item_or_list