Skip to content

Commit

Permalink
more comments and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
coxipi committed Jan 12, 2024
1 parent ffba740 commit f672e1b
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 25 deletions.
66 changes: 43 additions & 23 deletions xclim/sdba/_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,13 @@ def _npdft_train(ref, hist, rots, quantiles, method, extrap):
)
hist[iv] = u.apply_correction(hist[iv], af, "+")
hist = rots[-1].T @ hist
return af_q, hist
return af_q


def npdf_train(
ds,
rot_matrices,
pts_dims,
quantiles,
g_idxs,
gw_idxs,
Expand All @@ -151,9 +152,9 @@ def npdf_train(
hist : training data
n_iter : int
The number of iterations to perform. Defaults to 20.
pts_dim : str
The name of the "multivariate" dimension. Defaults to "multivar", which is the
normal case when using :py:func:`xclim.sdba.base.stack_variables`.
pts_dims : str
The name of the "multivariate" dimension and its primed counterpart. Defaults to "multivar"
and "multivar_prime", which is the normal case when using :py:func:`xclim.sdba.base.stack_variables`.
rot_matrices : xr.DataArray, optional
The rotation matrices as a 3D array ('iterations', <pts_dim>, <anything>), with shape (n_iter, <N>, <N>).
If left empty, random rotation matrices will be automatically generated.
Expand All @@ -167,39 +168,42 @@ def npdf_train(
# unload data
ref = ds.ref
hist = ds.hist
# npdf core
gr_dim = gw_idxs.attrs["group_dim"]

# npdf training core
af_q_l = []
for ib in range(gw_idxs[gr_dim].size):
# indices in a given time block
indices = gw_idxs[{gr_dim: ib}].astype(int).values
ind = indices[indices >= 0]
af_q, scen_npdft = xr.apply_ufunc(

# npdft training : multiple rotations on standardized datasets
# keep track of adjustment factors in each rotation for later use
af_q = xr.apply_ufunc(
_npdft_train,
standardize(ref[{"time": ind}])[0].copy(),
standardize(hist[{"time": ind}])[0].copy(),
rot_matrices,
quantiles,
input_core_dims=[
["multivar", "time"],
["multivar", "time"],
["iterations", "multivar_prime", "multivar"],
[pts_dims[0], "time"],
[pts_dims[0], "time"],
["iterations", pts_dims[1], pts_dims[0]],
["quantiles"],
],
output_core_dims=[
["iterations", "multivar_prime", "quantiles"],
["multivar", "time"],
["iterations", pts_dims[1], "quantiles"],
],
dask="parallelized",
output_dtypes=[hist.dtype, hist.dtype],
output_dtypes=[hist.dtype],
kwargs={"method": method, "extrap": extrap},
vectorize=True,
)
af_q_l.append(af_q.expand_dims({gr_dim: [ib]}))
af_q = xr.concat(af_q_l, dim=gr_dim).assign_coords(
{"quantiles": quantiles, gr_dim: gw_idxs[gr_dim].values}
)
out = xr.Dataset(dict(af_q=af_q))
return out
return af_q.to_dataset(name="af_q")


def _npdft_adjust(sim, af_q, rots, quantiles, method, extrap):
Expand Down Expand Up @@ -236,6 +240,7 @@ def npdf_adjust(
hist,
sim,
ds,
pts_dims,
method,
extrap,
base_kws_scen,
Expand Down Expand Up @@ -263,46 +268,61 @@ def npdf_adjust(
gw_idxs = ds.gw_idxs
gr_dim = gw_idxs.attrs["group_dim"]
win = gw_idxs.attrs["group"][1]

# this way of handling was letting open the possibility to perform
# interpolation for multiple periods in the simulation all at once
# in principle, avoiding redundancy. Need to test this on small data
# to confirm it works, and on big data to check performance.
dims = ["time"] if period_dim is None else [period_dim, "time"]

# arguments for univariate adjustment (step 1)
base = base_kws_scen["base"]
kinds = base_kws_scen["kinds"]
base_kws_scen.pop("base")
base_kws_scen.pop("kinds")
scen_mbcn = sim.copy()

# mbcn core
scen_mbcn = xr.zeros_like(sim)
for ib in range(gw_idxs[gr_dim].size):
# indices in a given time block (with and without the window)
indices_gw = gw_idxs[{gr_dim: ib}].astype(int).values
ind_gw = indices_gw[indices_gw >= 0]
indices_g = g_idxs[{gr_dim: ib}].astype(int).values
ind_g = indices_g[indices_g >= 0]

# 1. univariate adjustment of sim -> scen
# the kind may be differ depending on the variables
scen_block = xr.zeros_like(sim[{"time": ind_gw}])
for iv, v in enumerate(sim.multivar.values):
sl = {"time": ind_gw, "multivar": iv}
for iv, v in enumerate(sim[pts_dims[0]].values):
sl = {"time": ind_gw, pts_dims[0]: iv}
ADJ = base.train(ref[sl], hist[sl], kind=kinds[v], **base_kws_scen)
scen_block[{"multivar": iv}] = ADJ.adjust(sim[sl], **adj_kws_scen).scen
scen_block[{pts_dims[0]: iv}] = ADJ.adjust(sim[sl], **adj_kws_scen).scen

# 2. npdft adjustment of sim
npdft_block = xr.apply_ufunc(
_npdft_adjust,
standardize(sim[{"time": ind_gw}].copy(), dim="time")[0],
af_q[{gr_dim: ib}],
rot_matrices,
quantiles,
input_core_dims=[
["multivar"] + dims,
["iterations", "multivar_prime", "quantiles"],
["iterations", "multivar_prime", "multivar"],
[pts_dims[0]] + dims,
["iterations", pts_dims[1], "quantiles"],
["iterations", pts_dims[1], pts_dims[0]],
["quantiles"],
],
output_core_dims=[
["multivar"] + dims,
[pts_dims[0]] + dims,
],
dask="parallelized",
output_dtypes=[sim.dtype],
kwargs={"method": method, "extrap": extrap},
vectorize=True,
)
# reorder
# 3. reorder scen according to npdft results
reordered = reordering(ref=npdft_block, sim=scen_block)
if win > 1:
# keep central value of window (intersecting indices in gw_idxs and g_idxs)
scen_mbcn[{"time": ind_g}] = reordered[{"time": np.in1d(ind_gw, ind_g)}]
else:
scen_mbcn[{"time": ind_g}] = reordered
Expand Down
14 changes: 12 additions & 2 deletions xclim/sdba/adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,7 @@ def _train(
rot_dim = xr.core.utils.get_temp_dimname(
set(ref.dims).union(hist.dims), pts_dim + "_prime"
)
pts_dims = (pts_dim, rot_dim)
if rot_matrices is None:
rot_matrices = rand_rot_matrix(
ref[pts_dim], num=n_iter, new_dim=rot_dim
Expand All @@ -1206,6 +1207,7 @@ def _train(
out = npdf_train(
ds,
rot_matrices=rot_matrices,
pts_dims=pts_dims,
quantiles=quantiles,
g_idxs=g_idxs,
gw_idxs=gw_idxs,
Expand All @@ -1223,7 +1225,12 @@ def _train(
standard_name="Adjustment factors",
long_name="Quantile mapping adjustment factors",
)
return out, {"group": group, "interp": interp, "extrapolation": extrapolation}
return out, {
"group": group,
"interp": interp,
"extrapolation": extrapolation,
"pts_dims": pts_dims,
}

def _adjust(
self,
Expand All @@ -1240,7 +1247,9 @@ def _adjust(
base_kws_scen.setdefault("nquantiles", self.ds.af_q.quantiles.values)
base_kws_scen.setdefault("group", self.group)
# change this hardcoding multivar
base_kws_scen.setdefault("kinds", {v: "+" for v in sim["multivar"].values})
base_kws_scen.setdefault(
"kinds", {v: "+" for v in sim[self.pts_dims[0]].values}
)
base_kws_scen.setdefault("base", QuantileDeltaMapping)

if np.isscalar(base_kws_scen["nquantiles"]):
Expand All @@ -1266,6 +1275,7 @@ def _adjust(
hist,
sim,
self.ds,
self.pts_dims,
self.interp,
self.extrapolation,
base_kws_scen,
Expand Down

0 comments on commit f672e1b

Please sign in to comment.