From b3cd8eb975e850ba600cf4b233423ff2c0df9cd4 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 2 Oct 2023 15:17:35 +0100 Subject: [PATCH] Duck array ops for `all` and `any` --- xarray/core/duck_array_ops.py | 18 +++++++++++++----- xarray/core/weighted.py | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 7e7333fd8ea..bbf358ff48b 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -16,8 +16,6 @@ import numpy as np import pandas as pd -from numpy import all as array_all # noqa: F401 -from numpy import any as array_any # noqa: F401 from numpy import ( # noqa: F401 isclose, isnat, @@ -319,7 +317,7 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): if lazy_equiv is None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") - return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) + return bool(array_all(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True))) else: return lazy_equiv @@ -333,7 +331,7 @@ def array_equiv(arr1, arr2): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "In the future, 'NAT == x'") flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2)) - return bool(flag_array.all()) + return bool(array_all(flag_array)) else: return lazy_equiv @@ -349,7 +347,7 @@ def array_notnull_equiv(arr1, arr2): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "In the future, 'NAT == x'") flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2) - return bool(flag_array.all()) + return bool(array_all(flag_array)) else: return lazy_equiv @@ -536,6 +534,16 @@ def f(values, axis=None, skipna=None, **kwargs): cumsum_1d.numeric_only = True +def array_all(array, axis=None, keepdims=False): + xp = get_array_namespace(array) + return xp.all(array, axis=axis, keepdims=keepdims) + + +def array_any(array, axis=None, keepdims=False): + xp = get_array_namespace(array) + return xp.any(array, axis=axis, keepdims=keepdims) + + _mean = _create_nan_agg_method("mean", invariant_0d=True) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 269cb49a2c1..cd24091b18e 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -171,7 +171,7 @@ def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None: def _weight_check(w): # Ref https://github.com/pydata/xarray/pull/4559/files#r515968670 - if duck_array_ops.isnull(w).any(): + if duck_array_ops.array_any(duck_array_ops.isnull(w)): raise ValueError( "`weights` cannot contain missing values. " "Missing values can be replaced by `weights.fillna(0)`."