Skip to content

Commit

Permalink
fix _npdf_train, mbcn in train_adjust
Browse files Browse the repository at this point in the history
  • Loading branch information
coxipi committed Jan 12, 2024
1 parent 1d7a50a commit cea04a3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
31 changes: 24 additions & 7 deletions xclim/sdba/_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def _npdft_train(ref, hist, rots, quantiles, method, extrap):
extrap=extrap,
)
hist[iv] = u.apply_correction(hist[iv], af, "+")
return af_q
hist = rots[-1].T @ hist
return af_q, hist


def npdf_train(
Expand Down Expand Up @@ -169,14 +170,18 @@ def npdf_train(
# npdf core
gr_dim = gw_idxs.attrs["group_dim"]
af_q_l = []
print(gw_idxs[gr_dim].size)
win = gw_idxs.attrs["group"][1]
adjust_scen = "scen" in ds.data_vars
if adjust_scen:
scen = ds.scen
scen_mbcn = scen.copy()
for ib in range(gw_idxs[gr_dim].size):
indices = gw_idxs[{gr_dim: ib}].astype(int).values
ind = indices[indices >= 0]
af_q = xr.apply_ufunc(
af_q, scen_npdft = xr.apply_ufunc(
_npdft_train,
standardize(ref[{"time": ind}].copy(), dim="time")[0],
standardize(hist[{"time": ind}].copy(), dim="time")[0],
standardize(ref[{"time": ind}])[0].copy(),
standardize(hist[{"time": ind}])[0].copy(),
rot_matrices,
quantiles,
input_core_dims=[
Expand All @@ -187,17 +192,29 @@ def npdf_train(
],
output_core_dims=[
["iterations", "multivar_prime", "quantiles"],
["multivar", "time"],
],
dask="parallelized",
output_dtypes=[hist.dtype],
output_dtypes=[hist.dtype, hist.dtype],
kwargs={"method": method, "extrap": extrap},
vectorize=True,
)
af_q_l.append(af_q.expand_dims({gr_dim: [ib]}))
if adjust_scen:
reordered = reordering(ref=scen_npdft, sim=scen[{"time": ind}].copy())
if win > 1:
indices_g = g_idxs[{gr_dim: ib}].astype(int).values
ind_g = indices_g[indices_g >= 0]
scen_mbcn[{"time": ind_g}] = reordered[{"time": np.in1d(ind, ind_g)}]
else:
scen_mbcn[{"time": ind}] = reordered
af_q = xr.concat(af_q_l, dim=gr_dim).assign_coords(
{"quantiles": quantiles, gr_dim: gw_idxs[gr_dim].values}
)
return xr.Dataset(dict(af_q=af_q))
out = xr.Dataset(dict(af_q=af_q))
if adjust_scen:
out["scen"] = scen_mbcn
return out


def _npdft_adjust(sim, af_q, rots, quantiles, method, extrap):
Expand Down
3 changes: 3 additions & 0 deletions xclim/sdba/adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,7 @@ def _train(
ref: xr.DataArray,
hist: xr.DataArray,
*,
scen: xr.DataArray | None = None,
base_kws: dict[str, Any] | None = None,
adj_kws: dict[str, Any] | None = None,
n_escore: int = -1,
Expand Down Expand Up @@ -1201,6 +1202,8 @@ def _train(

# prepare input dataset
ds = xr.Dataset(dict(ref=ref, hist=hist)) # , hist_npdf=hist_npdf))
if scen is not None:
ds["scen"] = scen

# train
out = npdf_train(
Expand Down

0 comments on commit cea04a3

Please sign in to comment.