Skip to content

Commit

Permalink
feat(axes): Add tensordot function for AxesArrays
Browse files Browse the repository at this point in the history
dispatches to einsum, which is apparently faster anyways

Still need to write tests for linalg_solve, einsum, and tensordot
  • Loading branch information
Jacob-Stevens-Haas committed Jan 15, 2024
1 parent f0fc6b3 commit 8084fd4
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 26 deletions.
10 changes: 5 additions & 5 deletions pysindy/feature_library/weak_pde_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,11 @@ def _set_up_weights(self):
)
weights1 = weights1 + [weights2]

# TODO: get rest of code to work with AxesArray
deaxify = lambda arr_list: [np.asarray(arr) for arr in arr_list]
tweights = deaxify(tweights)
weights0 = deaxify(weights0)
weights1 = deaxify(weights1)
# TODO: get rest of code to work with AxesArray. Too unsure of
# which axis labels to use at this point to continue
tweights = [np.asarray(arr) for arr in tweights]
weights0 = [np.asarray(arr) for arr in weights0]
weights1 = [[np.asarray(arr) for arr in sublist] for sublist in weights1]
# Product weights over the axes for time derivatives, shaped as inds_k
self.fulltweights = []
deriv = np.zeros(self.grid_ndim)
Expand Down
65 changes: 44 additions & 21 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ class Sentinels(Enum):
ADV_REMOVE = object()


Literal[Sentinels.ADV_NAME]


class _AxisMapping:
"""Convenience wrapper for a two-way map between axis names and indexes."""

Expand Down Expand Up @@ -479,7 +476,7 @@ def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None)


@implements(np.einsum)
def _einsum(
def einsum(
subscripts: str, *operands: AxesArray, out: Optional[NDArray] = None, **kwargs
) -> AxesArray:
calc = np.einsum(
Expand Down Expand Up @@ -550,7 +547,7 @@ def _join_unique_names(l_of_s: List[str]) -> str:


@implements(np.linalg.solve)
def solve(a: AxesArray, b: AxesArray) -> AxesArray:
def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray:
result = np.linalg.solve(np.asarray(a), np.asarray(b))
a_rev = a._ax_map.reverse_map
contracted_axis_name = a_rev[sorted(a_rev)[-1]]
Expand All @@ -562,6 +559,34 @@ def solve(a: AxesArray, b: AxesArray) -> AxesArray:
return AxesArray(result, axes)


@implements(np.tensordot)
def tensordot(
a: AxesArray, b: AxesArray, axes: Union[int, Sequence[Sequence[int]]] = 2
) -> AxesArray:
sub = _tensordot_to_einsum(a.ndim, b.ndim, axes)
return einsum(sub, a, b)


def _tensordot_to_einsum(
a_ndim: int, b_ndim: int, axes: Union[int, Sequence[Sequence[int]]]
) -> str:
lc_ord = range(97, 123)
if isinstance(axes, int):
if axes > 26:
raise ValueError("Too many axes")
sub_a = f"...{[chr(code) for code in lc_ord[:axes]]}"
sub_b_li = f"{[chr(code) for code in lc_ord[:axes]]}..."
sub = sub_a + sub_b_li
else:
sub_a = f"{[chr(code) for code in lc_ord[:a_ndim]]}"
sub_b_li = [chr(code) for code in lc_ord[a_ndim : a_ndim + b_ndim]]
for a_ind, b_ind in zip(*axes):
sub_b_li[b_ind] - sub_a[a_ind]
sub_b = "".join(sub_b_li)
sub = f"{sub_a},{sub_b}"
return sub


def standardize_indexer(
arr: np.ndarray, key: Indexer | Sequence[Indexer]
) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]:
Expand Down Expand Up @@ -748,26 +773,24 @@ def wrap_axes(axes: dict, obj):
return obj


T = TypeVar("T") # TODO: Bind to a non-sequence after type-negation PEP
T = TypeVar("T", bound=int) # TODO: Bind to a non-sequence after type-negation PEP
ItemOrList = Union[T, list[T]]
CompatDict = dict[str, ItemOrList[T]]


def _compat_dict_append(
compat_dict: dict[str, Union[T, list[T]]],
compat_dict: CompatDict[T],
key: str,
item_or_list: Union[T, list[T]],
item_or_list: ItemOrList[T],
) -> None:
"""Add an element or list of elements to a dictionary, preserving old values"""
try:
prev_val = compat_dict[key]
except KeyError:
compat_dict[key] = item_or_list
return
if not isinstance(item_or_list, list):
try:
compat_dict[key].append(item_or_list)
except KeyError:
compat_dict[key] = item_or_list
except AttributeError:
compat_dict[key] = [compat_dict[key], item_or_list]
else:
try:
compat_dict[key] += item_or_list
except KeyError:
compat_dict[key] = item_or_list
except AttributeError:
compat_dict[key] = [compat_dict[key], *item_or_list]
item_or_list = [item_or_list]
if not isinstance(prev_val, list):
prev_val = [prev_val]
compat_dict[key] = prev_val + item_or_list

0 comments on commit 8084fd4

Please sign in to comment.