From 9fccd854200d7a5dd3958dafa2fb2ba2271fe48d Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Wed, 28 Aug 2024 21:35:34 -0700 Subject: [PATCH] Clean up unit test warnings --- quantile_forest/tests/test_quantile_forest.py | 341 ++++++++---------- 1 file changed, 156 insertions(+), 185 deletions(-) diff --git a/quantile_forest/tests/test_quantile_forest.py b/quantile_forest/tests/test_quantile_forest.py index a3ab611..9da68a3 100755 --- a/quantile_forest/tests/test_quantile_forest.py +++ b/quantile_forest/tests/test_quantile_forest.py @@ -69,26 +69,22 @@ def check_regression_toy(name, weighted_quantile): # Check aggregated quantile predictions. y_true = [[0.0, 0.5, 1.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred = regr.predict( - y_test, - quantiles=quantiles, - weighted_quantile=weighted_quantile, - aggregate_leaves_first=True, - ) + y_pred = regr.predict( + y_test, + quantiles=quantiles, + weighted_quantile=weighted_quantile, + aggregate_leaves_first=True, + ) assert_allclose(y_pred, y_true) # Check unaggregated quantile predictions. y_true = [[0.25, 0.5, 0.75], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred = regr.predict( - y_test, - quantiles=quantiles, - weighted_quantile=weighted_quantile, - aggregate_leaves_first=False, - ) + y_pred = regr.predict( + y_test, + quantiles=quantiles, + weighted_quantile=weighted_quantile, + aggregate_leaves_first=False, + ) assert_allclose(y_pred, y_true) assert regr._more_tags() @@ -208,60 +204,54 @@ def check_predict_quantiles_toy(name): est.fit(X, y) # Check that weighted and unweighted quantiles are approximately equal. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred1 = est.predict( - X, - quantiles=quantiles, - weighted_quantile=True, - weighted_leaves=False, - oob_score=oob_score, - ) - y_pred2 = est.predict( - X, - quantiles=quantiles, - weighted_quantile=False, - weighted_leaves=False, - oob_score=oob_score, - ) + y_pred1 = est.predict( + X, + quantiles=quantiles, + weighted_quantile=True, + weighted_leaves=False, + oob_score=oob_score, + ) + y_pred2 = est.predict( + X, + quantiles=quantiles, + weighted_quantile=False, + weighted_leaves=False, + oob_score=oob_score, + ) assert_allclose(y_pred1, y_pred2) # Check that weighted and unweighted leaves are not equal. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred1 = est.predict( - X, - quantiles=quantiles, - weighted_quantile=True, - weighted_leaves=True, - oob_score=oob_score, - ) - y_pred2 = est.predict( - X, - quantiles=quantiles, - weighted_quantile=True, - weighted_leaves=False, - oob_score=oob_score, - ) + y_pred1 = est.predict( + X, + quantiles=quantiles, + weighted_quantile=True, + weighted_leaves=True, + oob_score=oob_score, + ) + y_pred2 = est.predict( + X, + quantiles=quantiles, + weighted_quantile=True, + weighted_leaves=False, + oob_score=oob_score, + ) assert_raises(AssertionError, assert_allclose, y_pred1, y_pred2) # Check that leaf weighting without weighted quantiles does nothing. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred1 = est.predict( - X, - quantiles=quantiles, - weighted_quantile=False, - weighted_leaves=True, - oob_score=oob_score, - ) - y_pred2 = est.predict( - X, - quantiles=quantiles, - weighted_quantile=False, - weighted_leaves=False, - oob_score=oob_score, - ) + y_pred1 = est.predict( + X, + quantiles=quantiles, + weighted_quantile=False, + weighted_leaves=True, + oob_score=oob_score, + ) + y_pred2 = est.predict( + X, + quantiles=quantiles, + weighted_quantile=False, + weighted_leaves=False, + oob_score=oob_score, + ) assert_array_equal(y_pred1, y_pred2) @@ -304,14 +294,12 @@ def check_predict_quantiles( random_state=0, ) est.fit(X_train, y_train) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred = est.predict( - X_test, - quantiles=quantiles, - weighted_quantile=weighted_quantile, - aggregate_leaves_first=aggregate_leaves_first, - ) + y_pred = est.predict( + X_test, + quantiles=quantiles, + weighted_quantile=weighted_quantile, + aggregate_leaves_first=aggregate_leaves_first, + ) if isinstance(quantiles, list): assert y_pred.shape == (X_test.shape[0], len(quantiles)) assert_array_almost_equal(y_pred[:, 1], y_test, -e1_high) @@ -332,14 +320,12 @@ def check_predict_quantiles( random_state=0, ) est.fit(X_train, y_train) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred = est.predict( - X_test, - quantiles=quantiles, - weighted_quantile=weighted_quantile, - aggregate_leaves_first=aggregate_leaves_first, - ) + y_pred = est.predict( + X_test, + quantiles=quantiles, + weighted_quantile=weighted_quantile, + aggregate_leaves_first=aggregate_leaves_first, + ) if isinstance(quantiles, list): assert y_pred.shape == (X_test.shape[0], len(quantiles)) else: @@ -384,22 +370,20 @@ def check_predict_quantiles( random_state=0, ) est.fit(X_train, y_train) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred_1 = est.predict( - X_test, - quantiles=quantiles, - weighted_quantile=weighted_quantile, - weighted_leaves=True, - aggregate_leaves_first=False, - ) - y_pred_2 = est.predict( - X_test, - quantiles=quantiles, - weighted_quantile=weighted_quantile, - weighted_leaves=False, - aggregate_leaves_first=False, - ) + y_pred_1 = est.predict( + X_test, + quantiles=quantiles, + weighted_quantile=weighted_quantile, + weighted_leaves=True, + aggregate_leaves_first=False, + ) + y_pred_2 = est.predict( + X_test, + quantiles=quantiles, + weighted_quantile=weighted_quantile, + weighted_leaves=False, + aggregate_leaves_first=False, + ) assert_allclose(y_pred_1, y_pred_2) # Check that aggregated and unaggregated quantiles are all equal. @@ -495,15 +479,12 @@ def check_predict_quantiles( random_state=0, ) est.fit(X.reshape(-1, 1), y) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred = est.predict( - X.reshape(-1, 1), - quantiles=quantiles, - weighted_quantile=weighted_quantile, - aggregate_leaves_first=aggregate_leaves_first, - ) + y_pred = est.predict( + X.reshape(-1, 1), + quantiles=quantiles, + weighted_quantile=weighted_quantile, + aggregate_leaves_first=aggregate_leaves_first, + ) score = est.score(X.reshape(-1, 1), y, quantiles=0.5) assert y_pred.ndim == (3 if isinstance(quantiles, list) else 2) assert y_pred.shape[1] == y.shape[1] @@ -935,15 +916,13 @@ def check_predict_oob( median_idx = quantiles.index(0.5) # Check that `R^2` score from OOB predictions is close to `oob_score_`. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred = est.predict( - X, - quantiles=quantiles, - weighted_quantile=weighted_quantile, - aggregate_leaves_first=aggregate_leaves_first, - oob_score=True, - ) + y_pred = est.predict( + X, + quantiles=quantiles, + weighted_quantile=weighted_quantile, + aggregate_leaves_first=aggregate_leaves_first, + oob_score=True, + ) y_pred_score = r2_score(y, y_pred[:, median_idx]) if n_quantiles is not None: assert y_pred.shape == (len(X), n_quantiles) @@ -962,16 +941,14 @@ def check_predict_oob( perm = np.random.permutation(len(X)) for indices in np.split(np.arange(len(X)), range(100, len(X), 100)): X_chunk = X[perm[indices]] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred_chunk = est.predict( - X_chunk, - quantiles=quantiles, - weighted_quantile=weighted_quantile, - aggregate_leaves_first=aggregate_leaves_first, - oob_score=True, - indices=perm[indices], - ) + y_pred_chunk = est.predict( + X_chunk, + quantiles=quantiles, + weighted_quantile=weighted_quantile, + aggregate_leaves_first=aggregate_leaves_first, + oob_score=True, + indices=perm[indices], + ) y_pred_chunks[indices, ...] = y_pred_chunk if n_quantiles is not None: assert y_pred_chunk.shape == (len(X_chunk), n_quantiles) @@ -987,48 +964,42 @@ def check_predict_oob( for i, estimator in enumerate(est.estimators_): unsampled_indices[i] = est._get_unsampled_indices(estimator) est.unsampled_indices_ = unsampled_indices - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred_precomputed_indices = est.predict( - X, - quantiles=quantiles, - weighted_quantile=weighted_quantile, - aggregate_leaves_first=aggregate_leaves_first, - oob_score=True, - ) + y_pred_precomputed_indices = est.predict( + X, + quantiles=quantiles, + weighted_quantile=weighted_quantile, + aggregate_leaves_first=aggregate_leaves_first, + oob_score=True, + ) assert np.all(y_pred == y_pred_precomputed_indices) # Check single-row OOB scoring. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred_single_row = est.predict( - X[:1], - quantiles=quantiles, - weighted_quantile=weighted_quantile, - aggregate_leaves_first=aggregate_leaves_first, - oob_score=True, - indices=np.zeros(1), - ) + y_pred_single_row = est.predict( + X[:1], + quantiles=quantiles, + weighted_quantile=weighted_quantile, + aggregate_leaves_first=aggregate_leaves_first, + oob_score=True, + indices=np.zeros(1), + ) assert np.all(y_pred[:1] == y_pred_single_row) # Check that OOB predictions indexed by -1 return IB predictions. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred_ib = est.predict( - X, - quantiles=quantiles, - weighted_quantile=weighted_quantile, - aggregate_leaves_first=aggregate_leaves_first, - oob_score=False, - ) - y_pred_oob = est.predict( - X, - quantiles=quantiles, - weighted_quantile=weighted_quantile, - aggregate_leaves_first=aggregate_leaves_first, - oob_score=True, - indices=-np.ones(len(X)), - ) + y_pred_ib = est.predict( + X, + quantiles=quantiles, + weighted_quantile=weighted_quantile, + aggregate_leaves_first=aggregate_leaves_first, + oob_score=False, + ) + y_pred_oob = est.predict( + X, + quantiles=quantiles, + weighted_quantile=weighted_quantile, + aggregate_leaves_first=aggregate_leaves_first, + oob_score=True, + indices=-np.ones(len(X)), + ) assert np.all(y_pred_ib == y_pred_oob) # Check OOB predictions with `default_quantiles`. @@ -1036,10 +1007,8 @@ def check_predict_oob( est1.fit(X, y) est2 = ForestRegressor(n_estimators=1, default_quantiles=quantiles, random_state=0) est2.fit(X, y) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - y_pred_oob1 = est1.predict(X, quantiles=quantiles) - y_pred_oob2 = est2.predict(X) + y_pred_oob1 = est1.predict(X, quantiles=quantiles) + y_pred_oob2 = est2.predict(X) assert_allclose(y_pred_oob1, y_pred_oob2) # Check error if OOB score without `indices` do not match training count. @@ -1084,32 +1053,30 @@ def check_predict_oob( # Check error if no bootstrapping. est = ForestRegressor(n_estimators=1, bootstrap=False) est.fit(X, y) + assert_raises( + ValueError, + est.predict, + X, + weighted_quantile=weighted_quantile, + aggregate_leaves_first=aggregate_leaves_first, + oob_score=True, + ) with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) - assert_raises( - ValueError, - est.predict, - X, - weighted_quantile=weighted_quantile, - aggregate_leaves_first=aggregate_leaves_first, - oob_score=True, - ) assert np.all(est._get_unsampled_indices(est.estimators_[0]) == np.array([])) # Check error if number of scoring and training samples are different. est = ForestRegressor(n_estimators=1, bootstrap=True) est.fit(X, y) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - assert_raises( - ValueError, - est.predict, - X[:1], - y[:1], - weighted_quantile=weighted_quantile, - aggregate_leaves_first=aggregate_leaves_first, - oob_score=True, - ) + assert_raises( + ValueError, + est.predict, + X[:1], + y[:1], + weighted_quantile=weighted_quantile, + aggregate_leaves_first=aggregate_leaves_first, + oob_score=True, + ) @pytest.mark.parametrize("name", FOREST_REGRESSORS) @@ -1318,11 +1285,15 @@ def check_monotonic_constraints(name, max_samples_leaf): assert score > 0.75 # Check the monotonic increase constraint. - y_incr = est.predict(X_test_incr, oob_score=oob_score) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + y_incr = est.predict(X_test_incr, oob_score=oob_score) assert np.all(y_incr >= y) # Check the monotonic decrease constraint. - y_decr = est.predict(X_test_decr, oob_score=oob_score) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + y_decr = est.predict(X_test_decr, oob_score=oob_score) assert np.all(y_decr <= y) # Check error if `max_samples_leaf` != 1.