Skip to content

Commit

Permalink
Test OOB score monotonicity
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Aug 14, 2024
1 parent 02fe264 commit b1f50a1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion quantile_forest/_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def predict(
train_indices = np.zeros(len(X), dtype=int)
train_indices[indices] = y_train_leaves[tree, leaves, output, 0]
if self.monotonic_cst is not None:
clip_min, clip_max = np.zeros(len(X)), np.zeros(len(X))
clip_min, clip_max = np.full(len(X), -np.inf), np.full(len(X), np.inf)
clip_min[indices] = y_bound_leaves[tree, leaves, 0]
clip_max[indices] = y_bound_leaves[tree, leaves, 1]
leaf_values[:, tree] = y_train[train_indices - 1, output]
Expand Down
25 changes: 15 additions & 10 deletions quantile_forest/tests/test_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,7 @@ def check_monotonic_constraints(name, max_samples_leaf):
X_train = X[train]
y_train = y[train]
X_test = np.copy(X[test])
y_test = np.copy(y[test])
X_test_incr = np.copy(X_test)
X_test_decr = np.copy(X_test)
X_test_incr[:, 0] += 10
Expand All @@ -1283,23 +1284,27 @@ def check_monotonic_constraints(name, max_samples_leaf):
monotonic_cst[1] = -1

est = ForestRegressor(
n_estimators=5,
max_depth=8,
max_samples_leaf=max_samples_leaf,
monotonic_cst=monotonic_cst,
max_leaf_nodes=n_samples_train,
bootstrap=True,
)

est.fit(X_train, y_train)
y = est.predict(X_test)
for oob_score in [True]:
if not oob_score:
est.fit(X_train, y_train)

Check warning on line 1295 in quantile_forest/tests/test_quantile_forest.py

View check run for this annotation

Codecov / codecov/patch

quantile_forest/tests/test_quantile_forest.py#L1295

Added line #L1295 was not covered by tests
else:
est.fit(X_test, y_test)

y = est.predict(X_test, oob_score=oob_score)

# Check the monotonic increase constraint.
y_incr = est.predict(X_test_incr)
assert np.all(y_incr >= y)
# Check the monotonic increase constraint.
y_incr = est.predict(X_test_incr, oob_score=oob_score)
assert np.all(y_incr >= y)

# Check the monotonic decrease constraint.
y_decr = est.predict(X_test_decr)
assert np.all(y_decr <= y)
# Check the monotonic decrease constraint.
y_decr = est.predict(X_test_decr, oob_score=oob_score)
assert np.all(y_decr <= y)


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
Expand Down

0 comments on commit b1f50a1

Please sign in to comment.