Skip to content

Commit

Permalink
Update example plots
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Aug 17, 2024
1 parent 1e8add1 commit b40f858
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 27 deletions.
14 changes: 9 additions & 5 deletions examples/plot_proximity_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
n_test_samples = 25
noise_std = 0.1

pixel_dim = (8, 8) # pixel dimensions (width and height)
pixel_scale = 100 # scale multipler for combining clean and noisy values

# Load the Digits dataset.
X, y = datasets.load_digits(return_X_y=True, as_frame=True)

Expand Down Expand Up @@ -85,7 +88,7 @@ def extract_floats(combined_df, scale=100):
)

df = (
combine_floats(X_test, X_test_noisy) # combine to reduce transmitted data
combine_floats(X_test, X_test_noisy, scale=pixel_scale) # combine to reduce transmitted data
.join(y_test)
.reset_index()
.join(df_prox)
Expand All @@ -103,7 +106,7 @@ def extract_floats(combined_df, scale=100):

# Create a data frame for looking up training proximities.
df_lookup = (
combine_floats(X_train, X_train_noisy) # combine to reduce transmitted data
combine_floats(X_train, X_train_noisy, scale=pixel_scale) # combine to reduce transmitted data
.assign(**{"index": np.arange(len(X_train))})
.join(y_train)
)
Expand All @@ -112,14 +115,15 @@ def extract_floats(combined_df, scale=100):
def plot_digits_proximities(
df,
df_lookup,
pixel_dim=(8, 8),
pixel_scale=100,
n_prox=25,
n_prox_per_row=5,
subplot_spacing=10,
height=225,
width=225,
):
pixel_scale = 100
pixel_cols = [f"pixel_{y:01}_{x:01}" for y in range(8) for x in range(8)]
pixel_cols = [f"pixel_{y:01}_{x:01}" for y in range(pixel_dim[1]) for x in range(pixel_dim[0])]
pixel_x = "split(datum.pixel, '_')[2]"
pixel_y = "split(datum.pixel, '_')[1]"

Expand Down Expand Up @@ -227,5 +231,5 @@ def plot_digits_proximities(
return chart


chart = plot_digits_proximities(df, df_lookup)
chart = plot_digits_proximities(df, df_lookup, pixel_dim=pixel_dim, pixel_scale=pixel_scale)
chart
11 changes: 5 additions & 6 deletions examples/plot_quantile_extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@

random_state = np.random.RandomState(0)
n_samples = 500
bounds = [0, 15]
extrap_frac = 0.25
bounds = [0, 15]
func = lambda x: x * np.sin(x)
func_str = "f(x) = x sin(x)"

quantiles = [0.025, 0.975, 0.5]
qrf_params = {"max_samples_leaf": None, "min_samples_leaf": 4, "random_state": random_state}

Expand Down Expand Up @@ -403,8 +402,8 @@ def get_coverage_xtr(bounds_list, train_indices, test_indices, y_train, level, *
)


def plot_qrf_vs_xtrapolation_comparison(df):
def plot_extrapolations(df, title="", legend=False, x_domain=None, y_domain=None):
def plot_qrf_vs_xtrapolation_comparison(df, func_str):
def plot_extrapolations(df, title="", legend=False, func_str="", x_domain=None, y_domain=None):
x_scale = None
if x_domain is not None:
x_scale = alt.Scale(domain=x_domain, nice=False, padding=0)
Expand Down Expand Up @@ -517,7 +516,7 @@ def plot_extrapolations(df, title="", legend=False, x_domain=None, y_domain=None
chart = chart.properties(title=title, height=200, width=300)
return chart

kwargs = {"x_domain": [0, 15], "y_domain": [-15, 20]}
kwargs = {"func_str": func_str, "x_domain": [0, 15], "y_domain": [-15, 20]}
xtra_mapper = {"bb_mid": "y_pred", "bb_low": "y_pred_low", "bb_upp": "y_pred_upp"}

chart1 = alt.layer(
Expand Down Expand Up @@ -555,5 +554,5 @@ def plot_extrapolations(df, title="", legend=False, x_domain=None, y_domain=None
return chart


chart = plot_qrf_vs_xtrapolation_comparison(df)
chart = plot_qrf_vs_xtrapolation_comparison(df, func_str)
chart
13 changes: 4 additions & 9 deletions examples/plot_quantile_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,16 @@
{
"signal": lambda x: np.log1p(x + 1),
"noise": lambda x: np.log1p(x) * random_state.uniform(size=len(x)),
"legend": {"0": "#f2a619"}, # plot legend value and color
},
{
"signal": lambda x: np.log1p(np.sqrt(x)),
"noise": lambda x: np.log1p(x / 2) * random_state.uniform(size=len(x)),
"legend": {"1": "#006aff"}, # plot legend value and color
},
]

legend = {
"0": "#f2a619",
"1": "#006aff",
}
legend = {k: v for f in funcs for k, v in f["legend"].items()}


def make_func_Xy(funcs, bounds, n_samples):
Expand All @@ -48,10 +47,6 @@ def make_func_Xy(funcs, bounds, n_samples):
return np.atleast_2d(x).T, y


def format_frac(fraction):
return f"{fraction:.3g}".rstrip("0").rstrip(".") or "0"


# Create the dataset with multiple target variables.
X, y = make_func_Xy(funcs, bounds, n_samples)

Expand All @@ -70,7 +65,7 @@ def format_frac(fraction):
"y_true": np.concatenate([f["signal"](X.squeeze()) for f in funcs]),
"y_pred": np.concatenate([y_pred[:, i, len(quantiles) // 2] for i in range(len(funcs))]),
"target": np.concatenate([[str(i)] * len(X) for i in range(len(funcs))]),
**{f"q_{format_frac(q_i)}": y_i.ravel() for q_i, y_i in zip(quantiles, y_pred.T)},
**{f"q_{q_i:.3g}": y_i.ravel() for q_i, y_i in zip(quantiles, y_pred.T)},
}
)

Expand Down
9 changes: 2 additions & 7 deletions examples/plot_quantile_vs_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from quantile_forest import RandomForestQuantileRegressor

random_state = np.random.RandomState(0)
n_samples = 5000
quantiles = np.linspace(0, 1, num=101, endpoint=True).round(2).tolist()

# Create right-skewed dataset.
n_samples = 5000
a, loc, scale = 7, -1, 1
skewnorm_rv = sp.stats.skewnorm(a, loc, scale)
skewnorm_rv.random_state = random_state
Expand All @@ -48,16 +48,11 @@
"QRF (Median)": "#006aff",
}


def format_frac(fraction):
return f"{fraction:.3g}".rstrip("0").rstrip(".") or "0"


df = pd.DataFrame(
{
"actual": y_test,
"rf": y_pred_rf,
**{f"qrf_{format_frac(q_i)}": y_i.ravel() for q_i, y_i in zip(quantiles, y_pred_qrf.T)},
**{f"qrf_{q_i:.3g}": y_i.ravel() for q_i, y_i in zip(quantiles, y_pred_qrf.T)},
}
)

Expand Down

0 comments on commit b40f858

Please sign in to comment.