-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_figures.py
executable file
·353 lines (291 loc) · 10.7 KB
/
make_figures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import os
from test import TRAIN_IDS
import numpy as np
from fire import Fire
from matplotlib import pyplot as plt
from matplotlib.patches import Patch
from scipy.io import loadmat, savemat
def show_example(
sos_map,
pred_field,
true_field,
):
# Find the maximum value of the field
max_val = np.max(np.abs(true_field))
vmax = 2 # max_val / 2.
# Prepare figure
fig, axs = plt.subplots(1, 3, figsize=(8, 2.0), dpi=300)
raster1 = axs[0].imshow(np.real(true_field), vmin=-vmax, vmax=vmax, cmap="seismic")
axs[0].axis("off")
axs[0].set_title("Reference")
fig.colorbar(raster1, ax=axs[0])
ax = fig.add_axes([0.152, 0.61, 0.25, 0.25])
raster2 = ax.imshow(sos_map, vmin=1, vmax=2, cmap="inferno")
ax.axis("off")
raster3 = axs[1].imshow(np.real(pred_field), vmin=-vmax, vmax=vmax, cmap="seismic")
axs[1].axis("off")
axs[1].set_title("Prediction")
fig.colorbar(raster3, ax=axs[1])
# Normalized error map
error_field = np.abs(true_field - pred_field)
error_field = 100 * error_field / max_val
raster4 = axs[2].imshow(error_field, cmap="inferno")
axs[2].axis("off")
axs[2].set_title("Difference %")
cbar = fig.colorbar(raster4, ax=axs[2])
# cbar.set_ticks(np.log10([0.1, 0.01, 0.001, 0.0001]))
# cbar.set_ticklabels(["10%", "1%", "0.1%", "0.01%"])
# plt.tight_layout()
def load_data(name, is_born=False):
# Add scaling factor used in loss
# TODO: Remove
if not is_born:
path = "results/" + TRAIN_IDS[name] + ".mat"
data = loadmat(path)
data["prediction"] = data["prediction"] / 10.0 #
else:
path = "results/" + name + ".mat"
data = loadmat(path)
return data
def get_single_sample(data, example):
sos = data["sound_speed"][example, ..., 0]
pred = data["prediction"][example, ..., 0]
true_field = data["true_field"][example, ..., 0]
return sos, pred, true_field
def make_example_figure(data, example):
sos, pred, true_field = get_single_sample(data, example)
show_example(sos, pred, true_field)
def compute_example_error(pred, true, loss_kind):
error = np.abs(pred - true)
# Normalize by maximum value
error = error / np.max(np.abs(true))
if loss_kind == "l_infty":
return 100 * np.amax(error)
else:
raise ValueError(f"Unknown loss kind {loss_kind}")
def errors_for_model(name, loss_kind):
is_born = "born" in name
data = load_data(name, is_born)
# Remove channels
data["prediction"] = data["prediction"][..., 0]
data["true_field"] = data["true_field"][..., 0]
# Compute loss for each example
errors = [
compute_example_error(data["prediction"][i], data["true_field"][i], loss_kind)
for i in range(data["prediction"].shape[0])
]
return errors
def make_iterations_error_figure(loss_kind):
plt.figure(figsize=(5, 3))
bno_models = {
"6_stages": 2,
"base": 5.5,
"24_stages": 10,
}
cbs_models = {
"born_series_6": 2,
"born_series_12": 5.5,
"born_series_24": 10,
}
linear_2_channels_models = {"2_channels": 5.5}
for name, num_stages in bno_models.items():
print(name)
errors = errors_for_model(name, loss_kind)
plt.boxplot(
errors,
patch_artist=True,
positions=[num_stages - 0.5],
widths=0.75,
boxprops=dict(facecolor="white", color="black"),
medianprops=dict(color="black"),
whiskerprops=dict(color="black"),
capprops=dict(color="black"),
flierprops=dict(color="black", marker="."),
)
# Linear model with 2 channels
for name, num_stages in linear_2_channels_models.items():
print(name)
errors = errors_for_model(name, loss_kind)
plt.boxplot(
errors,
patch_artist=True,
positions=[num_stages + 1.5],
widths=0.75,
boxprops=dict(facecolor="white", color="orange"),
medianprops=dict(color="orange"),
whiskerprops=dict(color="orange"),
capprops=dict(color="orange"),
flierprops=dict(color="orange", marker="."),
)
# Repeat for CBS, but using red color
for name, num_stages in cbs_models.items():
print(name)
errors = errors_for_model(name, loss_kind)
plt.boxplot(
errors,
patch_artist=True,
positions=[num_stages + 0.5],
widths=0.75,
boxprops=dict(facecolor="white", color="red"),
medianprops=dict(color="red"),
whiskerprops=dict(color="red"),
capprops=dict(color="red"),
flierprops=dict(color="red", marker="."),
)
# Add title
titles = {"l_infty": "Maximum error %"}
plt.ylabel(titles[loss_kind])
# Enlarge x-axis
# plt.xlim(0, 12)
plt.xticks([2, 6, 10], [6, 12, 24])
plt.xlabel("Iterations")
# Make legend
legend_elements = []
legend_elements.append(Patch(edgecolor="black", label="LBS", facecolor="white"))
legend_elements.append(Patch(edgecolor="red", label="LBS", facecolor="white"))
legend_elements.append(
Patch(edgecolor="orange", label="Linear LBS 2 ch.", facecolor="white")
)
plt.legend(handles=legend_elements, fontsize=8)
plt.grid(axis="y", which="both")
# plt.yscale("log")
def show_iterations(example):
bno_models = ["6_stages", "base", "24_stages"]
cbs_models = ["born_series_6", "born_series_12", "born_series_24"]
vmax = 2
fix, ax = plt.subplots(2, 3, figsize=(9, 6))
for i, name in enumerate(bno_models):
data = load_data(name, False)
_, pred, _ = get_single_sample(data, example)
ax[0, i].imshow(pred.real, vmin=-vmax, vmax=vmax, cmap="seismic")
ax[0, i].axis("off")
for i, name in enumerate(cbs_models):
data = load_data(name, True)
_, pred, _ = get_single_sample(data, example)
ax[1, i].imshow(pred.real, vmin=-vmax, vmax=vmax, cmap="seismic")
ax[1, i].axis("off")
plt.tight_layout()
def make_cbs_errors(recompute_data):
# This function loads the CBS results for every number of iterations and builds up a
# 2D array with axis [iteration, example] containing the error value
# It then saves the array to a file
# If the file already exists, it loads the data from the file if recompute_data is False
matfile = "results/error_data_for_pareto.mat"
if os.path.exists(matfile) and not recompute_data:
data = loadmat(matfile)
return data["errors"]
# Generate data
# Load all results starting with "born_series_", the remaining of the
# name is the number of iterations
results = [f for f in os.listdir("results") if f.startswith("born_series_")]
results = sorted(results, key=lambda x: int(x.split("_")[2][:-4]))
errors = []
for result in results:
print(result)
# remove the ".mat" extension
result = result[:-4]
error_value = errors_for_model(result, "l_infty")
errors.append(error_value)
errors = np.asarray(errors)
savemat(matfile, {"errors": errors})
return errors
def show_pareto(results, recompute_data=False):
plt.figure(figsize=(7, 4))
# Load error for cbs
cbs_errors = make_cbs_errors(recompute_data)
# Compute error for bno
# For each example, find the number of cbs iterations that has the error
# closest to the bno error, but not larger TODO
runs = ["6_stages", "base", "24_stages"]
num_iterations = [6, 12, 24]
colors = ["black", "darkred", "darkgreen"]
light_colors = ["#999999", "#ff9999", "#99ff99"]
x_plot = []
y_plot = []
for run, num_iters, col, light_col in zip(
runs, num_iterations, colors, light_colors
):
print(run)
bno_errors = errors_for_model(run, "l_infty")
num_cbs_iterations = []
for i in range(len(bno_errors)):
error = bno_errors[i]
idx = np.argmin(np.abs(cbs_errors[:, i] - error))
num_cbs_iterations.append(idx)
# Sort the bno_errors and num_cbs_iterations by bno_errors
bno_errors = np.asarray(bno_errors)
idx = np.argsort(bno_errors)
bno_errors = bno_errors[idx]
num_cbs_iterations = np.array(num_cbs_iterations)[idx]
# Transform the num_cbs_iterations to the speed_up_factor
num_cbs_iterations = num_cbs_iterations / num_iters
# Plot a point with errorbars for both x and y
x = np.median(bno_errors)
y = np.median(num_cbs_iterations)
xerr_left = np.percentile(bno_errors, 5)
xerr_right = np.percentile(bno_errors, 95)
xerr = np.array([[x - xerr_left], [xerr_right - x]])
yerr_left = np.percentile(num_cbs_iterations, 5)
yerr_right = np.percentile(num_cbs_iterations, 95)
yerr = np.array([[y - yerr_left], [yerr_right - y]])
plt.scatter(bno_errors, num_cbs_iterations, marker=".", color=light_col)
plt.errorbar(
x,
y,
xerr=xerr,
yerr=yerr,
fmt="o",
label=f"{num_iters} stages",
color=col,
capsize=5,
)
# Print the mean speed up factor
print(f"Mean speed up factor: {np.mean(num_cbs_iterations)} for {run}")
x_plot.append(x)
y_plot.append(y)
plt.ylabel("Speed up factor")
plt.xlabel("Maximum error %")
plt.xscale("log")
# Plot the line that connects the points
plt.plot(x_plot, y_plot, color="black", linestyle="--")
# Don't use the scientific notation for y ticks
plt.xticks(
[1, 1.2, 1.5, 2, 3, 4, 5, 6, 10, 15, 20],
[1, 1.2, 1.5, 2, 3, 4, 5, 6, 10, 15, 20],
)
plt.legend()
def main(
results: str = "test",
figure: str = "example",
loss_kind: str = "l_infty",
example: int = 0,
save_fig: bool = True,
recompute_data: bool = False,
):
# Make figure
if figure == "example":
# Load data
is_born = "born" in results
data = load_data(results, is_born)
make_example_figure(data, example)
elif figure == "iterations_error":
make_iterations_error_figure(loss_kind)
elif figure == "show_iterations":
show_iterations(example)
elif figure == "show_pareto":
show_pareto(results, recompute_data)
else:
raise ValueError(f"Unknown figure {figure}")
# Save figure
if save_fig:
if not os.path.exists("figures"):
os.makedirs("figures")
plt.savefig(f"figures/{results}_{figure}_{example}.eps", bbox_inches="tight")
# Save as png
plt.savefig(
f"figures/{results}_{figure}_{example}.png", bbox_inches="tight", dpi=300
)
def make_all_figures():
pass
if __name__ == "__main__":
Fire(main)