From 679b4dcabd46e6c4f5b2d6b3a17335f472382207 Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Mon, 8 Jan 2024 20:57:46 +0000 Subject: [PATCH] BUG(axes): Make concat out param work --- pysindy/utils/axes.py | 7 +++++-- test/utils/test_axes.py | 7 +++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index bad10d55c..ad0f79040 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -117,13 +117,16 @@ def decorator(func): @implements(np.concatenate) -def concatenate(arrays, axis=0): +def concatenate(arrays, axis=0, out=None, dtype=None, casting="same_kind"): parents = [np.asarray(obj) for obj in arrays] ax_list = [obj.__dict__ for obj in arrays if isinstance(obj, AxesArray)] for ax1, ax2 in zip(ax_list[:-1], ax_list[1:]): if ax1 != ax2: raise TypeError("Concatenating >1 AxesArray with incompatible axes") - return AxesArray(np.concatenate(parents, axis), axes=ax_list[0]) + result = np.concatenate(parents, axis, out=out, dtype=dtype, casting=casting) + if isinstance(out, AxesArray): + out.__dict__ = ax_list[0] + return AxesArray(result, axes=ax_list[0]) def comprehend_axes(x): diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index b1a38e6f4..e5d9a8385 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -7,6 +7,13 @@ from pysindy import AxesArray +def test_concat_out(): + arr = AxesArray(np.arange(3).reshape(1, 3), {"ax_a": 0, "ax_b": 1}) + arr_out = np.empty((2, 3)).view(AxesArray) + result = np.concatenate((arr, arr), axis=0, out=arr_out) + assert_equal(result, arr_out) + + def test_reduce_mean_noinf_recursion(): arr = AxesArray(np.array([[1]]), {}) np.mean(arr, axis=0)