diff --git a/CHANGES.rst b/CHANGES.rst index b2fe9195..820e457e 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -20,6 +20,8 @@ New features and enhancements * Better ``xs.extract.resample`` : support for weighted resampling operations when starting with frequencies coarser than daily and missing timesteps/values handling. (:issue:`80`, :issue:`93`, :pull:`265`). * New argument ``attribute_weights`` to ``generate_weights`` to allow for custom weights. (:pull:`252`). * ``xs.io.round_bits`` to round floating point variable up to a number of bits, allowing for a better compression. This can be combined with the saving step through argument ``"bitround"`` of ``save_to_netcdf`` and ``save_to_zarr``. (:pull:`266`). +* Added the ability to directly provide an ensemble dataset to ``xs.ensemble_stats``. (:pull:`299`). +* Added support in ``xs.ensemble_stats`` for the new robustness-related functions available in `xclim`. (:pull:`299`). Breaking changes ^^^^^^^^^^^^^^^^ diff --git a/tests/test_ensembles.py b/tests/test_ensembles.py index e8005762..d0e5cffe 100644 --- a/tests/test_ensembles.py +++ b/tests/test_ensembles.py @@ -1,3 +1,4 @@ +import logging from copy import deepcopy import numpy as np @@ -7,6 +8,8 @@ import xscen as xs +LOGGER = logging.getLogger(__name__) + class TestEnsembleStats: @staticmethod @@ -34,6 +37,27 @@ def make_ensemble(n): return ens + def test_input_type(self, tmpdir): + ens_dict = self.make_ensemble(10) + out_dict = xs.ensemble_stats( + ens_dict, statistics={"ensemble_mean_std_max_min": None} + ) + ens_ds = xr.concat(ens_dict, dim="realization") + out_ds = xs.ensemble_stats( + ens_ds, statistics={"ensemble_mean_std_max_min": None} + ) + paths = [] + for i, ds in enumerate(ens_dict): + paths.append(tmpdir / f"ens{i}.zarr") + ds.to_zarr(paths[-1]) + out_zarr = xs.ensemble_stats( + paths, statistics={"ensemble_mean_std_max_min": None} + ).compute() + + assert out_dict.equals(out_ds) + # The zarr introduced some rounding errors + assert np.round(out_zarr, 10).equals(np.round(out_dict, 10)) + @pytest.mark.parametrize( "weights, to_level", [(None, None), (np.arange(1, 11), "for_tests")] ) @@ -88,9 +112,99 @@ def test_weights(self, weights, to_level): decimal=2, ) - def test_change_significance(self): - # TODO: Possible nearby changes to xclim.ensembles.change_significance would break this test. - pass + # FIXME: This function is deprecated, so this test should be removed eventually. + @pytest.mark.parametrize("p_vals", [True, False]) + def test_change_significance(self, p_vals): + ens = self.make_ensemble(10) + with pytest.warns( + FutureWarning, + match="Function change_significance is deprecated as of xclim 0.47", + ): + out = xs.ensemble_stats( + ens, + statistics={"change_significance": {"test": None, "p_vals": p_vals}}, + ) + + assert len(out.data_vars) == 3 if p_vals else 2 + + @pytest.mark.parametrize("fractions", ["only", "both", "nested", "missing"]) + def test_robustness_input(self, fractions): + ens = self.make_ensemble(10) + + weights = None + if fractions == "only": + statistics = { + "robustness_fractions": {"test": "threshold", "abs_thresh": 2.5} + } + elif fractions == "both": + statistics = { + "robustness_fractions": {"test": "threshold", "abs_thresh": 2.5}, + "robustness_categories": { + "categories": ["robust", "non-robust"], + "thresholds": [(0.5, 0.5), (0.5, 0.5)], + "ops": [(">=", ">="), ("<", "<")], + }, + } + elif fractions == "nested": + weights = xr.DataArray( + [0] * 5 + [1] * 5, dims="realization" + ) # Mostly to check that this is put in the nested dict + statistics = { + "robustness_categories": { + "robustness_fractions": {"test": "threshold", "abs_thresh": 2.5}, + "categories": ["robust", "non-robust"], + "thresholds": [(0.5, 0.5), (0.5, 0.5)], + "ops": [(">=", ">="), ("<", "<")], + } + } + elif fractions == "missing": + statistics = { + "robustness_categories": { + "categories": ["robust", "non-robust"], + "thresholds": [(0.5, 0.5), (0.5, 0.5)], + "ops": [(">=", ">="), ("<", "<")], + } + } + with pytest.raises( + ValueError, + match="'robustness_categories' requires 'robustness_fractions'", + ): + xs.ensemble_stats(ens, statistics=statistics) + return + + out = xs.ensemble_stats(ens, statistics=statistics, weights=weights) + + assert len(out.data_vars) == {"only": 5, "both": 6, "nested": 1}[fractions] + if fractions in ["only", "both"]: + np.testing.assert_array_equal(out.tg_mean_changed, [0, 0.4, 1, 1]) + np.testing.assert_array_equal(out.tg_mean_agree, [1, 1, 1, 1]) + if fractions in ["both", "nested"]: + np.testing.assert_array_equal( + out.tg_mean_robustness_categories, + {"both": [99, 99, 1, 1], "nested": [99, 1, 1, 1]}[fractions], + ) + np.testing.assert_array_equal( + out.tg_mean_robustness_categories.attrs["flag_descriptions"], + ["robust", "non-robust"], + ) + + @pytest.mark.parametrize("symbol", ["rel.", "relative", "*", "/", "pct.", "abs."]) + def test_robustness_reldelta(self, caplog, symbol): + ens = self.make_ensemble(10) + for e in ens: + e["tg_mean"].attrs["delta_kind"] = symbol + + with caplog.at_level(logging.INFO): + xs.ensemble_stats( + ens, + statistics={ + "robustness_fractions": {"test": "threshold", "abs_thresh": 2.5} + }, + ) + if symbol in ["rel.", "relative", "*", "/"]: + assert "Relative delta detected" in caplog.text + else: + assert "Relative delta detected" not in caplog.text @pytest.mark.parametrize("common_attrs_only", [True, False]) def test_common_attrs_only(self, common_attrs_only): @@ -107,6 +221,53 @@ def test_common_attrs_only(self, common_attrs_only): assert out.attrs.get("foo", None) == "bar" assert ("bar0" not in out.attrs) == common_attrs_only + def test_errors(self): + ens = self.make_ensemble(10) + + # Warning if the statistic does not support weighting + weights = xr.DataArray([0] * 5 + [1] * 5, dims="realization") + with pytest.warns(UserWarning, match="Weighting is not supported"): + with pytest.raises( + TypeError + ): # kkz is not actually supported here, but it's one of the few that will not support weighting + xs.ensemble_stats( + ens, statistics={"kkz_reduce_ensemble": None}, weights=weights + ) + + # Error if you try to use a relative delta with a reference dataset + for e in ens: + e["tg_mean"].attrs["delta_kind"] = "rel." + ref = weights + with pytest.raises( + ValueError, match="is a delta, but 'ref' was still specified." + ): + xs.ensemble_stats( + ens, + statistics={ + "robustness_fractions": { + "test": "threshold", + "abs_thresh": 2.5, + "ref": ref, + } + }, + ) + + # Error if you try to use a robustness_fractions with a reference dataset, but also specify other statistics + with pytest.raises( + ValueError, match="The input requirements for 'robustness_fractions'" + ): + xs.ensemble_stats( + ens, + statistics={ + "robustness_fractions": { + "test": "threshold", + "abs_thresh": 2.5, + "ref": ref, + }, + "ensemble_mean_std_max_min": None, + }, + ) + class TestGenerateWeights: @staticmethod diff --git a/xscen/ensembles.py b/xscen/ensembles.py index 01b3e296..1053f6e5 100644 --- a/xscen/ensembles.py +++ b/xscen/ensembles.py @@ -21,9 +21,13 @@ @parse_config -def ensemble_stats( +def ensemble_stats( # noqa: C901 datasets: Union[ - dict, list[Union[str, os.PathLike]], list[xr.Dataset], list[xr.DataArray] + dict, + list[Union[str, os.PathLike]], + list[xr.Dataset], + list[xr.DataArray], + xr.Dataset, ], statistics: dict, *, @@ -36,14 +40,17 @@ def ensemble_stats( Parameters ---------- - datasets : dict or list of str, Path, Dataset or DataArray + datasets : dict or list of [str, os.PathLike, Dataset or DataArray], or Dataset List of file paths or xarray Dataset/DataArray objects to include in the ensemble. A dictionary can be passed instead of a list, in which case the keys are used as coordinates along the new `realization` axis. Tip: With a project catalog, you can do: `datasets = pcat.search(**search_dict).to_dataset_dict()`. + If a single Dataset is passed, it is assumed to already be an ensemble and will be used as is. The 'realization' dimension is required. statistics : dict xclim.ensembles statistics to be called. Dictionary in the format {function: arguments}. - If a function requires 'ref', the dictionary entry should be the inputs of a .loc[], e.g. {"ref": {"horizon": "1981-2010"}} + If a function requires 'weights', you can leave it out of this dictionary and + it will be applied automatically if the 'weights' argument is provided. + See the Notes section for more details on robustness statistics, which are more complex in their usage. create_kwargs : dict, optional Dictionary of arguments for xclim.ensembles.create_ensemble. weights : xr.DataArray, optional @@ -59,62 +66,150 @@ def ensemble_stats( xr.Dataset Dataset with ensemble statistics + Notes + ----- + * The positive fraction in 'change_significance' and 'robustness_fractions' is calculated by + xclim using 'v > 0', which is not appropriate for relative deltas. + This function will attempt to detect relative deltas by using the 'delta_kind' attribute ('rel.', 'relative', '*', or '/') + and will apply 'v - 1' before calling the function. + * The 'robustness_categories' statistic requires the outputs of 'robustness_fractions'. + Thus, there are two ways to build the 'statistics' dictionary: + + 1. Having 'robustness_fractions' and 'robustness_categories' as separate entries in the dictionary. + In this case, all outputs will be returned. + 2. Having 'robustness_fractions' as a nested dictionary under 'robustness_categories'. + In this case, only the robustness categories will be returned. + + * A 'ref' DataArray can be passed to 'change_significance' and 'robustness_fractions', which will be used by xclim to compute deltas + and perform some significance tests. However, this supposes that both 'datasets' and 'ref' are still timeseries (e.g. annual means), + not climatologies where the 'time' dimension represents the period over which the climatology was computed. Thus, + using 'ref' is only accepted if 'robustness_fractions' (or 'robustness_categories') is the only statistic being computed. + * If you want to use compute a robustness statistic on a climatology, you should first compute the climatologies and deltas yourself, + then leave 'ref' as None and pass the deltas as the 'datasets' argument. This will be compatible with other statistics. + See Also -------- xclim.ensembles._base.create_ensemble, xclim.ensembles._base.ensemble_percentiles, - xclim.ensembles._base.ensemble_mean_std_max_min, xclim.ensembles._robustness.change_significance, + xclim.ensembles._base.ensemble_mean_std_max_min, + xclim.ensembles._robustness.robustness_fractions, xclim.ensembles._robustness.robustness_categories, xclim.ensembles._robustness.robustness_coefficient, """ create_kwargs = create_kwargs or {} + statistics = deepcopy(statistics) # to avoid modifying the original dictionary # if input files are .zarr, change the engine automatically - if isinstance(datasets, list) and isinstance(datasets[0], (str, Path)): + if isinstance(datasets, list) and isinstance(datasets[0], (str, os.PathLike)): path = Path(datasets[0]) - if path.suffix == ".zarr" and "engine" not in create_kwargs: - create_kwargs["engine"] = "zarr" + if path.suffix == ".zarr": + create_kwargs.setdefault("engine", "zarr") - ens = ensembles.create_ensemble(datasets, **create_kwargs) + if not isinstance(datasets, xr.Dataset): + ens = ensembles.create_ensemble(datasets, **create_kwargs) + else: + ens = datasets ens_stats = xr.Dataset(attrs=ens.attrs) - for stat, stats_kwargs in statistics.items(): - stats_kwargs = deepcopy(stats_kwargs or {}) + + # "robustness_categories" requires "robustness_fractions", but we want to compute things only once if both are requested. + statistics_to_compute = list(statistics.keys()) + if "robustness_categories" in statistics_to_compute: + if "robustness_fractions" in statistics_to_compute: + statistics_to_compute.remove("robustness_fractions") + elif "robustness_fractions" not in statistics["robustness_categories"]: + raise ValueError( + "'robustness_categories' requires 'robustness_fractions' to be computed. " + "Either add 'robustness_fractions' to the statistics dictionary or " + "add 'robustness_fractions' under the 'robustness_categories' dictionary." + ) + + for stat in statistics_to_compute: + stats_kwargs = deepcopy(statistics.get(stat) or {}) logger.info( - f"Creating ensemble with {len(datasets)} simulations and calculating {stat}." + f"Calculating {stat} from an ensemble of {len(ens.realization)} simulations." ) - if ( - weights is not None - and "weights" in inspect.getfullargspec(getattr(ensembles, stat))[0] - ): - stats_kwargs["weights"] = weights.reindex_like(ens.realization) - if "ref" in stats_kwargs: - stats_kwargs["ref"] = ens.loc[stats_kwargs["ref"]] - - if stat == "change_significance": + + # Workaround for robustness_categories + real_stat = None + if stat == "robustness_categories": + real_stat = "robustness_categories" + stat = "robustness_fractions" + categories_kwargs = deepcopy(stats_kwargs) + categories_kwargs.pop("robustness_fractions", None) + stats_kwargs = deepcopy( + stats_kwargs.get("robustness_fractions", None) + or statistics.get("robustness_fractions", {}) + ) + + if weights is not None: + if "weights" in inspect.getfullargspec(getattr(ensembles, stat))[0]: + stats_kwargs["weights"] = weights.reindex_like(ens.realization) + else: + warnings.warn( + f"Weighting is not supported for '{stat}'. The results may be incorrect." + ) + + # FIXME: change_significance is deprecated and will be removed in xclim 0.49. + if stat in [ + "change_significance", + "robustness_fractions", + "robustness_categories", + ]: + # FIXME: This can be removed once change_significance is removed. + # It's here because the 'ref' default was removed for change_significance in xclim 0.47. + stats_kwargs.setdefault("ref", None) + if (stats_kwargs.get("ref") is not None) and len(statistics_to_compute) > 1: + raise ValueError( + f"The input requirements for '{stat}' when 'ref' is specified are not compatible with other statistics." + ) + + # These statistics only work on DataArrays for v in ens.data_vars: with xr.set_options(keep_attrs=True): - deltak = ens[v].attrs.get("delta_kind", None) - if stats_kwargs.get("ref") is not None and deltak is not None: + # Support for relative deltas [0, inf], where positive fraction is 'v > 1' instead of 'v > 0'. + delta_kind = ens[v].attrs.get("delta_kind") + if stats_kwargs.get("ref") is not None and delta_kind is not None: raise ValueError( f"{v} is a delta, but 'ref' was still specified." ) - if deltak in ["relative", "*", "/"]: + if delta_kind in ["rel.", "relative", "*", "/"]: logging.info( f"Relative delta detected for {v}. Applying 'v - 1' before change_significance." ) ens_v = ens[v] - 1 else: ens_v = ens[v] + + # Call the function tmp = getattr(ensembles, stat)(ens_v, **stats_kwargs) - if len(tmp) == 2: + + # Manage the multiple outputs of change_significance + # FIXME: change_significance is deprecated and will be removed in xclim 0.49. + if ( + stat == "change_significance" + and stats_kwargs.get("p_vals", False) is False + ): ens_stats[f"{v}_change_frac"], ens_stats[f"{v}_pos_frac"] = tmp - elif len(tmp) == 3: + elif stat == "change_significance" and stats_kwargs.get( + "p_vals", False + ): ( ens_stats[f"{v}_change_frac"], ens_stats[f"{v}_pos_frac"], ens_stats[f"{v}_p_vals"], ) = tmp - else: - raise ValueError(f"Unexpected number of outputs from {stat}.") + + # Robustness categories + if real_stat == "robustness_categories": + categories = ensembles.robustness_categories( + tmp, **categories_kwargs + ) + ens_stats[f"{v}_robustness_categories"] = categories + + # Only return the robustness fractions if they were requested. + if "robustness_fractions" in statistics.keys(): + tmp = tmp.rename({s: f"{v}_{s}" for s in tmp.data_vars}) + ens_stats = ens_stats.merge(tmp) + else: ens_stats = ens_stats.merge(getattr(ensembles, stat)(ens, **stats_kwargs)) @@ -123,7 +218,7 @@ def ensemble_stats( ens_stats = ens_stats.drop_vars("realization") # delete attrs that are not common to all dataset - if common_attrs_only: + if common_attrs_only and not isinstance(datasets, xr.Dataset): # if they exist remove attrs specific to create_ensemble create_kwargs.pop("mf_flag", None) create_kwargs.pop("resample_freq", None) diff --git a/xscen/utils.py b/xscen/utils.py index aa6c3e75..a6302b04 100644 --- a/xscen/utils.py +++ b/xscen/utils.py @@ -1011,7 +1011,7 @@ def publish_release_notes( changes = re.sub(search, replacement, changes) if not file: - return + return changes if isinstance(file, (Path, os.PathLike)): file = Path(file).open("w") print(changes, file=file)