From b68680fc5b54b241f2a1f5184a7f50918ad65847 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sat, 20 Jan 2024 17:18:10 -0800 Subject: [PATCH] Use `T_DataArray` in `Weighted` Allows subtypes. (I had this in my git stash, so commiting it...) --- xarray/core/weighted.py | 44 ++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 53ff6db5f28..2a969aab1bd 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -10,7 +10,7 @@ from xarray.core.alignment import align, broadcast from xarray.core.computation import apply_ufunc, dot from xarray.core.pycompat import is_duck_dask_array -from xarray.core.types import Dims, T_Xarray +from xarray.core.types import Dims, T_DataArray, T_Xarray from xarray.util.deprecation_helpers import _deprecate_positional_args # Weighted quantile methods are a subset of the numpy supported quantile methods. @@ -145,15 +145,15 @@ class Weighted(Generic[T_Xarray]): __slots__ = ("obj", "weights") - def __init__(self, obj: T_Xarray, weights: DataArray) -> None: + def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None: """ Create a Weighted object Parameters ---------- - obj : DataArray or Dataset + obj : T_DataArray or Dataset Object over which the weighted reduction operation is applied. - weights : DataArray + weights : T_DataArray An array of weights associated with the values in the obj. Each value in the obj contributes to the reduction operation according to its associated weight. @@ -189,7 +189,7 @@ def _weight_check(w): _weight_check(weights.data) self.obj: T_Xarray = obj - self.weights: DataArray = weights + self.weights: T_DataArray = weights def _check_dim(self, dim: Dims): """raise an error if any dimension is missing""" @@ -208,11 +208,11 @@ def _check_dim(self, dim: Dims): @staticmethod def _reduce( - da: DataArray, - weights: DataArray, + da: T_DataArray, + weights: T_DataArray, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """reduce using dot; equivalent to (da * weights).sum(dim, skipna) for internal use only @@ -230,7 +230,7 @@ def _reduce( # DataArray (if `weights` has additional dimensions) return dot(da, weights, dim=dim) - def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray: + def _sum_of_weights(self, da: T_DataArray, dim: Dims = None) -> T_DataArray: """Calculate the sum of weights, accounting for missing values""" # we need to mask data values that are nan; else the weights are wrong @@ -255,10 +255,10 @@ def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray: def _sum_of_squares( self, - da: DataArray, + da: T_DataArray, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s).""" demeaned = da - da.weighted(self.weights).mean(dim=dim) @@ -267,20 +267,20 @@ def _sum_of_squares( def _weighted_sum( self, - da: DataArray, + da: T_DataArray, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """Reduce a DataArray by a weighted ``sum`` along some dimension(s).""" return self._reduce(da, self.weights, dim=dim, skipna=skipna) def _weighted_mean( self, - da: DataArray, + da: T_DataArray, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """Reduce a DataArray by a weighted ``mean`` along some dimension(s).""" weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna) @@ -291,10 +291,10 @@ def _weighted_mean( def _weighted_var( self, - da: DataArray, + da: T_DataArray, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """Reduce a DataArray by a weighted ``var`` along some dimension(s).""" sum_of_squares = self._sum_of_squares(da, dim=dim, skipna=skipna) @@ -305,21 +305,21 @@ def _weighted_var( def _weighted_std( self, - da: DataArray, + da: T_DataArray, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """Reduce a DataArray by a weighted ``std`` along some dimension(s).""" - return cast("DataArray", np.sqrt(self._weighted_var(da, dim, skipna))) + return cast("T_DataArray", np.sqrt(self._weighted_var(da, dim, skipna))) def _weighted_quantile( self, - da: DataArray, + da: T_DataArray, q: ArrayLike, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """Apply a weighted ``quantile`` to a DataArray along some dimension(s).""" def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray: