Skip to content

Commit

Permalink
Update example plots
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Aug 3, 2024
1 parent 2901cc9 commit 024962c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
34 changes: 24 additions & 10 deletions examples/plot_proximity_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,33 @@
import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.utils.validation import check_random_state

from quantile_forest import RandomForestQuantileRegressor

alt.data_transformers.disable_max_rows()

n_test_samples = 10
corrupt_frac = 0.5
rng = check_random_state(0)

n_examples = 5
corrupt_frac = 0.33

# Load the Digits dataset.
X, y = datasets.load_digits(return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=n_test_samples, random_state=0)

perm = rng.permutation(len(X))
X = X.iloc[perm]
y = y[perm]

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

def randomly_mask_values(df, n=None, frac=None, seed=0):

def randomly_mask_values(df, n=None, frac=None, random_state=None):
"""Randomly mask a fraction of the values in a data frame with NaN."""
np.random.seed(seed)
if random_state is None:
rng = check_random_state(0)
else:
rng = random_state

df = df.copy()

Expand All @@ -43,8 +54,8 @@ def randomly_mask_values(df, n=None, frac=None, seed=0):
elif frac is not None:
num_nan = int(df.size * frac)

random_rows = np.random.randint(0, df.shape[0], num_nan)
random_cols = np.random.randint(0, df.shape[1], num_nan)
random_rows = rng.randint(0, df.shape[0], num_nan)
random_cols = rng.randint(0, df.shape[1], num_nan)

df.values[random_rows, random_cols] = np.nan

Expand Down Expand Up @@ -84,8 +95,11 @@ def fillna(df):


# Randomly corrupt a fraction of the training and test data.
X_train_corrupt = randomly_mask_values(X_train, frac=corrupt_frac, seed=0).pipe(fillna)
X_test_corrupt = randomly_mask_values(X_test, frac=corrupt_frac, seed=0).pipe(fillna)
X_train_corrupt = randomly_mask_values(X_train, frac=corrupt_frac, random_state=rng).pipe(fillna)
X_test_corrupt = randomly_mask_values(X_test, frac=corrupt_frac, random_state=rng).pipe(fillna)

X_test = X_test[:n_examples]
X_test_corrupt = X_test_corrupt[:n_examples]

# We set `max_samples_leaf=None` so that all leaf node samples are stored.
qrf = RandomForestQuantileRegressor(max_samples_leaf=None, random_state=0)
Expand All @@ -105,7 +119,7 @@ def fillna(df):

df_prox = pd.DataFrame(
{
"prox": [[(j, *p) for j, p in enumerate(proximities[i])] for i in range(n_test_samples)],
"prox": [[(j, *p) for j, p in enumerate(proximities[i])] for i in range(n_examples)],
"index": X_test.index,
}
)
Expand Down
25 changes: 19 additions & 6 deletions examples/plot_quantile_ranks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@
bounds = [0, 10]


def make_toy_dataset(n_samples, bounds, add_noise=True, random_seed=0):
def make_toy_dataset(n_samples, bounds, random_seed=0):
rng = np.random.RandomState(random_seed)
X_1d = np.linspace(*bounds, num=n_samples)
X = X_1d.reshape(-1, 1)
y = X_1d * np.cos(X_1d) + rng.normal(scale=X_1d / math.e)
return X, y


X, y = make_toy_dataset(n_samples, bounds, add_noise=True, random_seed=0)
X, y = make_toy_dataset(n_samples, bounds, random_seed=0)

params = {"max_samples_leaf": None, "min_samples_leaf": 50, "random_state": 0}
qrf = RandomForestQuantileRegressor(**params).fit(X, y)

y_pred = qrf.predict(X)
y_pred = qrf.predict(X, quantiles=0.5)
y_ranks = qrf.quantile_ranks(X, y)

df = pd.DataFrame(
Expand All @@ -51,6 +51,7 @@ def plot_fit_and_ranks(df):
rank_val = alt.param("rank_val", bind=slider, value=0.05)

base = alt.Chart(df)

points = (
base.transform_calculate(
outlier="abs(datum.y_rank - 0.5) > (0.5 - rank_val / 2) ? 'Yes' : 'No'"
Expand All @@ -60,7 +61,11 @@ def plot_fit_and_ranks(df):
.encode(
x=alt.X("x:Q"),
y=alt.Y("y:Q"),
color=alt.condition(f"datum.outlier == 'Yes'", alt.value("red"), alt.value("#f2a619")),
color=alt.Color(
"outlier:N",
scale=alt.Scale(domain=["Yes", "No"], range=["red", "#f2a619"]),
title="Outliers",
),
tooltip=[
alt.Tooltip("x:Q", format=".3f", title="x"),
alt.Tooltip("y:Q", format=".3f", title="f(x)"),
Expand All @@ -69,11 +74,19 @@ def plot_fit_and_ranks(df):
],
)
)

line_pred = base.mark_line(color="#006aff", size=4).encode(
x=alt.X("x", axis=alt.Axis(title="x")), y=alt.Y("y_pred", axis=alt.Axis(title="f(x)"))
x=alt.X("x:Q", axis=alt.Axis(title="x")),
y=alt.Y("y_pred:Q", axis=alt.Axis(title="f(x)")),
)

dummy_legend = (
base.mark_line(opacity=1)
.encode(opacity=alt.Opacity("model:N", scale=alt.Scale(range=[1, 1]), title="Predictions"))
.transform_calculate(model="'Median'")
)

chart = (points + line_pred).properties(
chart = (dummy_legend + points + line_pred).properties(
height=400, width=650, title="QRF Predictions with Quantile Rank Thresholding"
)

Expand Down

0 comments on commit 024962c

Please sign in to comment.