Skip to content

Commit

Permalink
fix for potential empty bins in multi-dimensional features
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixWick committed Oct 28, 2023
1 parent 8f5d039 commit 350caf4
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 16 deletions.
17 changes: 12 additions & 5 deletions cyclic_boosting/generic_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,18 @@ def calc_parameters(
"""
sorting = feature.lex_binned_data.argsort()
sorted_bins = feature.lex_binned_data[sorting]
splits_indices = np.unique(sorted_bins, return_index=True)[1][1:]
bins, split_indices = np.unique(sorted_bins, return_index=True)
split_indices = split_indices[1:]

y_pred = np.hstack((y[..., np.newaxis], self.unlink_func(pred.predict_link())[..., np.newaxis]))
y_pred = np.hstack((y_pred, self.weights[..., np.newaxis]))
y_pred_bins = np.split(y_pred[sorting], splits_indices)
y_pred_bins = np.split(y_pred[sorting], split_indices)

# keep potential empty bins in multi-dimensional features
all_bins = range(max(feature.lex_binned_data) + 1)
empty_bins = list(set(bins) ^ set(all_bins))
for i in empty_bins:
y_pred_bins.insert(i, np.zeros((0, 3)))

n_bins = len(y_pred_bins)
parameters = np.zeros(n_bins)
Expand Down Expand Up @@ -380,14 +387,14 @@ def quantile_costs(prediction: np.ndarray, y: np.ndarray, weights: np.ndarray, q
float
calcualted quantile costs
"""
if not len(y) > 0:
raise ValueError("Loss cannot be computed on empty data")
else:
if len(y) > 0:
sum_weighted_error = np.nansum(
((y < prediction) * (1 - quantile) * (prediction - y) + (y >= prediction) * quantile * (y - prediction))
* weights
)
return sum_weighted_error / np.nansum(weights)
else:
return 0


def quantile_global_scale(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cyclic-boosting"
version = "1.2.0"
version = "1.2.1"
description = "Implementation of Cyclic Boosting machine learning algorithms"
authors = ["Blue Yonder GmbH"]
packages = [{include = "cyclic_boosting"}]
Expand Down
20 changes: 10 additions & 10 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,9 @@ def test_multiplicative_quantile_regression_90(is_plot, prepare_data, features,
def test_multiplicative_quantile_regression_pdf_J_QPD_S(is_plot, prepare_data, features, feature_properties):
X, y = prepare_data

# empty bin check
X["P_ID"].iloc[1] = 20

quantiles = []
quantile_values = []
for quantile in [0.2, 0.5, 0.8]:
Expand All @@ -513,23 +516,20 @@ def test_multiplicative_quantile_regression_pdf_J_QPD_S(is_plot, prepare_data, f
np.testing.assert_almost_equal(j_qpd_s.ppf(0.2), quantile_values[0, i], 3)
np.testing.assert_almost_equal(j_qpd_s.ppf(0.5), quantile_values[1, i], 3)
np.testing.assert_almost_equal(j_qpd_s.ppf(0.8), quantile_values[2, i], 3)
if i == 24:
np.testing.assert_almost_equal(j_qpd_s.ppf(0.1), 0.457, 3)
np.testing.assert_almost_equal(j_qpd_s.ppf(0.9), 5.509, 3)

if is_plot:
if is_plot:
cdf_truth = smear_discrete_cdftruth(j_qpd_s.cdf, y[i])
cdf_truth_list.append(cdf_truth)

if i == 24:
plt.plot([0.2, 0.5, 0.8], [quantile_values[0, i], quantile_values[1, i], quantile_values[2, i]], "ro")
xs = np.linspace(0.0, 1.0, 100)
plt.plot(xs, j_qpd_s.ppf(xs))
plt.savefig("J_QPD_S_integration_" + str(i) + ".png")
plt.clf()

if is_plot:
cdf_truth = smear_discrete_cdftruth(j_qpd_s.cdf, y[i])
cdf_truth_list.append(cdf_truth)

cdf_truth = np.asarray(cdf_truth_list)
if is_plot:
cdf_truth = np.asarray(cdf_truth_list)
plt.hist(cdf_truth[cdf_truth > 0], bins=30)
plt.savefig("J_QPD_S_cdf_truth_histo.png")
plt.clf()
Expand All @@ -555,7 +555,7 @@ def test_multiplicative_quantile_regression_spline(is_plot, prepare_data, featur

i = 24
spl_fit = quantile_fit_spline(quantiles, quantile_values[:, i])
np.testing.assert_almost_equal(spl_fit(0.2), 0.527, 3)
np.testing.assert_almost_equal(spl_fit(0.2), 0.529, 3)
np.testing.assert_almost_equal(spl_fit(0.5), 2.193, 3)
np.testing.assert_almost_equal(spl_fit(0.8), 4.21, 3)

Expand Down

0 comments on commit 350caf4

Please sign in to comment.