-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmin_double_descent.py
140 lines (124 loc) · 4.84 KB
/
min_double_descent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
A minimal example of double descent using ridge regression trained
on tanh random features. The dataset is a linearly separable
classification problem with 2 features and 2 classes.
"""
# %matplotlib inline
# %config InlineBackend.figure_format = "retina"
import matplotlib.pyplot as plt
import pandas as pd
from numpy import argmax, eye, linalg, mean, random, tanh, unique
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split as split
from tqdm import trange
seed = 42
N = 6 # Number of training/test samples
P_max = N * 5 # Maximum number of random features
P_step = max(P_max // 50, 1) # Step size for number of random features
d = 2 # Number of features
num_trials = 10000 # Number of trials to average over
x, y = make_classification(
n_samples=N * 2,
n_features=d,
n_informative=d,
n_redundant=0,
flip_y=0,
class_sep=1.5,
random_state=seed,
)
y = eye(unique(y).shape[0])[y] # One-hot encode the labels
x_tr, x_te, y_tr, y_te = split(x, y, test_size=0.5, random_state=seed)
# Visualize the dataset
def visualize(x_tr, y_tr, x_te, y_te):
plt.figure(figsize=(3.2, 3.2), constrained_layout=True)
plt.scatter(
x_tr[y_tr[:, 0] == 1, 0],
x_tr[y_tr[:, 0] == 1, 1],
label="Train Class 0",
c="tab:orange",
marker="x",
)
plt.scatter(
x_tr[y_tr[:, 1] == 1, 0],
x_tr[y_tr[:, 1] == 1, 1],
label="Train Class 1",
c="tab:blue",
marker="x",
)
plt.scatter(
x_te[y_te[:, 0] == 1, 0],
x_te[y_te[:, 0] == 1, 1],
label="Test Class 0",
c="tab:orange",
marker="o",
)
plt.scatter(
x_te[y_te[:, 1] == 1, 0],
x_te[y_te[:, 1] == 1, 1],
label="Test Class 1",
c="tab:blue",
marker="o",
)
plt.legend()
plt.savefig("dataset.png", bbox_inches="tight", dpi=300)
plt.show()
def cond_number(x):
"""Compute the condition number of a matrix."""
_, s, _ = linalg.svd(x)
return s[0] / (s[-1] + 1e-8) # Add division by zero
# Generate a random matrix with iid Gaussian entries
random_matrix = lambda d, P: random.normal(loc=0, scale=1 / d**0.5, size=(d, P))
# Compute the random tanh features
random_features = lambda x, W0: tanh(x @ W0)
# Compute the mean squared error
mse_loss = lambda y, y_pred: mean((y - y_pred) ** 2)
# Compute the ridge regression solution
ridge = lambda X, y, a=1e-8: linalg.inv(X.T @ X + a * eye(X.shape[1])) @ X.T @ y
def make_double_descent(x_tr, y_tr, x_te, y_te, P_max, P_step, num_trials):
output = []
for _ in trange(num_trials):
W0_ = random_matrix(d, P_max)
for p in range(P_max, 0, -P_step):
W0 = W0_[:, :p] # Reuse part of the random matrix for efficiency
# Compute the random features
z_tr, z_te = random_features(x_tr, W0), random_features(x_te, W0)
W1 = ridge(z_tr, y_tr) # Compute the ridge regression solution
tr_pred, te_pred = z_tr @ W1, z_te @ W1 # Compute the predictions
output.append(
{
"p_over_n": p / N,
"tr_acc": mean(argmax(tr_pred, 1) == argmax(y_tr, 1)),
"te_acc": mean(argmax(te_pred, 1) == argmax(y_te, 1)),
"tr_loss": mse_loss(y_tr, tr_pred),
"te_loss": mse_loss(y_te, te_pred),
"cond": cond_number(z_tr),
"W1_norm": linalg.norm(W1, ord=2),
}
)
return pd.DataFrame(output).groupby("p_over_n").mean().reset_index()
def plot(output):
_, axes = plt.subplots(1, 4, figsize=(12, 3), sharex=True, constrained_layout=True)
# Plot error
axes[0].plot(output["p_over_n"], 1 - output["tr_acc"], label="Train")
axes[0].plot(output["p_over_n"], 1 - output["te_acc"], label="Test")
axes[0].axvline(x=1, color="r", linestyle="--")
axes[0].set(xlabel=r"$P/N$", ylabel="Error Rate")
axes[0].legend()
# Plot loss
axes[1].plot(output["p_over_n"], output["tr_loss"], label="Train")
axes[1].plot(output["p_over_n"], output["te_loss"], label="Test")
axes[1].axvline(x=1, color="r", linestyle="--")
axes[1].set(xlabel=r"$P/N$", ylabel="MSE Loss", yscale="log")
axes[1].legend()
# Plot condition number
axes[2].plot(output["p_over_n"], output["cond"])
axes[2].axvline(x=1, color="r", linestyle="--")
axes[2].set(xlabel=r"$P/N$", ylabel=r"$\sigma_{\max}/\sigma_{\min}$", yscale="log")
# Plot norm of W1
axes[3].plot(output["p_over_n"], output["W1_norm"])
axes[3].axvline(x=1, color="r", linestyle="--")
axes[3].set(xlabel=r"$P/N$", ylabel=r"$\|W_1\|_2$", yscale="log")
plt.savefig("min_double_descent.png", bbox_inches="tight", dpi=300)
plt.show()
visualize(x_tr, y_tr, x_te, y_te)
plot(make_double_descent(x_tr, y_tr, x_te, y_te, P_max, P_step, num_trials))