-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathoptimize_n_hours.py
34 lines (24 loc) · 944 Bytes
/
optimize_n_hours.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
from mlgw_bns import *
if __name__ == "__main__":
try:
m = Model("optimization_dataset")
m.load()
except FileNotFoundError:
m.generate(512, 1 << 15, None)
m.save()
ho = HyperparameterOptimization(m)
n_hours_before = ho.total_training_time().total_seconds() / 3600
print(f"Optimized for {n_hours_before:2f} hours so far")
parser = argparse.ArgumentParser(description="Optimize the hyperparameters")
parser.add_argument("hours", metavar="h", type=float)
parser.add_argument(
"-g", "--generate", metavar="gen", default=False, nargs=1, type=int
)
args = parser.parse_args()
if args.generate:
m.generate(None, None, args.generate[0])
m.save()
ho.optimize_and_save(args.hours)
n_hours_after = ho.total_training_time().total_seconds() / 3600
print(f"Optimized for {n_hours_after - n_hours_before:2f} more hours")