Skip to content

Commit

Permalink
add loc estimates in _fit_start for 'gamma' and 'fisk'
Browse files Browse the repository at this point in the history
  • Loading branch information
coxipi committed Mar 22, 2024
1 parent b989dbf commit 52bff86
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 22 deletions.
10 changes: 8 additions & 2 deletions xclim/indices/_agro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,7 @@ def standardized_precipitation_index(
window: int = 1,
dist: str = "gamma",
method: str = "APP",
fitkwargs: dict = {},
cal_start: DateStr | None = None,
cal_end: DateStr | None = None,
params: Quantified | None = None,
Expand All @@ -1120,6 +1121,8 @@ def standardized_precipitation_index(
method : {'APP', 'ML'}
Name of the fitting method, such as `ML` (maximum likelihood), `APP` (approximate). The approximate method
uses a deterministic function that doesn't involve any optimization.
fitkwargs : dict
Kwargs passed to ``xclim.indices.stats.fit`` used to impose values of certains parameters (`floc`, `fscale`).
cal_start : DateStr, optional
Start date of the calibration period. A `DateStr` is expected, that is a `str` in format `"YYYY-MM-DD"`.
Default option `None` means that the calibration period begins at the start of the input dataset.
Expand Down Expand Up @@ -1189,7 +1192,7 @@ def standardized_precipitation_index(
else:
raise NotImplementedError(f"{dist} distribution is not implemented yet")
spi = standardized_index(
pr, freq, window, dist, method, cal_start, cal_end, params, **indexer
pr, freq, window, dist, method, fitkwargs, cal_start, cal_end, params, **indexer
)
return spi

Expand All @@ -1205,6 +1208,7 @@ def standardized_precipitation_evapotranspiration_index(
window: int = 1,
dist: str = "gamma",
method: str = "APP",
fitkwargs: dict = {},
offset: Quantified = "",
cal_start: DateStr | None = None,
cal_end: DateStr | None = None,
Expand Down Expand Up @@ -1234,6 +1238,8 @@ def standardized_precipitation_evapotranspiration_index(
`PWM` (probability weighted moments).
The approximate method uses a deterministic function that doesn't involve any optimization. Available methods
vary with the distribution: 'gamma':{'APP', 'ML', 'PWM'}, 'fisk':{'APP', 'ML'}
fitkwargs : dict
Kwargs passed to ``xclim.indices.stats.fit`` used to impose values of certains parameters (`floc`, `fscale`).
offset : Quantified
For distributions bounded by zero (e.g. "gamma", "fisk"), the two-parameters distributions only accept positive
values. An offset can be added to make sure this is the case. This option will be removed in xclim >=0.49.0, ``xclim``
Expand Down Expand Up @@ -1302,7 +1308,7 @@ def standardized_precipitation_evapotranspiration_index(
else:
raise NotImplementedError(f"{dist} distribution is not implemented yet")
spei = standardized_index(
wb, freq, window, dist, method, cal_start, cal_end, params, **indexer
wb, freq, window, dist, method, fitkwargs, cal_start, cal_end, params, **indexer
)

return spei
Expand Down
73 changes: 53 additions & 20 deletions xclim/indices/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,26 +490,36 @@ def _fit_start(x, dist: str, **fitkwargs: Any) -> tuple[tuple, dict]:
return (chat,), {"loc": loc, "scale": scale}

if dist in ["gamma"]:
# not sure if the approximation holds for xmin < 0
xmin = x.min()
x_pos = x - (xmin if xmin <= 0 else 0)
x_pos = x_pos[x_pos > 0]
if "floc" in fitkwargs:
loc0 = fitkwargs["floc"]
else:
xs = sorted(x)
x1, x2, xn = xs[0], xs[1], xs[-1]
n = len(x)
cv = x.std() / x.mean()
p = 100 * (0.48265 + 0.32967 * cv) * n ** (-0.2984 * cv)
xp = np.percentile(x, p)
loc0 = (x1 * xn - xp**2) / (x1 + xn - 2 * xp)
loc0 = loc0 if loc0 < x1 else 0.9999 * x1
x_pos = x - loc0
m = x_pos.mean()
log_of_mean = np.log(m)
mean_of_logs = np.log(x_pos).mean()
a = log_of_mean - mean_of_logs
alpha = (1 + np.sqrt(1 + 4 * a / 3)) / (4 * a)
beta = m / alpha
kwargs = {"scale": beta}
if xmin < 0:
kwargs["loc"] = xmin
return (alpha,), kwargs
A = log_of_mean - mean_of_logs
a0 = (1 + np.sqrt(1 + 4 * A / 3)) / (4 * A)
scale0 = m / a0
kwargs = {"scale": scale0, "loc": loc0}
return (a0,), kwargs

if dist in ["fisk"]:
# not sure if the approximation holds for xmin < 0
xmin = x.min()
x_pos = x - (xmin if xmin <= 0 else 0)
x_pos = x_pos[x_pos > 0]
if "floc" in fitkwargs:
loc0 = fitkwargs["floc"]
else:
xs = sorted(x)
x1, x2, xn = xs[0], xs[1], xs[-1]
loc0 = (x1 * xn - x2**2) / (x1 + xn - 2 * x2)
loc0 = loc0 if loc0 < x1 else 0.9999 * x1
x_pos = x - loc0
m = x_pos.mean()
m2 = (x_pos**2).mean()
# method of moments:
Expand All @@ -519,10 +529,10 @@ def _fit_start(x, dist: str, **fitkwargs: Any) -> tuple[tuple, dict]:
# <x> = m
# <x^2> / <x>^2 = m2/m**2
# solving these equations yields
alpha = 2 * m**3 / (m2 + m**2)
beta = np.pi * m / np.sqrt(3) / np.sqrt(m2 - m**2)
kwargs = {"scale": alpha}
return (beta,), kwargs
scale0 = 2 * m**3 / (m2 + m**2)
c0 = np.pi * m / np.sqrt(3) / np.sqrt(m2 - m**2)
kwargs = {"scale": scale0, "loc": loc0}
return (c0,), kwargs
return (), {}


Expand Down Expand Up @@ -673,6 +683,7 @@ def standardized_index_fit_params(
window: int,
dist: str | scipy.stats.rv_continuous,
method: str,
fitkwargs: dict = {},
offset: Quantified | None = None,
**indexer,
) -> xr.DataArray:
Expand All @@ -698,6 +709,8 @@ def standardized_index_fit_params(
method : {'ML', 'APP', 'PWM'}
Name of the fitting method, such as `ML` (maximum likelihood), `APP` (approximate). The approximate method
uses a deterministic function that doesn't involve any optimization.
fitkwargs : dict
Kwargs passed to ``xclim.indices.stats.fit`` used to impose values of certains parameters (`floc`, `fscale`).
\*\*indexer
Indexing parameters to compute the indicator on a temporal subset of the data.
It accepts the same arguments as :py:func:`xclim.indices.generic.select_time`.
Expand Down Expand Up @@ -729,7 +742,24 @@ def standardized_index_fit_params(
f"The method `{method}` is not supported for distribution `{dist.name}`."
)
da, group = preprocess_standardized_index(da, freq, window, **indexer)
params = da.groupby(group).map(fit, dist=dist, method=method)

# convert floc units if needed
# this should be passed to scipy eventually, so it should not support strings
# with units. Perhaps using `fitkwargs` in this context is misleading?
for fpar in ["floc", "fscale"]:
if fpar in fitkwargs.keys():
if np.isscalar(fitkwargs[fpar]) is False:
fitkwargs[fpar] = convert_units_to(fitkwargs[fpar], da, context="hydro")

# Use zero inflated distributions
# Idea: Perhaps having specific zero-inflated distributions or an option passed
# to the standardized index (SPI/SPEI) to state we want to use zero-inflated
# would be a better way to organize this
if da.min() == 0 and dist.name in ["gamma", "fisk"]:
da0 = da.where(da > 0).copy()
else:
da0 = da
params = da0.groupby(group).map(fit, dist=dist, method=method, **fitkwargs)
cal_range = (
da.time.min().dt.strftime("%Y-%m-%d").item(),
da.time.max().dt.strftime("%Y-%m-%d").item(),
Expand Down Expand Up @@ -757,6 +787,7 @@ def standardized_index(
window: int,
dist: str | scipy.stats.rv_continuous | None,
method: str,
fitkwargs: dict,
cal_start: DateStr | None,
cal_end: DateStr | None,
params: Quantified | None,
Expand Down Expand Up @@ -837,6 +868,7 @@ def standardized_index(
window=1,
dist=dist,
method=method,
fitkwargs=fitkwargs,
)

# If params only contains a subset of main dataset time grouping
Expand All @@ -860,6 +892,7 @@ def reindex_time(da, da_ref):
lambda x: (x == 0).sum("time") / x.notnull().sum("time")
)
params, probs_of_zero = (reindex_time(dax, da) for dax in [params, probs_of_zero])
# should `da` below exclude zeros?
dist_probs = dist_method("cdf", params, da, dist=dist)
probs = probs_of_zero + ((1 - probs_of_zero) * dist_probs)
params_norm = xr.DataArray(
Expand Down

0 comments on commit 52bff86

Please sign in to comment.