diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index ed27957aa..25c42db8f 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -346,11 +346,23 @@ def __array_finalize__(self, obj) -> None: elif all( ( isinstance(obj, AxesArray), + hasattr(obj, "_ax_map"), not hasattr(self, "_ax_map"), self.shape == obj.shape, ) ): self._ax_map = _AxisMapping(obj.axes, obj.ndim) + # Using a poorly-initialized AxesArray + # Occurs in MaskedArray.ravel, used in some plotting. MaskedArray views + # of AxesArray lose the axes attributes, and then the _ax_map attributes. + # See numpy.ma.core:asanyarray + elif all( + ( + isinstance(obj, AxesArray), + not hasattr(obj, "_ax_map"), + ) + ): + self._ax_map = _AxisMapping({"ax_unk": 0}, in_ndim=1) # maybe add errors for incompatible views? def __array_ufunc__( @@ -418,6 +430,16 @@ def decorator(func): return decorator +@_implements(np.ravel) +def ravel(a, order="C"): + out = np.ravel(np.asarray(a), order=order) + is_1d_already = len(a.shape) == 1 + if is_1d_already: + return AxesArray(out, a.axes) + else: + return AxesArray(out, {"ax_unk": 0}) + + @_implements(np.ix_) def ix_(*args: AxesArray): calc = np.ix_(*(np.asarray(arg) for arg in args)) diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index b26a73890..6e85cd38b 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -629,6 +629,27 @@ def test_tensordot_list_axes(): assert_array_equal(result, super_result) +def test_ravel_1d(): + arr = AxesArray(np.array([1, 2]), axes={"ax_a": 0}) + result = np.ravel(arr) + assert_array_equal(result, arr) + assert result.axes == arr.axes + + +def test_ravel_nd(): + arr = AxesArray(np.array([[1, 2], [3, 4]]), axes={"ax_a": 0, "ax_b": 1}) + result = np.ravel(arr) + expected = np.ravel(np.asarray(arr)) + assert_array_equal(result, expected) + assert result.axes == {"ax_unk": 0} + + +def test_ma_ravel(): + arr = AxesArray(np.array([1, 2]), axes={"ax_a": 0}) + marr = np.ma.MaskedArray(arr) + np.ma.ravel(marr) + + @pytest.mark.skip def test_einsum_implicit(): ...