diff --git a/xclim/sdba/_adjustment.py b/xclim/sdba/_adjustment.py index 4463c61f9..0b3fad263 100644 --- a/xclim/sdba/_adjustment.py +++ b/xclim/sdba/_adjustment.py @@ -384,7 +384,6 @@ def mbcn_adjust( # to confirm it works, and on big data to check performance. dims = ["time"] if period_dim is None else [period_dim, "time"] - sim_u = sim[pts_dims[0]].attrs["_units"] # mbcn core scen_mbcn = xr.zeros_like(sim) for ib in range(gw_idxs[gr_dim].size): @@ -400,13 +399,12 @@ def mbcn_adjust( for iv, v in enumerate(sim[pts_dims[0]].values): sl = {"time": ind_gw, pts_dims[0]: iv} with set_options(sdba_extra_output=False): + ADJ = base.train( - ref[v][{"time": ind_gw}], - hist[v][{"time": ind_gw}], - **base_kws_vars[v], + ref[sl], hist[sl], **base_kws_vars[v], skip_input_checks=True ) scen_block[{pts_dims[0]: iv}] = ADJ.adjust( - sim[sl].assign_attrs({"units": sim_u[iv]}), **adj_kws + sim[sl], **adj_kws, skip_input_checks=True ) # 2. npdft adjustment of sim @@ -439,7 +437,7 @@ def mbcn_adjust( else: scen_mbcn[{"time": ind_g}] = reordered - return scen_mbcn + return scen_mbcn.to_dataset(name="scen") @map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[]) diff --git a/xclim/sdba/adjustment.py b/xclim/sdba/adjustment.py index 255b6cc7f..168e67b41 100644 --- a/xclim/sdba/adjustment.py +++ b/xclim/sdba/adjustment.py @@ -37,7 +37,7 @@ scaling_train, ) from .base import Grouper, ParametrizableWithDataset, parse_group -from .processing import grouped_time_indexes, stack_variables, unstack_variables +from .processing import grouped_time_indexes from .utils import ( ADDITIVE, best_pc_orientation_full, @@ -88,8 +88,6 @@ def _check_inputs(cls, *inputs, group): Also raises if :py:attr:`BaseAdjustment._allow_diff_calendars` is False and calendars differ. """ for inda in inputs: - if isinstance(inda, xr.Dataset): - inda = stack_variables(inda) if uses_dask(inda) and len(inda.chunks[inda.get_axis_num(group.dim)]) > 1: raise ValueError( f"Multiple chunks along the main adjustment dimension {group.dim} is not supported." @@ -135,22 +133,44 @@ def _harmonize_units(cls, *inputs, target: dict[str] | str | None = None): Returns the converted inputs and the target units. """ - def _convert_units_to(ds, target, context): - if isinstance(ds, DataArray): - return convert_units_to(ds, target, context) - else: - return xr.merge( - [convert_units_to(ds[v], target[v], context) for v in ds.data_vars] - ) + def _harmonize_units_multivariate( + *inputs, dim, target: dict[str] | None = None + ): + def _convert_units_to(inda, dim, target): + for iv, v in enumerate(inda[dim].values): + inda.attrs["units"], inda.attrs["standard_name"] = ( + inda[dim].attrs[a][iv] for a in ["_units", "_standard_name"] + ) + inda[{dim: iv}] = convert_units_to( + inda[{dim: iv}], target[v], context="infer" + ) + inda.attrs["units"], inda.attrs["standard_name"] = "", "" + return inda + + if target is None: + target = { + v: inputs[0][dim].attrs["_units"][iv] + for iv, v in enumerate(inputs[0][dim].values) + } + return ( + _convert_units_to(inda, dim=dim, target=target) for inda in inputs + ), target + + dim_multivariate = None + for _dim, _crd in inputs[0].coords.items(): + if _crd.attrs.get("is_variables"): + dim_multivariate = str(_dim) + break + if dim_multivariate is not None: + return _harmonize_units_multivariate( + *inputs, dim=dim_multivariate, target=target + ) if target is None: - if isinstance(inputs[0], DataArray): - target = inputs[0].units - else: - target = {v: inputs[0][v].units for v in inputs[0].data_vars} + target = inputs[0].units return ( - _convert_units_to(inp, target, context="infer") for inp in inputs + convert_units_to(inda, target, context="infer") for inda in inputs ), target @classmethod @@ -237,7 +257,6 @@ def adjust(self, sim: DataArray, *args, **kwargs): if "group" in self: self._check_inputs(sim, *args, group=self.group) - sim = convert_units_to(sim, self.train_units) out = self._adjust(sim, *args, **kwargs) if isinstance(out, xr.DataArray): @@ -336,168 +355,6 @@ def adjust( return scen -class MultivariateTrainAdjust(BaseAdjustment): - """Base class for adjustment objects obeying the train-adjust scheme. - - Children classes should implement these methods: - - - ``_train(ref, hist, **kwargs)``, classmethod receiving the training target and data, - returning a training dataset and parameters to store in the object. - - - ``_adjust(sim, **kwargs)``, receiving the projected data and some arguments, - returning the `scen` Dataset. - - """ - - _allow_diff_calendars = True - _attribute = "_xclim_adjustment" - _repr_hide_params = ["hist_calendar", "train_units"] - - @classmethod - def train(cls, ref: xr.Dataset, hist: xr.Dataset, **kwargs) -> TrainAdjust: - r"""Train the adjustment object. - - Refer to the class documentation for the algorithm details. - - Parameters - ---------- - ref : xr.Dataset - Training target, usually a reference time series drawn from observations. - hist : xr.Dataset - Training data, usually a model output whose biases are to be adjusted. - \*\*kwargs - Algorithm-specific keyword arguments, see class doc. - """ - kwargs = parse_group(cls._train, kwargs) - skip_checks = kwargs.pop("skip_input_checks", False) - - if not skip_checks: - (ref, hist), train_units = cls._harmonize_units(ref, hist) - - if "group" in kwargs: - cls._check_inputs(ref, hist, group=kwargs["group"]) - - # I don't understand the point of this line, this is done above already - # hist = convert_units_to(hist, ref) - else: - train_units = "" - - ds, params = cls._train(ref, hist, **kwargs) - obj = cls( - _trained=True, - hist_calendar=get_calendar(hist), - train_units=train_units, - **params, - ) - obj.set_dataset(ds) - return obj - - def adjust(self, sim: DataArray, *args, **kwargs): - r"""Return bias-adjusted data. - - Refer to the class documentation for the algorithm details. - - Parameters - ---------- - sim : xr.Dataset - Time series to be bias-adjusted, usually a model output. - args : xr.Dataset | xr.DataArray - Other Datasets/DataArrays needed for the adjustment (usually none). - \*\*kwargs - Algorithm-specific keyword arguments, see class doc. - """ - skip_checks = kwargs.pop("skip_input_checks", False) - if not skip_checks: - (sim, *args), _ = self._harmonize_units(sim, *args, target=self.train_units) - - if "group" in self: - self._check_inputs(sim, *args, group=self.group) - - # sim = convert_units_to(sim, self.train_units) - scen = self._adjust(sim, *args, **kwargs) - - # Keep attrs - scen.attrs.update(sim.attrs) - for name, crd in sim.coords.items(): - if name in scen.coords: - scen[name].attrs.update(crd.attrs) - params = gen_call_string("", **kwargs)[1:-1] # indexing to remove added ( ) - infostr = f"{str(self)}.adjust(sim, {params})" - scen.attrs["history"] = update_history(f"Bias-adjusted with {infostr}", sim) - scen.attrs["bias_adjustment"] = infostr - return scen - - def set_dataset(self, ds: xr.Dataset): - """Store an xarray dataset in the `ds` attribute. - - Useful with custom object initialization or if some external processing was performed. - """ - super().set_dataset(ds) - self.ds.attrs["adj_params"] = str(self) - - @classmethod - def _train(cls, ref: DataArray, hist: DataArray, *kwargs): - raise NotImplementedError() - - def _adjust(self, sim, **kwargs): - raise NotImplementedError() - - -class MultivariateAdjust(BaseAdjustment): - """Adjustment with no intermediate trained object. - - Children classes should implement a `_adjust` classmethod taking as input the three Datasets - and returning the scen dataset/array. - """ - - @classmethod - def adjust( - cls, - ref: xr.Dataset, - hist: xr.Dataset, - sim: xr.Dataset | None = None, - **kwargs, - ) -> xr.Dataset: - r"""Return bias-adjusted data. Refer to the class documentation for the algorithm details. - - Parameters - ---------- - ref : xr.Dataset - Training target, usually a reference time series drawn from observations. - hist : xr.Dataset - Training data, usually a model output whose biases are to be adjusted. - sim : xr.Dataset, optional. - Time series to be bias-adjusted, usually a model output. If `None`, `hist` will be used for adjustments. - \*\*kwargs - Algorithm-specific keyword arguments, see class doc. - - Returns - ------- - xr.Dataset - The bias-adjusted Dataset. - """ - kwargs = parse_group(cls._adjust, kwargs) - skip_checks = kwargs.pop("skip_input_checks", False) - - kwargs["adjust_hist"] = sim is None - if sim is None: - sim = hist.copy() - - if not skip_checks: - if "group" in kwargs: - cls._check_inputs(ref, hist, sim, group=kwargs["group"]) - - (ref, hist, sim), _ = cls._harmonize_units(ref, hist, sim) - - scen = cls._adjust(ref, hist, sim, **kwargs) - - params = ", ".join([f"{k}={repr(v)}" for k, v in kwargs.items()]) - infostr = f"{cls.__name__}.adjust(ref, hist, sim, {params})" - scen.attrs["history"] = update_history(f"Bias-adjusted with {infostr}", sim) - scen.attrs["bias_adjustment"] = infostr - return scen - - class EmpiricalQuantileMapping(TrainAdjust): """Empirical Quantile Mapping bias-adjustment. @@ -1388,11 +1245,7 @@ def _adjust( return out -# Right now, the training part only outputs af_q -# if it outputs also the corrected hist, it could be like NpdfTransform - - -class MBCn(MultivariateTrainAdjust): +class MBCn(TrainAdjust): r"""Multivariate bias correction function using the N-dimensional probability density function transform. A multivariate bias-adjustment algorithm described by :cite:t:`sdba-cannon_multivariate_2018` @@ -1514,8 +1367,8 @@ class MBCn(MultivariateTrainAdjust): @classmethod def _train( cls, - ref: xr.Dataset, - hist: xr.Dataset, + ref: xr.DataArray, + hist: xr.DataArray, *, base_kws: dict[str, Any] | None = None, adj_kws: dict[str, Any] | None = None, @@ -1524,7 +1377,6 @@ def _train( pts_dim: str = "multivar", rot_matrices: xr.DataArray | None = None, ): - """Training method of MBCn (temporary docstring)""" # set default values for non-specified parameters base_kws = base_kws if base_kws is not None else {} adj_kws = adj_kws if adj_kws is not None else {} @@ -1542,23 +1394,19 @@ def _train( "Received `group==time.month` in `base_kws`. Monthly grouping is not currently supported in the MBCn class." ) # stack variables and prepare rotations - if rot_matrices is None: - pts_dim = xr.core.utils.get_temp_dimname( - set(ref.dims).union(hist.dims), "multivar" - ) + if rot_matrices is not None: + if pts_dim != rot_matrices.attrs["crd_dim"]: + raise ValueError( + f"`crd_dim` attribute of `rot_matrices` ({rot_matrices.attrs['crd_dim']}) does not correspond to `pts_dim` ({pts_dim})." + ) else: - pts_dim = rot_matrices.attrs["crd_dim"] - rot_dim = rot_matrices.attrs["new_dim"] - ref = stack_variables(ref, pts_dim) - hist = stack_variables(hist, pts_dim) - if rot_matrices is None: rot_dim = xr.core.utils.get_temp_dimname( set(ref.dims).union(hist.dims), pts_dim + "_prime" ) rot_matrices = rand_rot_matrix( ref[pts_dim], num=n_iter, new_dim=rot_dim ).rename(matrices="iterations") - pts_dims = (pts_dim, rot_dim) + pts_dims = [rot_matrices.attrs[d] for d in ["crd_dim", "new_dim"]] # time indices corresponding to group and windowed group # used to divide datasets as map_blocks or groupby would do @@ -1591,19 +1439,18 @@ def _train( def _adjust( self, - sim: xr.Dataset, + sim: xr.DataArray, + ref: xr.DataArray, + hist: xr.DataArray, *, - ref: xr.Dataset | None = None, - hist: xr.Dataset | None = None, base: TrainAdjust = QuantileDeltaMapping, base_kws_vars: dict[str, Any] | None = None, adj_kws: dict[str, Any] | None = None, period_dim=None, ): - """Adjusting method of MBCn (temporary docstring)""" # set default values for non-specified parameters base_kws_vars = base_kws_vars or {} - for v in sim.data_vars: + for v in sim[self.pts_dims[0]].values: if v not in base_kws_vars.keys(): base_kws_vars[v] = {} @@ -1628,7 +1475,6 @@ def _adjust( adj_kws.setdefault("extrapolation", self.extrapolation) # adjust (adjust for npft transform, train/adjust for univariate bias correction) - sim = stack_variables(sim, self.pts_dims[0]) out = mbcn_adjust( ref=ref, hist=hist, @@ -1643,286 +1489,9 @@ def _adjust( period_dim=period_dim, ) - # postprocess - out = unstack_variables(out) return out -# class MBCn(BaseAdjustment): -# r"""Multivariate bias correction function using the N-dimensional probability density function transform. - -# A multivariate bias-adjustment algorithm described by :cite:t:`sdba-cannon_multivariate_2018` -# based on a color-correction algorithm described by :cite:t:`sdba-pitie_n-dimensional_2005`. - -# This algorithm in itself, when used with QuantileDeltaMapping, is NOT trend-preserving. -# The full MBCn algorithm includes a reordering step provided here by :py:func:`xclim.sdba.processing.reordering`. - -# See notes for an explanation of the algorithm. - -# Attributes -# ---------- -# Train step - -# ref : xr.DataArray -# Reference dataset. -# hist : xr.DataArray -# Historical dataset. -# base_kws : dict, optional -# Arguments passed to the training in the npdf transform. -# adj_kws : dict, optional -# Arguments passed to the adjusting in the npdf transform. -# n_escore : int -# The number of elements to send to the escore function. The default, 0, means all elements are included. -# Pass -1 to skip computing the escore completely. -# Small numbers result in less significant scores, but the execution time goes up quickly with large values. -# n_iter : int -# The number of iterations to perform. Defaults to 20. -# pts_dim : str -# The name of the "multivariate" dimension. Defaults to "multivar", which is the -# normal case when using :py:func:`xclim.sdba.base.stack_variables`. -# rot_matrices: xr.DataArray, optional -# The rotation matrices as a 3D array ('iterations', , ), with shape (n_iter, , ). -# If left empty, random rotation matrices will be automatically generated. -# The rotation matrices as a 3D array ('iterations', , ), with shape (n_iter, , ). - -# Adjust step - -# ref : xr.DataArray -# Target reference dataset also needed for univariate bias correction preceding npdf transform -# hist: xr.DataArray -# Source dataset also needed for univariate bias correction preceding npdf transform -# sim : xr.DataArray -# Source dataset to adjust. -# base : BaseAdjustment -# Bias-adjustment class used for the univariate bias correction. -# base_kws : dict, optional -# Arguments passed to the training in the univariate bias correction -# adj_kws : dict, optional -# Arguments passed to the adjusting in the univariate bias correction -# period_dim : str, optional -# Name of the period dimension used when stacking time periods of `sim` using :py:func:`xclim.core.calendar.stack_periods`. -# If specified, the interpolation of the npdf transform is performed only once and applied on all periods simultaneously. -# This should be more performant, but also more memory intensive. - -# Notes -# ----- -# The historical reference (:math:`T`, for "target"), simulated historical (:math:`H`) and simulated projected (:math:`S`) -# datasets are constructed by stacking the timeseries of N variables together. The algorithm is broken into the -# following steps: - -# Training (only npdf transform training) - -# 1. Standardize `ref` and `hist` (see ``xclim.sdba.processing.standardize``.) - -# 2. Rotate the datasets in the N-dimensional variable space with :math:`\mathbf{R}`, a random rotation NxN matrix. - -# .. math:: - -# \tilde{\mathbf{T}} = \mathbf{T}\mathbf{R} \ -# \tilde{\mathbf{H}} = \mathbf{H}\mathbf{R} - -# 3. QuantileDeltaMapping is used to perform bias adjustment :math:`\mathcal{F}` on the rotated datasets. -# The adjustment factor is conserved for later use in the adjusting step. The adjustments are made in additive mode, -# for each variable :math:`i`. - -# .. math:: - -# \hat{\mathbf{H}}_i, \hat{\mathbf{S}}_i = \mathcal{F}\left(\tilde{\mathbf{T}}_i, \tilde{\mathbf{H}}_i, \tilde{\mathbf{S}}_i\right) - -# 4. The bias-adjusted datasets are rotated back. - -# .. math:: - -# \mathbf{H}' = \hat{\mathbf{H}}\mathbf{R} \\ -# \mathbf{S}' = \hat{\mathbf{S}}\mathbf{R} - -# 5. Repeat steps 2,3,4 three steps ``n_iter`` times, i.e. the number of randomly generated rotation matrices. - -# Adjusting - -# 1. Perform the same steps as in training, with `ref, hist` replaced with `sim`. Step 3. of the training is modified, here we -# simply reuse the adjustment factors previously found in the training step to bias correct the standardized `sim` directly. - -# 2. Using the original (unstandardized) `ref,hist, sim`, perform a univariate bias adjustment using the ``base_scen`` class -# on `sim`. - -# 3. Reorder the dataset found in step 2. according to the ranks of the dataset found in step 1. - - -# The original algorithm :cite:p:`sdba-pitie_n-dimensional_2005`, stops the iteration when some distance score converges. -# Following cite:t:`sdba-cannon_multivariate_2018` and the MBCn implementation in :cite:t:`sdba-cannon_mbc_2020`, we -# instead fix the number of iterations. - -# As done by cite:t:`sdba-cannon_multivariate_2018`, the distance score chosen is the "Energy distance" from -# :cite:t:`sdba-szekely_testing_2004`. (see: :py:func:`xclim.sdba.processing.escore`). - -# The random matrices are generated following a method laid out by :cite:t:`sdba-mezzadri_how_2007`. - -# References -# ---------- -# :cite:cts:`sdba-cannon_multivariate_2018,sdba-cannon_mbc_2020,sdba-pitie_n-dimensional_2005,sdba-mezzadri_how_2007,sdba-szekely_testing_2004` - -# Notes -# ----- -# Only "time" and "time.dayofyear" (with a suitable window) are implemented as possible values for `group`. -# """ - -# @classmethod -# def train( -# cls, -# ref: xr.Dataset, -# hist: xr.Dataset, -# *, -# base_kws: dict[str, Any] | None = None, -# adj_kws: dict[str, Any] | None = None, -# n_escore: int = -1, -# n_iter: int = 20, -# pts_dim: str = "multivar", -# rot_matrices: xr.DataArray | None = None, -# ): -# """Training method of MBCn (temporary docstring)""" -# # check units -# # skip_checks = kwargs.pop("skip_input_checks", False) -# # if not skip_checks: -# # if "group" in kwargs: -# # cls._check_inputs(ref, hist, sim, group=kwargs["group"]) -# (ref, hist), train_units = cls._harmonize_units_multivariate(ref, hist) -# # set default values for non-specified parameters -# base_kws = base_kws if base_kws is not None else {} -# adj_kws = adj_kws if adj_kws is not None else {} -# base_kws.setdefault("nquantiles", 20) -# base_kws.setdefault("group", Grouper("time", 1)) -# adj_kws.setdefault("interp", "nearest") -# adj_kws.setdefault("extrapolation", "constant") - -# if np.isscalar(base_kws["nquantiles"]): -# base_kws["nquantiles"] = equally_spaced_nodes(base_kws["nquantiles"]) -# if isinstance(base_kws["group"], str): -# base_kws["group"] = Grouper(base_kws["group"], 1) -# if base_kws["group"].name == "time.month": -# NotImplementedError( -# "Received `group==time.month` in `base_kws`. Monthly grouping is not currently supported in the MBCn class." -# ) -# # stack variables and prepare rotations -# if rot_matrices is None: -# pts_dim = xr.core.utils.get_temp_dimname( -# set(ref.dims).union(hist.dims), "multivar" -# ) -# else: -# pts_dim = rot_matrices.attrs["crd_dim"] -# rot_dim = rot_matrices.attrs["new_dim"] -# ref = stack_variables(ref, pts_dim) -# hist = stack_variables(hist, pts_dim) -# if rot_matrices is None: -# rot_dim = xr.core.utils.get_temp_dimname( -# set(ref.dims).union(hist.dims), pts_dim + "_prime" -# ) -# rot_matrices = rand_rot_matrix( -# ref[pts_dim], num=n_iter, new_dim=rot_dim -# ).rename(matrices="iterations") -# pts_dims = (pts_dim, rot_dim) - -# # time indices corresponding to group and windowed group -# # used to divide datasets as map_blocks or groupby would do -# g_idxs, gw_idxs = grouped_time_indexes(ref.time, base_kws["group"]) - -# # training, obtain adjustment factors of the npdf transform -# ds = xr.Dataset(dict(ref=ref, hist=hist)) -# params = { -# "quantiles": base_kws["nquantiles"], -# "interp": adj_kws["interp"], -# "extrapolation": adj_kws["extrapolation"], -# "pts_dims": pts_dims, -# "n_escore": n_escore, -# } -# out = mbcn_train( -# ds, rot_matrices=rot_matrices, g_idxs=g_idxs, gw_idxs=gw_idxs, **params -# ) -# params["group"] = base_kws["group"] - -# # postprocess -# out["rot_matrices"] = rot_matrices -# out["g_idxs"] = g_idxs -# out["gw_idxs"] = gw_idxs - -# out.af_q.attrs.update( -# standard_name="Adjustment factors", -# long_name="Quantile mapping adjustment factors", -# ) -# obj = cls( -# _trained=True, -# hist_calendar=get_calendar(hist), -# train_units=train_units, -# **params, -# ) -# obj.set_dataset(out) -# return obj - -# def adjust( -# self, -# sim: xr.Dataset, -# *, -# ref: xr.Dataset | None = None, -# hist: xr.Dataset | None = None, -# base: TrainAdjust = QuantileDeltaMapping, -# base_kws_vars: dict[str, Any] | None = None, -# adj_kws: dict[str, Any] | None = None, -# period_dim=None, -# ): -# """Adjusting method of MBCn (temporary docstring)""" -# # units -# # skip_checks = kwargs.pop("skip_input_checks", False) -# # if not skip_checks: -# (ref, hist, sim), _ = cls._harmonize_units_multivariate(ref, hist, sim, self.train_units) - - -# # set default values for non-specified parameters -# base_kws_vars = base_kws_vars or {} -# for v in sim.data_vars: -# if v not in base_kws_vars.keys(): -# base_kws_vars[v] = {} - -# base_kws_vars[v].setdefault("group", self.group) -# if isinstance(base_kws_vars[v]["group"], str): -# base_kws_vars[v]["group"] = Grouper(base_kws_vars[v]["group"], 1) -# if base_kws_vars[v]["group"] != self.group: -# raise ValueError( -# "Expected usage of the same group in adjust and training steps." -# ) -# base_kws_vars[v].pop("group") - -# base_kws_vars[v].setdefault("nquantiles", self.ds.af_q.quantiles.values) -# if np.isscalar(base_kws_vars[v]["nquantiles"]): -# base_kws_vars[v]["nquantiles"] = equally_spaced_nodes( -# base_kws_vars[v]["nquantiles"] -# ) -# # should maybe require that quantiles here match those from training? - -# adj_kws = adj_kws if adj_kws is not None else {} -# adj_kws.setdefault("interp", self.interp) -# adj_kws.setdefault("extrapolation", self.extrapolation) - -# # adjust (adjust for npft transform, train/adjust for univariate bias correction) -# sim = stack_variables(sim, self.pts_dims[0]) -# out = mbcn_adjust( -# ref=ref, -# hist=hist, -# sim=sim, -# ds=self.ds, -# pts_dims=self.pts_dims, -# interp=self.interp, -# extrapolation=self.extrapolation, -# base=base, -# base_kws_vars=base_kws_vars, -# adj_kws=adj_kws, -# period_dim=period_dim, -# ) - -# # postprocess -# out = unstack_variables(out) -# return out - - try: import SBCK except ImportError: