-
Notifications
You must be signed in to change notification settings - Fork 0
/
storage.py
33 lines (24 loc) · 982 Bytes
/
storage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import csv
import numpy as np
def save_statistics(experiment_name, line_to_add):
with open("{}.csv".format(experiment_name), 'a') as f:
writer = csv.writer(f)
writer.writerow(line_to_add)
def load_statistics(experiment_name):
data_dict = dict()
with open("{}.csv".format(experiment_name), 'r') as f:
lines = f.readlines()
data_labels = lines[0].replace("\n","").split(",")
del lines[0]
for label in data_labels:
data_dict[label] = []
for line in lines:
data = line.replace("\n","").split(",")
for key, item in zip(data_labels, data):
data_dict[key].append(item)
return data_dict
def save_preds(epoch_idx, preds, labels):
fname = "preds_{}.txt".format(epoch_idx)
np.savetxt(fname, preds, fmt='%d', delimiter=' ', newline='\n')
fname = "labels_{}.txt".format(epoch_idx)
np.savetxt(fname, labels, fmt='%d', delimiter=' ', newline='\n')