Skip to content

Commit

Permalink
Lint unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Sep 14, 2024
1 parent e20cbcd commit fdd2f56
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 36 deletions.
67 changes: 34 additions & 33 deletions quantile_forest/tests/test_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ def check_regression_toy(name, weighted_quantile):

ForestRegressor = FOREST_REGRESSORS[name]

regr = ForestRegressor(n_estimators=10, max_samples_leaf=None, bootstrap=False, random_state=0)
regr.fit(X, y)
est = ForestRegressor(n_estimators=10, max_samples_leaf=None, bootstrap=False, random_state=0)
est.fit(X, y)

# Check model and apply outputs shape.
leaf_indices = regr.apply(X)
assert leaf_indices.shape == (len(X), regr.n_estimators)
assert 10 == len(regr)
leaf_indices = est.apply(X)
assert leaf_indices.shape == (len(X), est.n_estimators)
assert 10 == len(est)

# Check aggregated quantile predictions.
y_true = [[0.0, 0.5, 1.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]
y_pred = regr.predict(
y_pred = est.predict(
y_test,
quantiles=quantiles,
weighted_quantile=weighted_quantile,
Expand All @@ -78,15 +78,15 @@ def check_regression_toy(name, weighted_quantile):

# Check unaggregated quantile predictions.
y_true = [[0.25, 0.5, 0.75], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]
y_pred = regr.predict(
y_pred = est.predict(
y_test,
quantiles=quantiles,
weighted_quantile=weighted_quantile,
aggregate_leaves_first=False,
)
assert_allclose(y_pred, y_true)

assert regr._more_tags()
assert est._more_tags()


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
Expand All @@ -96,25 +96,25 @@ def test_regression_toy(name, weighted_quantile):


def check_california_criterion(name, criterion):
# Check for consistency on the California Housing Prices dataset.
"""Check for consistency on the California Housing dataset."""
ForestRegressor = FOREST_REGRESSORS[name]

regr = ForestRegressor(n_estimators=5, criterion=criterion, max_features=None, random_state=0)
regr.fit(X_california, y_california)
score = regr.score(X_california, y_california, quantiles=0.5)
est = ForestRegressor(n_estimators=5, criterion=criterion, max_features=None, random_state=0)
est.fit(X_california, y_california)
score = est.score(X_california, y_california, quantiles=0.5)
assert score > 0.9, f"Failed with max_features=None, criterion {criterion} and score={score}."

# Test maximum features.
regr = ForestRegressor(n_estimators=5, criterion=criterion, max_features=6, random_state=0)
regr.fit(X_california, y_california)
score = regr.score(X_california, y_california, quantiles=0.5)
est = ForestRegressor(n_estimators=5, criterion=criterion, max_features=6, random_state=0)
est.fit(X_california, y_california)
score = est.score(X_california, y_california, quantiles=0.5)
assert score > 0.9, f"Failed with max_features=6, criterion {criterion} and score={score}."

# Test sample weights.
regr = ForestRegressor(n_estimators=5, criterion=criterion, random_state=0)
est = ForestRegressor(n_estimators=5, criterion=criterion, random_state=0)
sample_weight = np.concatenate([np.zeros(1), np.ones(len(y_california) - 1)])
regr.fit(X_california, y_california, sample_weight=sample_weight)
score = regr.score(X_california, y_california, quantiles=0.5)
est.fit(X_california, y_california, sample_weight=sample_weight)
score = est.score(X_california, y_california, quantiles=0.5)
assert score > 0.9, f"Failed with criterion {criterion}, sample weight and score={score}."


Expand All @@ -125,7 +125,7 @@ def test_california(name, criterion):


def check_predict_quantiles_toy(name):
# Check quantile predictions on toy data.
"""Check quantile predictions on toy data."""
quantiles = [0.25, 0.5, 0.75]

ForestRegressor = FOREST_REGRESSORS[name]
Expand Down Expand Up @@ -267,6 +267,7 @@ def check_predict_quantiles(
weighted_quantile,
aggregate_leaves_first,
):
"""Check quantile predictions."""
ForestRegressor = FOREST_REGRESSORS[name]

# Check predicted quantiles on (semi-)random data.
Expand Down Expand Up @@ -577,7 +578,7 @@ def test_predict_quantiles(


def check_quantile_ranks_toy(name):
# Check rank predictions on toy data.
"""Check quantile ranks on toy data."""
ForestRegressor = FOREST_REGRESSORS[name]

# Check predicted ranks on toy sample.
Expand Down Expand Up @@ -650,7 +651,7 @@ def test_quantile_ranks_toy(name):


def check_quantile_ranks(name):
# Check rank predictions.
"""Check quantile ranks."""
ForestRegressor = FOREST_REGRESSORS[name]

# Check predicted ranks on (semi-)random data.
Expand Down Expand Up @@ -698,7 +699,7 @@ def test_quantile_ranks(name):


def check_proximity_counts(name):
# Check proximity counts.
"""Check proximity counts."""
ForestRegressor = FOREST_REGRESSORS[name]

# Check proximity counts on toy sample.
Expand Down Expand Up @@ -795,7 +796,7 @@ def test_proximity_counts(name):


def check_max_samples_leaf(name):
# Check that the `max_samples_leaf` parameter correctly samples leaves.
"""Check that the `max_samples_leaf` parameter correctly samples leaves."""
X = X_california
y = y_california

Expand Down Expand Up @@ -849,7 +850,7 @@ def test_max_samples_leaf(name):


def check_oob_samples(name):
# Check OOB sample generation.
"""Check OOB sample generation."""
X = X_california
y = y_california

Expand All @@ -874,7 +875,7 @@ def test_oob_samples(name):


def check_oob_samples_duplicates(name):
# Check OOB sampling with duplicates.
"""Check OOB sampling with duplicates."""
X = np.array(
[
[1, 2, 3],
Expand Down Expand Up @@ -915,7 +916,7 @@ def check_predict_oob(
weighted_quantile,
aggregate_leaves_first,
):
# Check OOB predictions.
"""Check OOB predictions."""
X = X_california
y = y_california

Expand Down Expand Up @@ -1126,7 +1127,7 @@ def test_predict_oob(


def check_quantile_ranks_oob(name):
# Check OOB quantile rank predictions.
"""Check OOB quantile ranks."""
X = X_california
y = y_california

Expand Down Expand Up @@ -1183,7 +1184,7 @@ def test_quantile_ranks_oob(name):


def check_proximity_counts_oob(name):
# Check OOB proximity counts.
"""Check OOB proximity counts."""
X = X_california
y = y_california

Expand Down Expand Up @@ -1262,6 +1263,7 @@ def test_proximity_counts_oob(name):


def check_monotonic_constraints(name, max_samples_leaf):
"""Check monotonic constraints."""
ForestRegressor = FOREST_REGRESSORS[name]

n_samples = 1000
Expand Down Expand Up @@ -1335,8 +1337,7 @@ def test_monotonic_constraints(name, max_samples_leaf):


def check_serialization(name, sparse_pickle, monotonic_cst, multi_target):
# Check model serialization/deserialization.

"""Check model serialization/deserialization."""
X = X_california

if multi_target:
Expand Down Expand Up @@ -1370,7 +1371,7 @@ def test_serialization(name, sparse_pickle, monotonic_cst, multi_target):


def test_calc_quantile():
# Check quantile calculations.
"""Check quantile calculations."""
quantiles = [0.0, 0.25, 0.5, 0.75, 1.0]
interpolations = [b"linear", b"lower", b"higher", b"midpoint", b"nearest"]

Expand Down Expand Up @@ -1440,7 +1441,7 @@ def test_calc_quantile():


def test_calc_weighted_quantile():
# Check weighted quantile calculations.
"""Check weighted quantile calculations."""
quantiles = [0.0, 0.25, 0.5, 0.75, 1.0]
interpolations = [b"linear", b"lower", b"higher", b"midpoint", b"nearest"]

Expand Down Expand Up @@ -1559,7 +1560,7 @@ def _dicts_to_input_pairs(input_dicts):


def test_calc_quantile_rank():
# Check quantile rank calculations.
"""Check quantile rank calculations."""
kinds = [b"rank", b"weak", b"strict", b"mean"]

inputs = [
Expand Down
6 changes: 3 additions & 3 deletions quantile_forest/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def test_generate_unsampled_indices():
# Check unsampled indices generation.
"""Check unsampled indices generation."""
max_index = 20
duplicates = [[1, 4], [19, 10], [2, 3, 5], [6, 13]]

Expand Down Expand Up @@ -40,7 +40,7 @@ def _generate_unsampled_indices(sample_indices, n_total_samples):


def test_group_indices_by_value():
# Check grouping indices by value.
"""Check grouping indices by value."""
inputs = np.array([1, 3, 2, 2, 5, 4, 5, 5], dtype=np.int64)

actual_indices, actual_values = group_indices_by_value(inputs)
Expand All @@ -58,7 +58,7 @@ def test_group_indices_by_value():


def test_map_indices_to_leaves():
# Check mapping of indices to leaf nodes.
"""Check mapping of indices to leaf nodes."""
y_train_leaves = np.zeros((3, 1, 3), dtype=np.int64)
bootstrap_indices = np.array([[1], [2], [3], [4], [5]], dtype=np.int64)
leaf_indices = np.array([1, 2])
Expand Down

0 comments on commit fdd2f56

Please sign in to comment.