forked from ShuoYang-1998/Few_Shot_Distribution_Calibration
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_DC.py
86 lines (74 loc) · 3.3 KB
/
evaluate_DC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import pickle
import numpy as np
import torch
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
use_gpu = torch.cuda.is_available()
def distribution_calibration(query, base_means, base_cov, k,alpha=0.21):
dist = []
for i in range(len(base_means)):
dist.append(np.linalg.norm(query-base_means[i]))
index = np.argpartition(dist, k)[:k]
mean = np.concatenate([np.array(base_means)[index], query[np.newaxis, :]])
calibrated_mean = np.mean(mean, axis=0)
calibrated_cov = np.mean(np.array(base_cov)[index], axis=0)+alpha
return calibrated_mean, calibrated_cov
if __name__ == '__main__':
# ---- data loading
dataset = 'miniImagenet'
n_shot = 1
n_ways = 5
n_queries = 15
n_runs = 10000
n_lsamples = n_ways * n_shot
n_usamples = n_ways * n_queries
n_samples = n_lsamples + n_usamples
import FSLTask
cfg = {'shot': n_shot, 'ways': n_ways, 'queries': n_queries}
FSLTask.loadDataSet(dataset)
FSLTask.setRandomStates(cfg)
ndatas = FSLTask.GenerateRunSet(end=n_runs, cfg=cfg)
ndatas = ndatas.permute(0, 2, 1, 3).reshape(n_runs, n_samples, -1)
labels = torch.arange(n_ways).view(1, 1, n_ways).expand(n_runs, n_shot + n_queries, 5).clone().view(n_runs,
n_samples)
# ---- Base class statistics
base_means = []
base_cov = []
base_features_path = "./checkpoints/%s/WideResNet28_10_S2M2_R/last/base_features.plk"%dataset
with open(base_features_path, 'rb') as f:
data = pickle.load(f)
for key in data.keys():
feature = np.array(data[key])
mean = np.mean(feature, axis=0)
cov = np.cov(feature.T)
base_means.append(mean)
base_cov.append(cov)
# ---- classification for each task
acc_list = []
print('Start classification for %d tasks...'%(n_runs))
for i in tqdm(range(n_runs)):
support_data = ndatas[i][:n_lsamples].numpy()
support_label = labels[i][:n_lsamples].numpy()
query_data = ndatas[i][n_lsamples:].numpy()
query_label = labels[i][n_lsamples:].numpy()
# ---- Tukey's transform
beta = 0.5
support_data = np.power(support_data[:, ] ,beta)
query_data = np.power(query_data[:, ] ,beta)
# ---- distribution calibration and feature sampling
sampled_data = []
sampled_label = []
num_sampled = int(750/n_shot)
for i in range(n_lsamples):
mean, cov = distribution_calibration(support_data[i], base_means, base_cov, k=2)
sampled_data.append(np.random.multivariate_normal(mean=mean, cov=cov, size=num_sampled))
sampled_label.extend([support_label[i]]*num_sampled)
sampled_data = np.concatenate([sampled_data[:]]).reshape(n_ways * n_shot * num_sampled, -1)
X_aug = np.concatenate([support_data, sampled_data])
Y_aug = np.concatenate([support_label, sampled_label])
# ---- train classifier
classifier = LogisticRegression(max_iter=1000).fit(X=X_aug, y=Y_aug)
predicts = classifier.predict(query_data)
acc = np.mean(predicts == query_label)
acc_list.append(acc)
print('%s %d way %d shot ACC : %f'%(dataset,n_ways,n_shot,float(np.mean(acc_list))))