Skip to content

Commit

Permalink
Merge pull request #54 from Blue-Yonder-OSS/quantile_matching_qpd
Browse files Browse the repository at this point in the history
quantile-parameterized distributions
  • Loading branch information
FelixWick authored Oct 19, 2023
2 parents bb8a3bc + 31b2fb6 commit ab2d958
Show file tree
Hide file tree
Showing 4 changed files with 359 additions and 8 deletions.
166 changes: 159 additions & 7 deletions cyclic_boosting/quantile_matching.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,163 @@
import numpy as np
from numpy import exp, log, sinh, arcsinh, arccosh
from scipy.optimize import curve_fit
from scipy.stats import norm, gamma, nbinom
from scipy.stats import norm, gamma, nbinom, logistic
from scipy.interpolate import InterpolatedUnivariateSpline

from typing import Optional


class J_QPD_S:
"""
Implementation of the semi-bounded mode of Johnson Quantile-Parameterized
Distributions (J-QPD), see https://repositories.lib.utexas.edu/bitstream/handle/2152/63037/HADLOCK-DISSERTATION-2017.pdf
(Due to the Python keyword, the parameter lambda from this reference is named kappa below.).
A distribution is parameterized by a symmetric-percentile triplet (SPT).
Parameters
----------
alpha : float
lower quantile of SPT (upper is ``1 - alpha``)
qv_low : float
quantile function value of ``alpha``
qv_median : float
quantile function value of quantile 0.5
qv_high : float
quantile function value of quantile ``1 - alpha``
l : float
lower bound of semi-bounded range (default is 0)
version: str
options are ``normal`` (default) or ``logistic``
"""

def __init__(
self,
alpha: float,
qv_low: float,
qv_median: float,
qv_high: float,
l: Optional[float] = 0,
version: Optional[str] = "normal",
):
if version == "normal":
self.phi = norm()
elif version == "logistic":
self.phi = logistic()
else:
raise Exception("Invalid version.")

self.l = l

self.c = self.phi.ppf(1 - alpha)

self.L = log(qv_low - l)
self.H = log(qv_high - l)
self.B = log(qv_median - l)

if self.L + self.H - 2 * self.B > 0:
self.n = 1
self.theta = qv_low - l
elif self.L + self.H - 2 * self.B < 0:
self.n = -1
self.theta = qv_high - l
else:
self.n = 0
self.theta = qv_median - l

self.delta = 1.0 / self.c * sinh(arccosh((self.H - self.L) / (2 * min(self.B - self.L, self.H - self.B))))

self.kappa = 1.0 / (self.delta * self.c) * min(self.H - self.B, self.B - self.L)

def ppf(self, x):
return self.l + self.theta * exp(
self.kappa * sinh(arcsinh(self.delta * self.phi.ppf(x)) + arcsinh(self.n * self.c * self.delta))
)

def cdf(self, x):
return self.phi.cdf(
1.0
/ self.delta
* sinh(arcsinh(1.0 / self.kappa * log((x - self.l) / self.theta)) - arcsinh(self.n * self.c * self.delta))
)


class J_QPD_B:
"""
Implementation of the bounded mode of Johnson Quantile-Parameterized
Distributions (J-QPD), see https://repositories.lib.utexas.edu/bitstream/handle/2152/63037/HADLOCK-DISSERTATION-2017.pdf.
(Due to the Python keyword, the parameter lambda from this reference is named kappa below.)
A distribution is parameterized by a symmetric-percentile triplet (SPT).
Parameters
----------
alpha : float
lower quantile of SPT (upper is ``1 - alpha``)
qv_low : float
quantile function value of ``alpha``
qv_median : float
quantile function value of quantile 0.5
qv_high : float
quantile function value of quantile ``1 - alpha``
l : float
lower bound of supported range
u : float
upper bound of supported range
version: str
options are ``normal`` (default) or ``logistic``
"""

def __init__(
self,
alpha: float,
qv_low: float,
qv_median: float,
qv_high: float,
l: float,
u: float,
version: Optional[str] = "normal",
):
if version == "normal":
self.phi = norm()
elif version == "logistic":
self.phi = logistic()
else:
raise Exception("Invalid version.")

self.l = l
self.u = u

self.c = self.phi.ppf(1 - alpha)

self.L = self.phi.ppf((qv_low - l) / (u - l))
self.H = self.phi.ppf((qv_high - l) / (u - l))
self.B = self.phi.ppf((qv_median - l) / (u - l))

if self.L + self.H - 2 * self.B > 0:
self.n = 1
self.xi = self.L
elif self.L + self.H - 2 * self.B < 0:
self.n = -1
self.xi = self.H
else:
self.n = 0
self.xi = self.B

self.delta = 1.0 / self.c * arccosh((self.H - self.L) / (2 * min(self.B - self.L, self.H - self.B)))

self.kappa = (self.H - self.L) / sinh(2 * self.delta * self.c)

def ppf(self, x):
return self.l + (self.u - self.l) * self.phi.cdf(
self.xi + self.kappa * sinh(self.delta * (self.phi.ppf(x) + self.n * self.c))
)

def cdf(self, x):
return self.phi.cdf(
1.0 / self.delta * arcsinh(1.0 / self.kappa * (self.phi.ppf((x - self.l) / (self.u - self.l)) - self.xi))
- self.n * self.c
)


def quantile_fit_gaussian(quantiles: np.ndarray, quantile_values: np.ndarray, mode: Optional[str] = "ppf") -> callable:
"""
Interpolation of a quantile function (with quantiles estimated, e.g., by
Expand All @@ -21,8 +173,8 @@ def quantile_fit_gaussian(quantiles: np.ndarray, quantile_values: np.ndarray, mo
mode : str
decides about kind of returned callable, possible values are:
- ``ppf``: quantile (default)
- ``dist``: fitted negative binomial (scipy function)
- ``ppf``: quantile function (default)
- ``dist``: fitted Gaussian function (scipy function)
- ``cdf``: CDF function
Returns
Expand Down Expand Up @@ -60,8 +212,8 @@ def quantile_fit_gamma(quantiles: np.ndarray, quantile_values: np.ndarray, mode:
mode : str
decides about kind of returned callable, possible values are:
- ``ppf``: quantile (default)
- ``dist``: fitted negative binomial (scipy function)
- ``ppf``: quantile function (default)
- ``dist``: fitted Gamma function (scipy function)
- ``cdf``: CDF function
Returns
Expand Down Expand Up @@ -123,8 +275,8 @@ def quantile_fit_nbinom(quantiles: np.ndarray, quantile_values: np.ndarray, mode
mode : str
decides about kind of returned callable, possible values are:
- ``ppf``: quantile (default)
- ``dist``: fitted negative binomial (scipy function)
- ``ppf``: quantile function (default)
- ``dist``: fitted negative binomial function (scipy function)
- ``cdf``: CDF function
Returns
Expand Down
26 changes: 26 additions & 0 deletions docs/source/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,29 @@ the individual predictions of a given data set:
```python
CB_est.get_feature_contributions(X_test)
```


## Quantile Regression
Below you can find an example of a quantile regression model for three
different quantiles, with a subsequent quantile matching (to get a full
individual probability distribution from the estimated quantiles) by means of a
quantile-parameterized distribution for an arbitrary test sample:
```python
from cyclic_boosting.pipelines import pipeline_CBMultiplicativeQuantileRegressor
from cyclic_boosting.quantile_matching import J_QPD_S

CB_est_qlow = pipeline_CBMultiplicativeQuantileRegressor(quantile=0.2)
CB_est_qlow.fit(X_train.copy(), y)
yhat_qlow = CB_est_qlow.predict(X_test.copy())

CB_est_qmedian = pipeline_CBMultiplicativeQuantileRegressor(quantile=0.5)
CB_est_qmedian.fit(X_train.copy(), y)
yhat_qmedian = CB_est_qmedian.predict(X_test.copy())

CB_est_qhigh = pipeline_CBMultiplicativeQuantileRegressor(quantile=0.8)
CB_est_qhigh.fit(X_train.copy(), y)
yhat_qhigh = CB_est_qhigh.predict(X_test.copy())

j_qpd_s_42 = J_QPD_S(0.2, yhat_qlow[42], yhat_qmedian[42], yhat_qhigh[42])
yhat_42_percentile95 = j_qpd_s_42.ppf(0.95)
```
49 changes: 48 additions & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
pipeline_CBAdditiveGenericCRegressor,
pipeline_CBGenericClassifier,
)
from cyclic_boosting.quantile_matching import quantile_fit_gamma, quantile_fit_nbinom, quantile_fit_spline
from cyclic_boosting.quantile_matching import quantile_fit_gamma, quantile_fit_nbinom, quantile_fit_spline, J_QPD_S
from cyclic_boosting.utils import smear_discrete_cdftruth
from tests.utils import plot_CB, costs_mad, costs_mse

Expand Down Expand Up @@ -486,6 +486,53 @@ def test_multiplicative_quantile_regression_90(is_plot, prepare_data, features,
np.testing.assert_almost_equal(quantile_acc, 0.9015, 3)


@pytest.mark.skip(reason="Long running time")
def test_multiplicative_quantile_regression_pdf_J_QPD_S(is_plot, prepare_data, features, feature_properties):
X, y = prepare_data

quantiles = []
quantile_values = []
for quantile in [0.2, 0.5, 0.8]:
CB_est = cb_multiplicative_quantile_regressor_model(
quantile=quantile, features=features, feature_properties=feature_properties
)
CB_est.fit(X.copy(), y)
yhat = CB_est.predict(X.copy())
quantile_values.append(yhat)
quantiles.append(quantile)

quantiles = np.asarray(quantiles)
quantile_values = np.asarray(quantile_values)

cdf_truth_list = []
n_samples = len(X)
for i in range(n_samples):
j_qpd_s = J_QPD_S(0.2, quantile_values[0, i], quantile_values[1, i], quantile_values[2, i])
np.testing.assert_almost_equal(j_qpd_s.ppf(0.2), quantile_values[0, i], 3)
np.testing.assert_almost_equal(j_qpd_s.ppf(0.5), quantile_values[1, i], 3)
np.testing.assert_almost_equal(j_qpd_s.ppf(0.8), quantile_values[2, i], 3)
if i == 24:
np.testing.assert_almost_equal(j_qpd_s.ppf(0.1), 0.592, 3)
np.testing.assert_almost_equal(j_qpd_s.ppf(0.9), 5.783, 3)

if is_plot:
plt.plot([0.2, 0.5, 0.8], [quantile_values[0, i], quantile_values[1, i], quantile_values[2, i]], "ro")
xs = np.linspace(0.0, 1.0, 100)
plt.plot(xs, j_qpd_s.ppf(xs))
plt.savefig("J_QPD_S_integration_" + str(i) + ".png")
plt.clf()

if is_plot:
cdf_truth = smear_discrete_cdftruth(j_qpd_s.cdf, y[i])
cdf_truth_list.append(cdf_truth)

cdf_truth = np.asarray(cdf_truth_list)
if is_plot:
plt.hist(cdf_truth[cdf_truth > 0], bins=30)
plt.savefig("J_QPD_S_cdf_truth_histo.png")
plt.clf()


@pytest.mark.skip(reason="Long running time")
def test_multiplicative_quantile_regression_spline(is_plot, prepare_data, features, feature_properties):
X, y = prepare_data
Expand Down
Loading

0 comments on commit ab2d958

Please sign in to comment.