-
Notifications
You must be signed in to change notification settings - Fork 0
/
attention_entropy_visualisation_utils.py
45 lines (35 loc) · 1.82 KB
/
attention_entropy_visualisation_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from typing import List
import numpy as np
import matplotlib.pyplot as plt
def draw_entropy_head_plot(axis: plt.Axes, neighbourhood_entropy: np.ndarray, uniform_entropy: np.ndarray, title: str):
cmap = plt.cm.PuBuGn
color_attention = cmap(0.7)
color_uniform = cmap(0.4)
def draw_entropy_histogram(entropy_array: np.ndarray, color: str, from_uniform_distribution=False, num_bins=30):
axis.hist(entropy_array, num_bins, color=color, alpha=0.7, rwidth=0.7 if not from_uniform_distribution else 1.0)
draw_entropy_histogram(uniform_entropy, color=color_uniform, from_uniform_distribution=True)
draw_entropy_histogram(neighbourhood_entropy, color=color_attention)
axis.set_xlabel(f'entropy bin')
axis.set_ylabel(f'# of neighborhoods')
axis.legend(['uniform distribution', 'attention distribution'])
axis.set_title(title)
def draw_entropy_heads_plot(neighbourhood_entropy_per_head: List[np.ndarray],
uniform_entropy_per_head: List[np.ndarray], layer: int, subplots: List[int]):
rows, cols = subplots
current_head = 0
fig, axs = plt.subplots(rows, cols)
for row in range(rows):
for col in range(cols):
neighbourhood_entropy = neighbourhood_entropy_per_head[current_head]
corresponding_unif_entropy = uniform_entropy_per_head[current_head]
if rows == 1:
draw_entropy_head_plot(axs[col], neighbourhood_entropy, corresponding_unif_entropy,
f'attention head={current_head}, layer={layer}')
else:
draw_entropy_head_plot(axs[row, col], neighbourhood_entropy, corresponding_unif_entropy,
f'attention head={current_head}, layer={layer}')
current_head += 1
fig.suptitle(f'attention distribution entropy in layer={layer}')
fig.subplots_adjust(top=0.9)
fig.set_size_inches(19.5, 5.75)
return fig