Skip to content

Commit

Permalink
BUG(axes): Make concat out param work
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Jan 8, 2024
1 parent 638e4bb commit 679b4dc
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 5 additions & 2 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,16 @@ def decorator(func):


@implements(np.concatenate)
def concatenate(arrays, axis=0):
def concatenate(arrays, axis=0, out=None, dtype=None, casting="same_kind"):
parents = [np.asarray(obj) for obj in arrays]
ax_list = [obj.__dict__ for obj in arrays if isinstance(obj, AxesArray)]
for ax1, ax2 in zip(ax_list[:-1], ax_list[1:]):
if ax1 != ax2:
raise TypeError("Concatenating >1 AxesArray with incompatible axes")
return AxesArray(np.concatenate(parents, axis), axes=ax_list[0])
result = np.concatenate(parents, axis, out=out, dtype=dtype, casting=casting)
if isinstance(out, AxesArray):
out.__dict__ = ax_list[0]
return AxesArray(result, axes=ax_list[0])


def comprehend_axes(x):
Expand Down
7 changes: 7 additions & 0 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
from pysindy import AxesArray


def test_concat_out():
arr = AxesArray(np.arange(3).reshape(1, 3), {"ax_a": 0, "ax_b": 1})
arr_out = np.empty((2, 3)).view(AxesArray)
result = np.concatenate((arr, arr), axis=0, out=arr_out)
assert_equal(result, arr_out)


def test_reduce_mean_noinf_recursion():
arr = AxesArray(np.array([[1]]), {})
np.mean(arr, axis=0)
Expand Down

0 comments on commit 679b4dc

Please sign in to comment.