Skip to content

Commit

Permalink
Use T_DataArray in Weighted
Browse files Browse the repository at this point in the history
Allows subtypes.

(I had this in my git stash, so commiting it...)
  • Loading branch information
max-sixty committed Jan 21, 2024
1 parent 35b7ab1 commit b68680f
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from xarray.core.alignment import align, broadcast
from xarray.core.computation import apply_ufunc, dot
from xarray.core.pycompat import is_duck_dask_array
from xarray.core.types import Dims, T_Xarray
from xarray.core.types import Dims, T_DataArray, T_Xarray
from xarray.util.deprecation_helpers import _deprecate_positional_args

# Weighted quantile methods are a subset of the numpy supported quantile methods.
Expand Down Expand Up @@ -145,15 +145,15 @@ class Weighted(Generic[T_Xarray]):

__slots__ = ("obj", "weights")

def __init__(self, obj: T_Xarray, weights: DataArray) -> None:
def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None:
"""
Create a Weighted object
Parameters
----------
obj : DataArray or Dataset
obj : T_DataArray or Dataset
Object over which the weighted reduction operation is applied.
weights : DataArray
weights : T_DataArray
An array of weights associated with the values in the obj.
Each value in the obj contributes to the reduction operation
according to its associated weight.
Expand Down Expand Up @@ -189,7 +189,7 @@ def _weight_check(w):
_weight_check(weights.data)

self.obj: T_Xarray = obj
self.weights: DataArray = weights
self.weights: T_DataArray = weights

def _check_dim(self, dim: Dims):
"""raise an error if any dimension is missing"""
Expand All @@ -208,11 +208,11 @@ def _check_dim(self, dim: Dims):

@staticmethod
def _reduce(
da: DataArray,
weights: DataArray,
da: T_DataArray,
weights: T_DataArray,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""reduce using dot; equivalent to (da * weights).sum(dim, skipna)
for internal use only
Expand All @@ -230,7 +230,7 @@ def _reduce(
# DataArray (if `weights` has additional dimensions)
return dot(da, weights, dim=dim)

def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray:
def _sum_of_weights(self, da: T_DataArray, dim: Dims = None) -> T_DataArray:
"""Calculate the sum of weights, accounting for missing values"""

# we need to mask data values that are nan; else the weights are wrong
Expand All @@ -255,10 +255,10 @@ def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray:

def _sum_of_squares(
self,
da: DataArray,
da: T_DataArray,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s)."""

demeaned = da - da.weighted(self.weights).mean(dim=dim)
Expand All @@ -267,20 +267,20 @@ def _sum_of_squares(

def _weighted_sum(
self,
da: DataArray,
da: T_DataArray,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""Reduce a DataArray by a weighted ``sum`` along some dimension(s)."""

return self._reduce(da, self.weights, dim=dim, skipna=skipna)

def _weighted_mean(
self,
da: DataArray,
da: T_DataArray,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""Reduce a DataArray by a weighted ``mean`` along some dimension(s)."""

weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna)
Expand All @@ -291,10 +291,10 @@ def _weighted_mean(

def _weighted_var(
self,
da: DataArray,
da: T_DataArray,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""Reduce a DataArray by a weighted ``var`` along some dimension(s)."""

sum_of_squares = self._sum_of_squares(da, dim=dim, skipna=skipna)
Expand All @@ -305,21 +305,21 @@ def _weighted_var(

def _weighted_std(
self,
da: DataArray,
da: T_DataArray,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""Reduce a DataArray by a weighted ``std`` along some dimension(s)."""

return cast("DataArray", np.sqrt(self._weighted_var(da, dim, skipna)))
return cast("T_DataArray", np.sqrt(self._weighted_var(da, dim, skipna)))

def _weighted_quantile(
self,
da: DataArray,
da: T_DataArray,
q: ArrayLike,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""Apply a weighted ``quantile`` to a DataArray along some dimension(s)."""

def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray:
Expand Down

0 comments on commit b68680f

Please sign in to comment.