Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add min_weight param to rolling_exp functions #8285

Merged
merged 9 commits into from
Oct 14, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ New Features
the ``other`` parameter, passing the object as the only argument. Previously,
this was only valid for the ``cond`` parameter. (:issue:`8255`)
By `Maximilian Roos <https://github.com/max-sixty>`_.
- ``.rolling_exp`` functions can now take a ``min_weight`` parameter, to only
output values when there are sufficient recent non-nan values.
``numbagg>=0.3.1`` is required. (:pull:`8285`)
By `Maximilian Roos <https://github.com/max-sixty>`_.
- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for
the ``variables`` parameter, passing the object as the only argument.
By `Maximilian Roos <https://github.com/max-sixty>`_.
Expand Down
40 changes: 16 additions & 24 deletions xarray/core/rolling_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
from xarray.core.computation import apply_ufunc
from xarray.core.options import _get_keep_attrs
from xarray.core.pdcompat import count_not_none
from xarray.core.pycompat import is_duck_dask_array
from xarray.core.types import T_DataWithCoords, T_DuckArray
from xarray.core.types import T_DataWithCoords

try:
import numbagg
from numbagg import move_exp_nanmean, move_exp_nansum

has_numbagg = numbagg.__version__
except ImportError:
has_numbagg = False
Comment on lines +13 to +19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a lazy import? Does our tests catch it?

Could you add numbagg here so we don't regress with regards to #6726:

https://github.com/pydata/xarray/blob/e8be4bbb961f58ba733852c998f2863f3ff644b1/xarray/tests/test_plugins.py#L216C9-L239

You can also use this for checking module availability:

def module_available(module: str) -> bool:
"""Checks whether a module is installed without importing it.
Use this for a lightweight check and lazy imports.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a lazy import per se (it used to be, but I'm not sure that was necessarily deliberate vs. coincidental)

Using module_available doesn't seem to get the version?

What's the standard way of getting an optional module and checking the version in xarray? I'm happy to abide by the standard if there is one. The current approach does seem to work well in isolation...



def _get_alpha(
Expand All @@ -25,26 +32,6 @@ def _get_alpha(
return 1 / (1 + com)


def move_exp_nanmean(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray:
if is_duck_dask_array(array):
raise TypeError("rolling_exp is not currently support for dask-like arrays")
import numbagg

# No longer needed in numbag > 0.2.0; remove in time
if axis == ():
return array.astype(np.float64)
else:
return numbagg.move_exp_nanmean(array, axis=axis, alpha=alpha)


def move_exp_nansum(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray:
if is_duck_dask_array(array):
raise TypeError("rolling_exp is not currently supported for dask-like arrays")
import numbagg

return numbagg.move_exp_nansum(array, axis=axis, alpha=alpha)


def _get_center_of_mass(
comass: float | None,
span: float | None,
Expand Down Expand Up @@ -110,11 +97,16 @@ def __init__(
obj: T_DataWithCoords,
windows: Mapping[Any, int | float],
window_type: str = "span",
min_weight: float = 0.0,
):
if has_numbagg is False or has_numbagg < "0.3.1":
raise ImportError("numbagg >= 0.3.1 is required for rolling_exp")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use

def mod_version(mod: ModType) -> Version:
"""Quick wrapper to get the version of the module."""
return _get_cached_duck_array_module(mod).version

To get the version.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though this isn't a duck array module — numba / numbagg operates on numpy arrays. Would I still use this?


self.obj: T_DataWithCoords = obj
dim, window = next(iter(windows.items()))
self.dim = dim
self.alpha = _get_alpha(**{window_type: window})
self.min_weight = min_weight

def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
"""
Expand Down Expand Up @@ -145,7 +137,7 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
move_exp_nanmean,
self.obj,
input_core_dims=[[self.dim]],
kwargs=dict(alpha=self.alpha, axis=-1),
kwargs=dict(alpha=self.alpha, min_weight=self.min_weight, axis=-1),
output_core_dims=[[self.dim]],
exclude_dims={self.dim},
keep_attrs=keep_attrs,
Expand Down Expand Up @@ -181,7 +173,7 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
move_exp_nansum,
self.obj,
input_core_dims=[[self.dim]],
kwargs=dict(alpha=self.alpha, axis=-1),
kwargs=dict(alpha=self.alpha, min_weight=self.min_weight, axis=-1),
output_core_dims=[[self.dim]],
exclude_dims={self.dim},
keep_attrs=keep_attrs,
Expand Down
Loading