Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix data flags var name generation #1507

Merged
merged 2 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Bug fixes
* Fixed ``xclim.indices.run_length.lazy_indexing`` which would sometimes trigger the loading of auxiliary coordinates. (:issue:`1483`, :pull:`1484`).
* Indicators ``snd_season_length`` and ``snw_season_length`` will return 0 instead of NaN if all inputs have a (non-NaN) zero snow depth (or water-equivalent thickness). (:pull:`1492`, :issue:`1491`)
* Fixed a bug in the `pytest` configuration that could prevent testing data caching from occurring in systems where the platform-dependent cache directory is not found in the user's home. (:issue:`1468`, :pull:`1473`).
* Fix ``xclim.core.dataflags.data_flags`` variable name generation (:pull:`1507`).

Breaking changes
^^^^^^^^^^^^^^^^
Expand Down
17 changes: 17 additions & 0 deletions tests/test_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,20 @@ def test_era5_ecad_qc_flag(self, open_dataset):

df_flagged = df.ecad_compliant(bad_ds)
np.testing.assert_array_equal(df_flagged.ecad_qc_flag, False)

def test_names(self, pr_series):
pr = pr_series(np.zeros(365), start="1971-01-01")
flgs = df.data_flags(
pr,
flags={
"values_op_thresh_repeating_for_n_or_more_days": {
"op": "==",
"n": 5,
"thresh": "-5.1 mm d-1",
}
},
)
assert (
list(flgs.data_vars.keys())[0]
== "values_eq_minus5point1_repeating_for_5_or_more_days"
)
118 changes: 69 additions & 49 deletions xclim/core/dataflags.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from typing import Sequence

import numpy as np
import pint
import xarray

from ..indices.generic import binary_ops
from ..indices.run_length import suspicious_run
from .calendar import climatological_mean_doy, within_bnds_doy
from .formatting import update_xclim_history
Expand All @@ -23,6 +23,7 @@
VARIABLES,
InputKind,
MissingVariableError,
Quantified,
infer_kind_from_parameter,
raise_warn_or_log,
)
Expand All @@ -41,6 +42,8 @@ class DataQualityException(Exception):
Message prepended to the error messages.
"""

flag_array: xarray.Dataset = None

def __init__(
self,
flag_array: xarray.Dataset,
Expand Down Expand Up @@ -81,10 +84,20 @@ def __str__(self):
]


def register_methods(func):
"""Summarize all methods used in dataflags checks."""
_REGISTRY[func.__name__] = func
return func
def register_methods(variable_name=None):
"""Register a data flag functioné.

Argument can be the output variable name template. The template may use any of the stringable input arguments.
If not given, the function name is used instead, which may create variable conflicts.
"""

def _register_methods(func):
"""Summarize all methods used in dataflags checks."""
func.__dict__["variable_name"] = variable_name or func.__name__
_REGISTRY[func.__name__] = func
return func

return _register_methods


def _sanitize_attrs(da: xarray.DataArray) -> xarray.DataArray:
Expand All @@ -97,7 +110,7 @@ def _sanitize_attrs(da: xarray.DataArray) -> xarray.DataArray:
return da


@register_methods
@register_methods()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are parentheses needed here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a decorator level to the function: @register_methods() actually returns the real decorator, using the default value (None, thus func.__name__) as variable name.

Copy link
Collaborator

@Zeitsperre Zeitsperre Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand, but I'll give it a look on Monday with fresher eyes.

@update_xclim_history
@declare_units(tasmax="[temperature]", tasmin="[temperature]")
def tasmax_below_tasmin(
Expand Down Expand Up @@ -130,7 +143,7 @@ def tasmax_below_tasmin(
return tasmax_lt_tasmin


@register_methods
@register_methods()
@update_xclim_history
@declare_units(tas="[temperature]", tasmax="[temperature]")
def tas_exceeds_tasmax(
Expand Down Expand Up @@ -163,7 +176,7 @@ def tas_exceeds_tasmax(
return tas_gt_tasmax


@register_methods
@register_methods()
@update_xclim_history
@declare_units(tas="[temperature]", tasmin="[temperature]")
def tas_below_tasmin(
Expand Down Expand Up @@ -195,11 +208,11 @@ def tas_below_tasmin(
return tas_lt_tasmin


@register_methods
@register_methods()
@update_xclim_history
@declare_units(da="[temperature]")
@declare_units(da="[temperature]", thresh="[temperature]")
def temperature_extremely_low(
da: xarray.DataArray, *, thresh: str = "-90 degC"
da: xarray.DataArray, *, thresh: Quantified = "-90 degC"
) -> xarray.DataArray:
"""Check if temperatures values are below -90 degrees Celsius for any given day.

Expand Down Expand Up @@ -229,11 +242,11 @@ def temperature_extremely_low(
return extreme_low


@register_methods
@register_methods()
@update_xclim_history
@declare_units(da="[temperature]")
@declare_units(da="[temperature]", thresh="[temperature]")
def temperature_extremely_high(
da: xarray.DataArray, *, thresh: str = "60 degC"
da: xarray.DataArray, *, thresh: Quantified = "60 degC"
) -> xarray.DataArray:
"""Check if temperatures values exceed 60 degrees Celsius for any given day.

Expand Down Expand Up @@ -263,7 +276,7 @@ def temperature_extremely_high(
return extreme_high


@register_methods
@register_methods()
@update_xclim_history
def negative_accumulation_values(
da: xarray.DataArray,
Expand Down Expand Up @@ -293,11 +306,11 @@ def negative_accumulation_values(
return negative_accumulations


@register_methods
@register_methods()
@update_xclim_history
@declare_units(da="[precipitation]")
@declare_units(da="[precipitation]", thresh="[precipitation]")
def very_large_precipitation_events(
da: xarray.DataArray, *, thresh="300 mm d-1"
da: xarray.DataArray, *, thresh: Quantified = "300 mm d-1"
) -> xarray.DataArray:
"""Check if precipitation values exceed 300 mm/day for any given day.

Expand Down Expand Up @@ -329,10 +342,10 @@ def very_large_precipitation_events(
return very_large_events


@register_methods
@register_methods("values_{op}_{thresh}_repeating_for_{n}_or_more_days")
@update_xclim_history
def values_op_thresh_repeating_for_n_or_more_days(
da: xarray.DataArray, *, n: int, thresh: str, op: str = "=="
da: xarray.DataArray, *, n: int, thresh: Quantified, op: str = "=="
) -> xarray.DataArray:
"""Check if array values repeat at a given threshold for `N` or more days.

Expand Down Expand Up @@ -377,11 +390,14 @@ def values_op_thresh_repeating_for_n_or_more_days(
return repetitions


@register_methods
@register_methods()
@update_xclim_history
@declare_units(da="[speed]")
@declare_units(da="[speed]", lower="[speed]", upper="[speed]")
def wind_values_outside_of_bounds(
da: xarray.DataArray, *, lower: str = "0 m s-1", upper: str = "46 m s-1"
da: xarray.DataArray,
*,
lower: Quantified = "0 m s-1",
upper: Quantified = "46 m s-1",
) -> xarray.DataArray:
"""Check if variable values fall below 0% or rise above 100% for any given day.

Expand Down Expand Up @@ -419,7 +435,7 @@ def wind_values_outside_of_bounds(
# TODO: 'Many excessive dry days' = the amount of dry days lies outside a 14·bivariate standard deviation


@register_methods
@register_methods("outside_{n}_standard_deviations_of_climatology")
@update_xclim_history
def outside_n_standard_deviations_of_climatology(
da: xarray.DataArray,
Expand Down Expand Up @@ -475,7 +491,7 @@ def outside_n_standard_deviations_of_climatology(
return ~within_bounds


@register_methods
@register_methods("values_repeating_for_{n}_or_more_days")
@update_xclim_history
def values_repeating_for_n_or_more_days(
da: xarray.DataArray, *, n: int
Expand Down Expand Up @@ -508,7 +524,7 @@ def values_repeating_for_n_or_more_days(
return repetition


@register_methods
@register_methods()
@update_xclim_history
def percentage_values_outside_of_bounds(da: xarray.DataArray) -> xarray.DataArray:
"""Check if variable values fall below 0% or rise above 100% for any given day.
Expand Down Expand Up @@ -589,28 +605,35 @@ def data_flags( # noqa: C901
... )
"""

def _convert_value_to_str(var_name, val) -> str:
"""Convert variable units to an xarray data variable-like string."""
if isinstance(val, str):
try:
# Use pint to
val = str2pint(val).magnitude
if isinstance(val, float):
def get_variable_name(function, kwargs):
fmtargs = {}
kwargs = kwargs or {}
for arg, param in signature(function).parameters.items():
val = kwargs.get(arg, param.default)
kind = infer_kind_from_parameter(param)
if arg == "op":
fmtargs[arg] = val if val not in binary_ops else binary_ops[val]
elif kind in [
InputKind.FREQ_STR,
InputKind.NUMBER,
InputKind.STRING,
InputKind.DAY_OF_YEAR,
InputKind.DATE,
InputKind.BOOL,
]:
fmtargs[arg] = val
elif kind == InputKind.QUANTIFIED:
if isinstance(val, xarray.DataArray):
fmtargs[arg] = "array"
else:
val = str2pint(val).magnitude
if Decimal(val) % 1 == 0:
val = str(int(val))
else:
val = "point".join(str(val).split("."))
except pint.UndefinedUnitError:
pass

if isinstance(val, (int, str)):
# Replace spaces between units with underlines
var_name = var_name.replace(f"_{param}_", f"_{str(val).replace(' ', '_')}_")
# Change hyphens in units into the word "_minus_"
if "-" in var_name:
var_name = var_name.replace("-", "_minus_")

return var_name
val = str(val).replace(".", "point")
val = val.replace("-", "minus")
fmtargs[arg] = str(val)
return function.variable_name.format(**fmtargs)

def _missing_vars(function, dataset: xarray.Dataset, var_provided: str):
"""Handle missing variables in passed datasets."""
Expand Down Expand Up @@ -659,12 +682,9 @@ def _missing_vars(function, dataset: xarray.Dataset, var_provided: str):
for flag_func in flag_funcs:
for name, kwargs in flag_func.items():
func = _REGISTRY[name]
variable_name = str(name)
variable_name = get_variable_name(func, kwargs)
named_da_variable = None

if kwargs:
for param, value in kwargs.items():
variable_name = _convert_value_to_str(variable_name, value)
try:
extras = _missing_vars(func, ds, str(da.name))
# Entries in extras implies that there are two variables being compared
Expand Down
1 change: 1 addition & 0 deletions xclim/indices/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

__all__ = [
"aggregate_between_dates",
"binary_ops",
"compare",
"count_level_crossings",
"count_occurrences",
Expand Down