-
Notifications
You must be signed in to change notification settings - Fork 13
/
CKA.py
92 lines (73 loc) · 2.73 KB
/
CKA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# inspired by
# https://github.com/yuanli2333/CKA-Centered-Kernel-Alignment/blob/master/CKA.py
import math
import torch
import numpy as np
class CKA(object):
def __init__(self):
pass
def centering(self, K):
n = K.shape[0]
unit = np.ones([n, n])
I = np.eye(n)
H = I - unit / n
return np.dot(np.dot(H, K), H)
def rbf(self, X, sigma=None):
GX = np.dot(X, X.T)
KX = np.diag(GX) - GX + (np.diag(GX) - GX).T
if sigma is None:
mdist = np.median(KX[KX != 0])
sigma = math.sqrt(mdist)
KX *= - 0.5 / (sigma * sigma)
KX = np.exp(KX)
return KX
def kernel_HSIC(self, X, Y, sigma):
return np.sum(self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma)))
def linear_HSIC(self, X, Y):
L_X = X @ X.T
L_Y = Y @ Y.T
return np.sum(self.centering(L_X) * self.centering(L_Y))
def linear_CKA(self, X, Y):
hsic = self.linear_HSIC(X, Y)
var1 = np.sqrt(self.linear_HSIC(X, X))
var2 = np.sqrt(self.linear_HSIC(Y, Y))
return hsic / (var1 * var2)
def kernel_CKA(self, X, Y, sigma=None):
hsic = self.kernel_HSIC(X, Y, sigma)
var1 = np.sqrt(self.kernel_HSIC(X, X, sigma))
var2 = np.sqrt(self.kernel_HSIC(Y, Y, sigma))
return hsic / (var1 * var2)
class CudaCKA(object):
def __init__(self, device):
self.device = device
def centering(self, K):
n = K.shape[0]
unit = torch.ones([n, n], device=self.device)
I = torch.eye(n, device=self.device)
H = I - unit / n
return torch.matmul(torch.matmul(H, K), H)
def rbf(self, X, sigma=None):
GX = torch.matmul(X, X.T)
KX = torch.diag(GX) - GX + (torch.diag(GX) - GX).T
if sigma is None:
mdist = torch.median(KX[KX != 0])
sigma = math.sqrt(mdist)
KX *= - 0.5 / (sigma * sigma)
KX = torch.exp(KX)
return KX
def kernel_HSIC(self, X, Y, sigma):
return torch.sum(self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma)))
def linear_HSIC(self, X, Y):
L_X = torch.matmul(X, X.T)
L_Y = torch.matmul(Y, Y.T)
return torch.sum(self.centering(L_X) * self.centering(L_Y))
def linear_CKA(self, X, Y):
hsic = self.linear_HSIC(X, Y)
var1 = torch.sqrt(self.linear_HSIC(X, X))
var2 = torch.sqrt(self.linear_HSIC(Y, Y))
return hsic / (var1 * var2)
def kernel_CKA(self, X, Y, sigma=None):
hsic = self.kernel_HSIC(X, Y, sigma)
var1 = torch.sqrt(self.kernel_HSIC(X, X, sigma))
var2 = torch.sqrt(self.kernel_HSIC(Y, Y, sigma))
return hsic / (var1 * var2)