Skip to content

Commit

Permalink
Update examples
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Feb 22, 2024
1 parent e7caefd commit 721165f
Show file tree
Hide file tree
Showing 11 changed files with 21 additions and 22 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from datetime import datetime

sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath(".."))

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinxext/gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from docutils.statemachine import ViewList
from sphinx.util.nodes import nested_parse_with_titles

from quantile_forest.tests.examples import iter_examples
from examples import iter_examples

from .utils import create_generic_image, create_thumbnail, get_docstring_and_rest, prev_this_next

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_train_Xy(X, y, min_idx, max_idx):
return X_train, y_train


def get_test_X(X):
def get_test_X(X, bounds):
n_samples = len(X)
X_test = np.atleast_2d(np.linspace(*bounds, n_samples)).T
return X_test
Expand All @@ -56,7 +56,7 @@ def get_test_X(X):
# Based on the extrapolation bounds, get the training and test data.
# Training data excludes extrapolated regions; test data includes them.
X_train, y_train = get_train_Xy(X, y, extrap_min_idx, extrap_max_idx)
X_test = get_test_X(X)
X_test = get_test_X(X, bounds)

qrf = RandomForestQuantileRegressor(max_samples_leaf=None, min_samples_leaf=10, random_state=0)
qrf.fit(np.expand_dims(X_train, axis=-1), y_train)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,6 @@

from quantile_forest import RandomForestQuantileRegressor

interpolations = {
"Linear": "#006aff",
"Lower": "#ffd237",
"Higher": "#0d4599",
"Midpoint": "#f2a619",
"Nearest": "#a6e5ff",
}

# legend = {"Actual": "#000000"} | interpolations
legend = {"Actual": "#000000"}
legend.update(interpolations)

# Create toy dataset.
X = np.array([[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1]])
y = np.array([-2, -1, 0, 1, 2])
Expand All @@ -41,6 +29,18 @@
)
est.fit(X, y)

interpolations = {
"Linear": "#006aff",
"Lower": "#ffd237",
"Higher": "#0d4599",
"Midpoint": "#f2a619",
"Nearest": "#a6e5ff",
}

# legend = {"Actual": "#000000"} | interpolations
legend = {"Actual": "#000000"}
legend.update(interpolations)

# Initialize data with actual values.
data = {
"method": ["Actual"] * len(y),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def plot_calibration(df):
"y_pred:Q",
axis=alt.Axis(format="$,d"),
scale=alt.Scale(domain=domain, nice=False),
title="Fitted Values (Conditional Mean)",
title="Fitted Values (conditional median)",
),
y=alt.Y(
"y_true:Q",
Expand Down Expand Up @@ -152,7 +152,7 @@ def plot_intervals(df):
y=alt.Y(
"y_true:Q",
axis=alt.Axis(format="$,d"),
title="Observed Values and Prediction Intervals",
title="Observed Values and Prediction Intervals (centered)",
),
color=alt.value("#f2a619"),
tooltip=tooltip,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def plot_prediction_histograms(df, legend):
.transform_calculate(calculate=f"round({alt.datum['rf']} * 10) / 10", as_="RF (Mean)")
.transform_calculate(calculate=f"round({alt.datum['qrf']} * 10) / 10", as_="QRF (Median)")
.transform_fold(["Actual", "RF (Mean)", "QRF (Median)"], as_=["label", "value"])
.transform_joinaggregate(total="count(*)", groupby=["label"])
.transform_calculate(pct="1 / datum.total")
.mark_bar()
.encode(
x=alt.X(
Expand All @@ -74,15 +72,15 @@ def plot_prediction_histograms(df, legend):
labelAngle=0,
labelExpr="datum.value % 0.5 == 0 ? datum.value : null",
),
title="Value",
title="Actual and Predicted Target Values",
),
y=alt.Y("sum(pct):Q", axis=alt.Axis(format=".0%", title="Percentage")),
y=alt.Y("count():Q", axis=alt.Axis(format=",d", title="Counts")),
color=color,
xOffset=alt.XOffset("label:N"),
tooltip=[
alt.Tooltip("label:N", title="Label"),
alt.Tooltip("value:O", title="Value (binned)"),
alt.Tooltip("sum(pct):Q", format=".0%", title="Percentage"),
alt.Tooltip("count():Q", format=",d", title="Counts"),
],
)
.add_params(click)
Expand Down
File renamed without changes.

0 comments on commit 721165f

Please sign in to comment.