Skip to content

Commit

Permalink
Merge pull request astropy#16663 from neutrinoceros/units/utils/compa…
Browse files Browse the repository at this point in the history
…t/np_cumulative_funcs

BUG: declare `np.cumulative_prod` and `np.cumulative_sum` as subclass-safe and test them (fix incompatibility with NumPy 2.1)
  • Loading branch information
pllim authored Jul 3, 2024
2 parents 9e2a087 + db0aaf6 commit 70e701b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
2 changes: 1 addition & 1 deletion astropy/units/quantity_helper/function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
# trapz was renamed to trapezoid
SUBCLASS_SAFE_FUNCTIONS |= {np.trapezoid}
if not NUMPY_LT_2_1:
SUBCLASS_SAFE_FUNCTIONS |= {np.unstack}
SUBCLASS_SAFE_FUNCTIONS |= {np.unstack, np.cumulative_prod, np.cumulative_sum}

# Implemented as methods on Quantity:
# np.ediff1d is from setops, but we support it anyway; the others
Expand Down
9 changes: 9 additions & 0 deletions astropy/units/tests/test_quantity_non_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,10 @@ def test_sum(self):
def test_cumsum(self):
self.check(np.cumsum)

@pytest.mark.skipif(NUMPY_LT_2_1, reason="np.cumulative_sum is new in NumPy 2.1")
def test_cumulative_sum(self):
self.check(np.cumulative_sum, axis=1)

def test_any(self):
with pytest.raises(TypeError):
np.any(self.q)
Expand Down Expand Up @@ -713,6 +717,11 @@ def test_cumproduct(self):
with pytest.raises(u.UnitsError):
np.cumproduct(self.q) # noqa: NPY003, NPY201

@pytest.mark.skipif(NUMPY_LT_2_1, reason="np.cumulative_prod is new in NumPy 2.1")
def test_cumulative_prod(self):
with pytest.raises(u.UnitsError):
np.cumulative_prod(self.q, axis=1)


class TestUfuncLike(InvariantUnitTestSetup):
def test_ptp(self):
Expand Down
25 changes: 25 additions & 0 deletions astropy/utils/masked/tests/test_function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,31 @@ def test_array_equiv(self):
assert np.array_equiv(self.mb, np.stack([self.mb, self.mb]))


class TestArrayAPI:
@classmethod
def setup_class(self):
self.a = np.tile(np.arange(5.0), 2).reshape(2, 5)
self.mask_a = np.array([[False] * 5, [True] * 4 + [False]])
self.ma = Masked(self.a, mask=self.mask_a)

def check(self, func, *args, **kwargs):
out = func(self.ma, *args, **kwargs)
expected = func(self.a, *args, **kwargs)
assert type(out) is MaskedNDArray
assert out.dtype.kind == "f"
assert_array_equal(out.unmasked, expected)
assert_array_equal(out.mask, self.mask_a)
assert not np.may_share_memory(out.mask, self.mask_a)

@pytest.mark.skipif(NUMPY_LT_2_1, reason="np.cumulative_prod is new in NumPy 2.1")
def test_cumulative_prod(self):
self.check(np.cumulative_prod, axis=0)

@pytest.mark.skipif(NUMPY_LT_2_1, reason="np.cumulative_sum is new in NumPy 2.1")
def test_cumulative_sum(self):
self.check(np.cumulative_sum, axis=0)


class TestOuterLikeFunctions(MaskedArraySetup):
def test_outer(self):
result = np.outer(self.ma, self.mb)
Expand Down

0 comments on commit 70e701b

Please sign in to comment.