Skip to content

Commit

Permalink
Fix fractional max_samples (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson authored Apr 10, 2024
1 parent 32ddf35 commit 05cd3bd
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 26 deletions.
21 changes: 13 additions & 8 deletions quantile_forest/_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ class calls the ``fit`` method of the ``ForestRegressor`` and creates a
sklearn_version = parse_version(sklearn.__version__)


def _generate_unsampled_indices(sample_indices, duplicates=None):
def _generate_unsampled_indices(sample_indices, n_total_samples, duplicates=None):
"""Private function used by forest._get_y_train_leaves function."""
if duplicates is None:
duplicates = []
return generate_unsampled_indices(sample_indices, duplicates)
return generate_unsampled_indices(sample_indices, n_total_samples, duplicates)


def _group_by_value(a):
Expand Down Expand Up @@ -262,20 +262,23 @@ def _get_y_train_leaves(self, X, y_dim, sorter=None, sample_weight=None):
warnings.simplefilter("ignore", UserWarning)
X_leaves = self.apply(X)

shape = (n_samples, self.n_estimators)
if self.bootstrap:
n_samples_bootstrap = _get_n_samples_bootstrap(n_samples, self.max_samples)

shape = (n_samples if not self.bootstrap else n_samples_bootstrap, self.n_estimators)
bootstrap_indices = np.empty(shape, dtype=np.int64)
X_leaves_bootstrap = np.empty(shape, dtype=np.int64)
for i, estimator in enumerate(self.estimators_):
# Get bootstrap indices.
if self.bootstrap:
n_samples_bootstrap = _get_n_samples_bootstrap(n_samples, self.max_samples)
bootstrap_indices[:, i] = _generate_sample_indices(
estimator.random_state, n_samples, n_samples_bootstrap
)
else:
bootstrap_indices[:, i] = np.arange(n_samples)

# Get predictions on bootstrap indices.
X_leaves[:, i] = X_leaves[bootstrap_indices[:, i], i]
X_leaves_bootstrap[:, i] = X_leaves[bootstrap_indices[:, i], i]

if sorter is not None:
# Reassign bootstrap indices to account for target sorting.
Expand All @@ -294,7 +297,7 @@ def _get_y_train_leaves(self, X, y_dim, sorter=None, sample_weight=None):
if node_count > max_node_count:
max_node_count = node_count
if not leaf_subsample:
sample_count = np.max(np.bincount(X_leaves[:, i]))
sample_count = np.max(np.bincount(X_leaves_bootstrap[:, i]))
if sample_count > max_samples_leaf:
max_samples_leaf = sample_count

Expand All @@ -304,7 +307,7 @@ def _get_y_train_leaves(self, X, y_dim, sorter=None, sample_weight=None):

for i, estimator in enumerate(self.estimators_):
# Group training indices by leaf node.
leaf_indices, leaf_values_list = _group_by_value(X_leaves[:, i])
leaf_indices, leaf_values_list = _group_by_value(X_leaves_bootstrap[:, i])

if leaf_subsample:
random.seed(estimator.random_state)
Expand Down Expand Up @@ -439,7 +442,9 @@ def _get_unsampled_indices(self, estimator, duplicates=None):
sample_indices = _generate_sample_indices(
estimator.random_state, n_train_samples, n_samples_bootstrap
)
unsampled_indices = _generate_unsampled_indices(sample_indices, duplicates=duplicates)
unsampled_indices = _generate_unsampled_indices(
sample_indices, n_train_samples, duplicates=duplicates
)
return np.asarray(unsampled_indices)

def predict(
Expand Down
8 changes: 6 additions & 2 deletions quantile_forest/_quantile_forest_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -507,14 +507,18 @@ cpdef double calc_quantile_rank(

cpdef vector[intp_t] generate_unsampled_indices(
vector[intp_t] sample_indices,
intp_t n_total_samples,
vector[set[intp_t]] duplicates,
) noexcept nogil:
"""Return a list of every unsampled index, accounting for duplicates.
Parameters
----------
sample_indices : array-like of shape (n_samples)
Sample indices for which to get duplicates.
Sampled indices.
n_total_samples : int
Number of total samples, sampled and unsampled.
duplicates : list of sets
List of sets of functionally identical indices.
Expand Down Expand Up @@ -543,7 +547,7 @@ cpdef vector[intp_t] generate_unsampled_indices(
sampled_set.insert(duplicates[i].begin(), duplicates[i].end())

# If the index is not in `sampled_set`, it is unsampled.
for i in range(n_samples):
for i in range(n_total_samples):
if not sampled_set.count(i):
unsampled_indices.push_back(i)

Expand Down
103 changes: 87 additions & 16 deletions quantile_forest/tests/test_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def test_predict_quantiles_toy(name):

def check_predict_quantiles(
name,
max_samples,
max_samples_leaf,
quantiles,
weighted_quantile,
Expand All @@ -295,7 +296,13 @@ def check_predict_quantiles(
y_train = y_train.astype(np.float64)
y_test = y_test.astype(np.float64)

est = ForestRegressor(n_estimators=10, max_samples_leaf=max_samples_leaf, random_state=0)
est = ForestRegressor(
n_estimators=10,
max_samples_leaf=max_samples_leaf,
max_samples=max_samples,
bootstrap=True,
random_state=0,
)
est.fit(X_train, y_train)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
Expand All @@ -317,7 +324,13 @@ def check_predict_quantiles(
X_california, y_california, test_size=0.25, random_state=0
)

est = ForestRegressor(n_estimators=10, max_samples_leaf=max_samples_leaf, random_state=0)
est = ForestRegressor(
n_estimators=10,
max_samples_leaf=max_samples_leaf,
max_samples=max_samples,
bootstrap=True,
random_state=0,
)
est.fit(X_train, y_train)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
Expand All @@ -338,7 +351,13 @@ def check_predict_quantiles(
assert np.all(np.less_equal(y_pred[:, 1], y_pred[:, 2]))

# Check that weighted and unweighted quantiles are all equal.
est = ForestRegressor(n_estimators=10, max_samples_leaf=max_samples_leaf, random_state=0)
est = ForestRegressor(
n_estimators=10,
max_samples_leaf=max_samples_leaf,
max_samples=max_samples,
bootstrap=True,
random_state=0,
)
est.fit(X_train, y_train)
y_pred_1 = est.predict(
X_test,
Expand All @@ -357,7 +376,13 @@ def check_predict_quantiles(
assert_allclose(y_pred_1, y_pred_2)

# Check that weighted and unweighted leaves are all equal.
est = ForestRegressor(n_estimators=1, max_samples_leaf=max_samples_leaf, random_state=0)
est = ForestRegressor(
n_estimators=1,
max_samples_leaf=max_samples_leaf,
max_samples=max_samples,
bootstrap=True,
random_state=0,
)
est.fit(X_train, y_train)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
Expand All @@ -378,7 +403,13 @@ def check_predict_quantiles(
assert_allclose(y_pred_1, y_pred_2)

# Check that aggregated and unaggregated quantiles are all equal.
est = ForestRegressor(n_estimators=1, max_samples_leaf=max_samples_leaf, random_state=0)
est = ForestRegressor(
n_estimators=1,
max_samples_leaf=max_samples_leaf,
max_samples=max_samples,
bootstrap=True,
random_state=0,
)
est.fit(X_train, y_train)
y_pred_1 = est.predict(
X_test,
Expand All @@ -397,7 +428,13 @@ def check_predict_quantiles(
assert_allclose(y_pred_1, y_pred_2)

# Check that omitting quantiles is the same as setting to 0.5.
est = ForestRegressor(n_estimators=1, max_samples_leaf=max_samples_leaf, random_state=0)
est = ForestRegressor(
n_estimators=1,
max_samples_leaf=max_samples_leaf,
max_samples=max_samples,
bootstrap=True,
random_state=0,
)
est.fit(X_train, y_train)
y_pred_1 = est.predict(
X_test,
Expand Down Expand Up @@ -456,7 +493,13 @@ def check_predict_quantiles(
y[:, 1] = np.log1p(X + 1)
y[:, 1] += np.log1p(X + 1) * np.random.uniform(size=len(X))

est = ForestRegressor(n_estimators=1, max_samples_leaf=max_samples_leaf, random_state=0)
est = ForestRegressor(
n_estimators=1,
max_samples_leaf=max_samples_leaf,
max_samples=max_samples,
bootstrap=True,
random_state=0,
)
est.fit(X.reshape(-1, 1), y)

with warnings.catch_warnings():
Expand All @@ -471,24 +514,38 @@ def check_predict_quantiles(
assert y_pred.ndim == (3 if isinstance(quantiles, list) else 2)
assert y_pred.shape[1] == y.shape[1]
assert np.any(y_pred[:, 0, ...] != y_pred[:, 1, ...])
assert score > 0.97
assert score > 0.95

# Check that specifying `quantiles` overwrites `default_quantiles`.
est1 = ForestRegressor(n_estimators=1, max_samples_leaf=max_samples_leaf, random_state=0)
est1 = ForestRegressor(
n_estimators=1,
max_samples_leaf=max_samples_leaf,
max_samples=max_samples,
bootstrap=True,
random_state=0,
)
est1.fit(X_train, y_train)
y_pred_1 = est1.predict(X_test, quantiles=0.5)
est2 = ForestRegressor(
n_estimators=1,
default_quantiles=[0.25, 0.5, 0.75],
max_samples_leaf=max_samples_leaf,
max_samples=max_samples,
bootstrap=True,
random_state=0,
)
est2.fit(X_train, y_train)
y_pred_2 = est2.predict(X_test, quantiles=0.5)
assert_allclose(y_pred_1, y_pred_2)

# Check that specifying `interpolation` changes outputs.
est = ForestRegressor(n_estimators=10, max_samples_leaf=max_samples_leaf, random_state=0)
est = ForestRegressor(
n_estimators=10,
max_samples_leaf=max_samples_leaf,
max_samples=max_samples,
bootstrap=True,
random_state=0,
)
est.fit(X_train, y_train)
y_pred_1 = est.predict(X_test, quantiles=0.5, interpolation="linear")
y_pred_2 = est.predict(X_test, quantiles=0.5, interpolation="nearest")
Expand All @@ -500,19 +557,21 @@ def check_predict_quantiles(


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@pytest.mark.parametrize("max_samples", [None, 0.5, 1.0, 100])
@pytest.mark.parametrize("max_samples_leaf", [None, 1])
@pytest.mark.parametrize("quantiles", [None, "mean", 0.5, [0.2, 0.5, 0.8]])
@pytest.mark.parametrize("weighted_quantile", [True, False])
@pytest.mark.parametrize("aggregate_leaves_first", [True, False])
def test_predict_quantiles(
name,
max_samples,
max_samples_leaf,
quantiles,
weighted_quantile,
aggregate_leaves_first,
):
check_predict_quantiles(
name, max_samples_leaf, quantiles, weighted_quantile, aggregate_leaves_first
name, max_samples, max_samples_leaf, quantiles, weighted_quantile, aggregate_leaves_first
)


Expand Down Expand Up @@ -829,6 +888,7 @@ def test_oob_samples_duplicates(name):

def check_predict_oob(
name,
max_samples,
max_samples_leaf,
quantiles,
weighted_quantile,
Expand All @@ -843,6 +903,7 @@ def check_predict_oob(
est = ForestRegressor(
n_estimators=20,
max_samples_leaf=max_samples_leaf,
max_samples=max_samples,
bootstrap=True,
oob_score=True,
random_state=0,
Expand Down Expand Up @@ -977,7 +1038,13 @@ def check_predict_oob(

# Check warning if not enough estimators.
with np.errstate(divide="ignore", invalid="ignore"):
est = ForestRegressor(n_estimators=4, bootstrap=True, oob_score=True, random_state=0)
est = ForestRegressor(
n_estimators=4,
max_samples=max_samples,
bootstrap=True,
oob_score=True,
random_state=0,
)
with pytest.warns():
est.fit(X, y)
est.predict(
Expand Down Expand Up @@ -1028,19 +1095,22 @@ def check_predict_oob(


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@pytest.mark.parametrize("max_samples", [None, 0.5, 1.0, 100])
@pytest.mark.parametrize("max_samples_leaf", [None, 1])
@pytest.mark.parametrize("quantiles", [None, "mean", 0.5, [0.2, 0.5, 0.8]])
@pytest.mark.parametrize("weighted_quantile", [True, False])
@pytest.mark.parametrize("aggregate_leaves_first", [True, False])
def test_predict_oob(
name,
max_samples,
max_samples_leaf,
quantiles,
weighted_quantile,
aggregate_leaves_first,
):
check_predict_oob(
name,
max_samples,
max_samples_leaf,
quantiles,
weighted_quantile,
Expand Down Expand Up @@ -1464,24 +1534,25 @@ def test_generate_unsampled_indices():
max_index = 20
duplicates = [[1, 4], [19, 10], [2, 3, 5], [6, 13]]

def _generate_unsampled_indices(sample_indices):
def _generate_unsampled_indices(sample_indices, n_total_samples):
return generate_unsampled_indices(
np.array(sample_indices, dtype=np.int64),
n_total_samples=n_total_samples,
duplicates=duplicates,
)

# If all indices are sampled, there are no unsampled indices.
indices = [idx for idx in range(max_index)]
expected = np.array([], dtype=np.int64)
assert_array_equal(_generate_unsampled_indices(indices), expected)
assert_array_equal(_generate_unsampled_indices(indices, max_index), expected)

# Index 7 has no duplicates, and thus should be the only unsampled index.
indices = [7 for _ in range(max_index)]
expected = np.array([idx for idx in range(max_index) if idx != 7])
assert_array_equal(_generate_unsampled_indices(indices), expected)
assert_array_equal(_generate_unsampled_indices(indices, max_index), expected)

# Check sample indices [0, 1, 2] with duplicates set(1, 4) + set(2, 3, 5),
# which excludes [0, 1, 2, 3, 4, 5] (i.e., range(6)) from unsampled.
indices = [idx % 3 for idx in range(max_index)]
expected = [x for x in range(max_index) if x not in range(6)]
assert_array_equal(_generate_unsampled_indices(indices), expected)
assert_array_equal(_generate_unsampled_indices(indices, max_index), expected)

0 comments on commit 05cd3bd

Please sign in to comment.