diff --git a/xclim/sdba/_adjustment.py b/xclim/sdba/_adjustment.py index cfb247a24..8126f2952 100644 --- a/xclim/sdba/_adjustment.py +++ b/xclim/sdba/_adjustment.py @@ -122,7 +122,8 @@ def _npdft_train(ref, hist, rots, quantiles, method, extrap): extrap=extrap, ) hist[iv] = u.apply_correction(hist[iv], af, "+") - return af_q + hist = rots[-1].T @ hist + return af_q, hist def npdf_train( @@ -169,14 +170,18 @@ def npdf_train( # npdf core gr_dim = gw_idxs.attrs["group_dim"] af_q_l = [] - print(gw_idxs[gr_dim].size) + win = gw_idxs.attrs["group"][1] + adjust_scen = "scen" in ds.data_vars + if adjust_scen: + scen = ds.scen + scen_mbcn = scen.copy() for ib in range(gw_idxs[gr_dim].size): indices = gw_idxs[{gr_dim: ib}].astype(int).values ind = indices[indices >= 0] - af_q = xr.apply_ufunc( + af_q, scen_npdft = xr.apply_ufunc( _npdft_train, - standardize(ref[{"time": ind}].copy(), dim="time")[0], - standardize(hist[{"time": ind}].copy(), dim="time")[0], + standardize(ref[{"time": ind}])[0].copy(), + standardize(hist[{"time": ind}])[0].copy(), rot_matrices, quantiles, input_core_dims=[ @@ -187,17 +192,29 @@ def npdf_train( ], output_core_dims=[ ["iterations", "multivar_prime", "quantiles"], + ["multivar", "time"], ], dask="parallelized", - output_dtypes=[hist.dtype], + output_dtypes=[hist.dtype, hist.dtype], kwargs={"method": method, "extrap": extrap}, vectorize=True, ) af_q_l.append(af_q.expand_dims({gr_dim: [ib]})) + if adjust_scen: + reordered = reordering(ref=scen_npdft, sim=scen[{"time": ind}].copy()) + if win > 1: + indices_g = g_idxs[{gr_dim: ib}].astype(int).values + ind_g = indices_g[indices_g >= 0] + scen_mbcn[{"time": ind_g}] = reordered[{"time": np.in1d(ind, ind_g)}] + else: + scen_mbcn[{"time": ind}] = reordered af_q = xr.concat(af_q_l, dim=gr_dim).assign_coords( {"quantiles": quantiles, gr_dim: gw_idxs[gr_dim].values} ) - return xr.Dataset(dict(af_q=af_q)) + out = xr.Dataset(dict(af_q=af_q)) + if adjust_scen: + out["scen"] = scen_mbcn + return out def _npdft_adjust(sim, af_q, rots, quantiles, method, extrap): diff --git a/xclim/sdba/adjustment.py b/xclim/sdba/adjustment.py index ae784d7c7..ce26ef29e 100644 --- a/xclim/sdba/adjustment.py +++ b/xclim/sdba/adjustment.py @@ -1162,6 +1162,7 @@ def _train( ref: xr.DataArray, hist: xr.DataArray, *, + scen: xr.DataArray | None = None, base_kws: dict[str, Any] | None = None, adj_kws: dict[str, Any] | None = None, n_escore: int = -1, @@ -1201,6 +1202,8 @@ def _train( # prepare input dataset ds = xr.Dataset(dict(ref=ref, hist=hist)) # , hist_npdf=hist_npdf)) + if scen is not None: + ds["scen"] = scen # train out = npdf_train(