-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add min_weight
param to rolling_exp
functions
#8285
Changes from 2 commits
18f4690
fe7c58e
4f54090
ab102c9
8a8e56d
44287fd
7302880
d9843f8
b463968
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -8,8 +8,15 @@ | |||||||
from xarray.core.computation import apply_ufunc | ||||||||
from xarray.core.options import _get_keep_attrs | ||||||||
from xarray.core.pdcompat import count_not_none | ||||||||
from xarray.core.pycompat import is_duck_dask_array | ||||||||
from xarray.core.types import T_DataWithCoords, T_DuckArray | ||||||||
from xarray.core.types import T_DataWithCoords | ||||||||
|
||||||||
try: | ||||||||
import numbagg | ||||||||
from numbagg import move_exp_nanmean, move_exp_nansum | ||||||||
|
||||||||
has_numbagg = numbagg.__version__ | ||||||||
except ImportError: | ||||||||
has_numbagg = False | ||||||||
|
||||||||
|
||||||||
def _get_alpha( | ||||||||
|
@@ -25,26 +32,6 @@ def _get_alpha( | |||||||
return 1 / (1 + com) | ||||||||
|
||||||||
|
||||||||
def move_exp_nanmean(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray: | ||||||||
if is_duck_dask_array(array): | ||||||||
raise TypeError("rolling_exp is not currently support for dask-like arrays") | ||||||||
import numbagg | ||||||||
|
||||||||
# No longer needed in numbag > 0.2.0; remove in time | ||||||||
if axis == (): | ||||||||
return array.astype(np.float64) | ||||||||
else: | ||||||||
return numbagg.move_exp_nanmean(array, axis=axis, alpha=alpha) | ||||||||
|
||||||||
|
||||||||
def move_exp_nansum(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray: | ||||||||
if is_duck_dask_array(array): | ||||||||
raise TypeError("rolling_exp is not currently supported for dask-like arrays") | ||||||||
import numbagg | ||||||||
|
||||||||
return numbagg.move_exp_nansum(array, axis=axis, alpha=alpha) | ||||||||
|
||||||||
|
||||||||
def _get_center_of_mass( | ||||||||
comass: float | None, | ||||||||
span: float | None, | ||||||||
|
@@ -110,11 +97,16 @@ def __init__( | |||||||
obj: T_DataWithCoords, | ||||||||
windows: Mapping[Any, int | float], | ||||||||
window_type: str = "span", | ||||||||
min_weight: float = 0.0, | ||||||||
): | ||||||||
if has_numbagg is False or has_numbagg < "0.3.1": | ||||||||
raise ImportError("numbagg >= 0.3.1 is required for rolling_exp") | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use xarray/xarray/core/pycompat.py Lines 81 to 83 in e8be4bb
To get the version. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Though this isn't a duck array module — numba / numbagg operates on numpy arrays. Would I still use this? |
||||||||
|
||||||||
self.obj: T_DataWithCoords = obj | ||||||||
dim, window = next(iter(windows.items())) | ||||||||
self.dim = dim | ||||||||
self.alpha = _get_alpha(**{window_type: window}) | ||||||||
self.min_weight = min_weight | ||||||||
|
||||||||
def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords: | ||||||||
""" | ||||||||
|
@@ -145,7 +137,7 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords: | |||||||
move_exp_nanmean, | ||||||||
self.obj, | ||||||||
input_core_dims=[[self.dim]], | ||||||||
kwargs=dict(alpha=self.alpha, axis=-1), | ||||||||
kwargs=dict(alpha=self.alpha, min_weight=self.min_weight, axis=-1), | ||||||||
output_core_dims=[[self.dim]], | ||||||||
exclude_dims={self.dim}, | ||||||||
keep_attrs=keep_attrs, | ||||||||
|
@@ -181,7 +173,7 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: | |||||||
move_exp_nansum, | ||||||||
self.obj, | ||||||||
input_core_dims=[[self.dim]], | ||||||||
kwargs=dict(alpha=self.alpha, axis=-1), | ||||||||
kwargs=dict(alpha=self.alpha, min_weight=self.min_weight, axis=-1), | ||||||||
output_core_dims=[[self.dim]], | ||||||||
exclude_dims={self.dim}, | ||||||||
keep_attrs=keep_attrs, | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a lazy import? Does our tests catch it?
Could you add numbagg here so we don't regress with regards to #6726:
https://github.com/pydata/xarray/blob/e8be4bbb961f58ba733852c998f2863f3ff644b1/xarray/tests/test_plugins.py#L216C9-L239
You can also use this for checking module availability:
xarray/xarray/namedarray/utils.py
Lines 81 to 84 in e8be4bb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not a lazy import per se (it used to be, but I'm not sure that was necessarily deliberate vs. coincidental)
Using
module_available
doesn't seem to get the version?What's the standard way of getting an optional module and checking the version in xarray? I'm happy to abide by the standard if there is one. The current approach does seem to work well in isolation...