Skip to content

Commit

Permalink
Test model serialization (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson authored Feb 22, 2024
1 parent 928a489 commit c41cea3
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion quantile_forest/tests/test_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import math
import pickle
import warnings
from typing import Any, Dict

Expand All @@ -21,7 +22,7 @@
assert_array_equal,
assert_raises,
)
from sklearn.utils.validation import check_random_state
from sklearn.utils.validation import check_is_fitted, check_random_state

from quantile_forest import ExtraTreesQuantileRegressor, RandomForestQuantileRegressor
from quantile_forest._quantile_forest_fast import (
Expand Down Expand Up @@ -1185,6 +1186,29 @@ def test_proximity_counts_oob(name):
check_proximity_counts_oob(name)


def check_serialization(name):
# Check model serialization/deserialization.

X = X_california
y = y_california

ForestRegressor = FOREST_REGRESSORS[name]

est = ForestRegressor(n_estimators=10, random_state=0)
est.fit(X, y)

dumped = pickle.dumps(est)
est_loaded = pickle.loads(dumped)

assert check_is_fitted(est_loaded) is None
assert np.all(est.predict(X) == est_loaded.predict(X))


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
def test_serialization(name):
check_serialization(name)


def test_calc_quantile():
# Check quantile calculations.
quantiles = [0.0, 0.25, 0.5, 0.75, 1.0]
Expand Down

0 comments on commit c41cea3

Please sign in to comment.