From 46e2c7292ebef5719eba8b85fc6c975131c61494 Mon Sep 17 00:00:00 2001 From: Reid Johnson Date: Tue, 10 Sep 2024 10:12:55 -0700 Subject: [PATCH] Fix multi-target serialization (#88) --- quantile_forest/_quantile_forest_fast.pyx | 2 +- quantile_forest/tests/test_quantile_forest.py | 21 ++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/quantile_forest/_quantile_forest_fast.pyx b/quantile_forest/_quantile_forest_fast.pyx index 617fb14..2d3cdfe 100755 --- a/quantile_forest/_quantile_forest_fast.pyx +++ b/quantile_forest/_quantile_forest_fast.pyx @@ -540,7 +540,7 @@ cdef class QuantileForest: d = {} if self.sparse_pickle: matrix1 = kwargs["y_train_leaves"] - reshape1 = (matrix1.shape[2], matrix1.shape[0] * matrix1.shape[1] * matrix1.shape[2]) + reshape1 = (matrix1.shape[3], matrix1.shape[0] * matrix1.shape[1] * matrix1.shape[2]) d["shape1"] = matrix1.shape d["matrix1"] = sparse.csc_matrix(matrix1.reshape(reshape1)) diff --git a/quantile_forest/tests/test_quantile_forest.py b/quantile_forest/tests/test_quantile_forest.py index f05280e..96b1076 100755 --- a/quantile_forest/tests/test_quantile_forest.py +++ b/quantile_forest/tests/test_quantile_forest.py @@ -1334,15 +1334,24 @@ def test_monotonic_constraints(name, max_samples_leaf): check_monotonic_constraints(name, max_samples_leaf) -def check_serialization(name, sparse_pickle): +def check_serialization(name, sparse_pickle, monotonic_cst, multi_target): # Check model serialization/deserialization. X = X_california - y = y_california + + if multi_target: + y = np.vstack([y_california, y_california]).T + else: + y = y_california + + if monotonic_cst and not multi_target: + monotonic_cst = [1] * X.shape[1] + else: + monotonic_cst = None ForestRegressor = FOREST_REGRESSORS[name] - est = ForestRegressor(n_estimators=10, random_state=0) + est = ForestRegressor(n_estimators=10, monotonic_cst=monotonic_cst, random_state=0) est.fit(X, y, sparse_pickle=sparse_pickle) dumped = pickle.dumps(est) @@ -1354,8 +1363,10 @@ def check_serialization(name, sparse_pickle): @pytest.mark.parametrize("name", FOREST_REGRESSORS) @pytest.mark.parametrize("sparse_pickle", [False, True]) -def test_serialization(name, sparse_pickle): - check_serialization(name, sparse_pickle) +@pytest.mark.parametrize("monotonic_cst", [False, True]) +@pytest.mark.parametrize("multi_target", [False, True]) +def test_serialization(name, sparse_pickle, monotonic_cst, multi_target): + check_serialization(name, sparse_pickle, monotonic_cst, multi_target) def test_calc_quantile():