-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
138 lines (111 loc) · 4.75 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import params
import random
import pickle
import matplotlib.pyplot as plt
import numpy as np
from stateless import functional_call
from sklearn.metrics import accuracy_score
# -------------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# -------------------------------------------------------------------
def zeros(nb):
return torch.zeros(nb).to(DEVICE).long()
# -------------------------------------------------------------------
def ones(nb):
return torch.ones(nb).to(DEVICE).long()
# -------------------------------------------------------------------
def write_in_file(file, file_directory):
a_file = open(file_directory, "wb")
pickle.dump(file, a_file)
a_file.close()
# -------------------------------------------------------------------
def convert_into_int(array_name):
class_label = 0
for i in np.unique(array_name):
idxs = np.where(array_name == i)
array_name[idxs] = class_label
class_label+=1
return array_name.astype('int')
# -------------------------------------------------------------------
class LambdaLayer(torch.nn.Module):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
# -------------------------------------------------------------------
def accuracy(pred, y_true):
y_pred = pred.argmax(1).reshape(-1).cpu()
y_true = y_true.reshape(-1).cpu()
return accuracy_score(y_pred, y_true)
# -------------------------------------------------------------------
def init_data_match_dict(keys, vals, variation):
data = {}
for key in keys:
data[key] = {}
if variation:
val_dim = vals
else:
val_dim = vals
if params.dataset_name in ['RotatedMNIST']:
data[key]['data'] = torch.rand((val_dim, params.img_c, params.img_w, params.img_h))
data[key]['label'] = torch.rand((val_dim, 1))
return data
# -------------------------------------------------------------------
def embedding_dist(x1, x2, tau=0.05, xent=False):
if xent:
# X1 denotes the batch of anchors while X2 denotes all the negative matches
# Broadcasting to compute loss for each anchor over all the negative matches
# Only implemnted if x1, x2 are 2 rank tensors
if len(x1.shape) != 2 or len(x2.shape) != 2:
print('Error: both should be rank 2 tensors for NT-Xent loss computation')
# Normalizing each vector
eps = 1e-8
norm = x1.norm(dim=1)
norm = norm.view(norm.shape[0], 1)
temp = eps * torch.ones_like(norm)
x1 = x1 / torch.max(norm, temp)
norm = x2.norm(dim=1)
norm = norm.view(norm.shape[0], 1)
temp = eps * torch.ones_like(norm)
x2 = x2 / torch.max(norm, temp)
# Boradcasting the anchors vector to compute loss over all negative matches
x1 = x1.unsqueeze(1)
cos_sim = torch.sum(x1 * x2, dim=2)
cos_sim = cos_sim / tau
loss = torch.sum(torch.exp(cos_sim), dim=1)
return loss
else:
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-08)
return 1.0 - cos(x1, x2)
# -------------------------------------------------------------------
def plot_avg_acc(test_steps, avg_acc_umaDANN, avg_acc_umaMMD, avg_acc_scratchDANN, avg_acc_scratchMMD, save_file):
plt.plot(range(test_steps), avg_acc_umaDANN, c='tab:blue', label="UMA-DANN")
plt.plot(range(test_steps), avg_acc_umaMMD, c='tab:green', label="UMA-MMD")
plt.plot(range(test_steps), avg_acc_scratchDANN, c='tab:orange', label="Scratch-DANN")
plt.plot(range(test_steps), avg_acc_scratchMMD, c='tab:purple', label="Scratch-MMD")
plt.xlabel("Steps")
plt.ylabel("Accuracy")
plt.legend()
#plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.title("Average accuracy", fontsize=12)
if save_file:
plt.savefig(save_file)
plt.close()
return
# -------------------------------------------------------------------
def plot_worst_acc(test_steps, worst_acc_umaDANN, worst_acc_umaMMD, worst_acc_scratchDANN, worst_acc_scratchMMD, save_file):
plt.plot(range(test_steps), worst_acc_umaDANN, c='tab:blue', label="UMA-DANN")
plt.plot(range(test_steps), worst_acc_umaMMD, c='tab:green', label="UMA-MMD")
plt.plot(range(test_steps), worst_acc_scratchDANN, c='tab:orange', label="Scratch-DANN")
plt.plot(range(test_steps), worst_acc_scratchMMD, c='tab:purple', label="Scratch-MMD")
plt.xlabel("Steps")
plt.ylabel("Accuracy")
plt.legend()
#plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.title("Worst accuracy", fontsize=12)
if save_file:
plt.savefig(save_file)
plt.close()
return