forked from andrey-popov/syst-smoothing
-
Notifications
You must be signed in to change notification settings - Fork 0
/
analyse_cv_sgn.py
executable file
·101 lines (67 loc) · 2.66 KB
/
analyse_cv_sgn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python
"""Analyses output files produced by cross_validation_sgn.py.
For each template and each variation find the bandwidth that gives the
smallest CV error. Plots mean CV errors for requested cases.
"""
import argparse
import json
import math
import os
import sys
import re
import numpy as np
import matplotlib as mpl
mpl.use("agg")
from matplotlib import pyplot as plt
def parse_file(path):
"""Parse file with output of a single job."""
with open(path) as f:
lines = f.readlines()
if len(lines) == 0:
raise RuntimeError(f"File {path} is empty.")
# Drop CSV header
del lines[0]
split_lines = [line.split(",") for line in lines]
for i in range(len(split_lines)):
if len(split_lines[i]) != 6:
raise RuntimeError("In file {path}, failed to parse line {lines[i]}.")
# Template and variation are the same for one file. Extract them
# from the first line.
template = split_lines[0][0]
variation = split_lines[0][1]
errors = np.asarray([[float(l[2]), float(l[4])] for l in split_lines])
return template, variation, errors
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(epilog=__doc__)
arg_parser.add_argument("inputs", help="Directory with outputs of individual jobs")
arg_parser.add_argument(
"-o", "--output", default="bandwidths.csv",
help="Name for output file with chosen bandwidths"
)
arg_parser.add_argument("-p", "--plot", action="store_true", help="Make plots")
arg_parser.add_argument(
"--fig-dir", default="fig/CV",
help="Directory for produced figures"
)
args = arg_parser.parse_args()
try:
os.makedirs(args.fig_dir)
except FileExistsError:
pass
# Find optimal bandwidths
optimal_badwidths = []
for file_name in os.listdir(args.inputs):
path = os.path.join(args.inputs, file_name)
if not os.path.isfile(path):
continue
template, variation, errors = parse_file(path)
bandwidth = errors[np.argmin(errors[:, 1])][0]
optimal_badwidths.append((template, variation, bandwidth))
plt.plot(errors[:, 0], errors[:, 1])
plt.savefig(os.path.join(args.fig_dir, template + "_" + variation + ".pdf"))
plt.close()
optimal_badwidths.sort()
with open(args.output, "w") as out_file:
out_file.write("#Template,Variation,h_\n")
for template, variation, bandwidth in optimal_badwidths:
out_file.write("{},{},{:g}\n".format(template, variation, bandwidth))