forked from urvashik/knnlm
-
Notifications
You must be signed in to change notification settings - Fork 1
/
plot_wiki.py
54 lines (42 loc) · 1.68 KB
/
plot_wiki.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
color = sns.color_palette('colorblind', n_colors=4)
# dist - acc
dist_grouped = pd.read_csv('figures/wiki_dist_correctness.csv')
dist_grouped = dist_grouped[dist_grouped['dist_right'] >= -15]
conditions = [
(dist_grouped['locality'] == 0),
(dist_grouped['locality'] == 1),
(dist_grouped['locality'] == 2),
(dist_grouped['locality'] == 3)]
choices = ['no locality',
'same category, different section',
'same section, different category',
'same section, same category']
dist_grouped['Locality'] = np.select(conditions, choices)
dist_grouped['Accuracy'] = dist_grouped['correctness']
dist_grouped['Neg. Distance'] = dist_grouped['dist_right']
fig, ax = plt.subplots(1, 3, figsize=(13, 4))
sns.scatterplot(x='Neg. Distance', y='Accuracy', hue='Locality', data=dist_grouped, s=11,
palette=color, ax=ax[0], legend=False)
grouped = pd.read_csv('figures/wiki_rank.csv')
grouped = grouped.loc[grouped['rank'] <= 200]
conditions = [
(grouped['locality'] == 0),
(grouped['locality'] == 1),
(grouped['locality'] == 2),
(grouped['locality'] == 3)]
grouped['Locality'] = np.select(conditions, choices)
grouped['Rank'] = grouped['rank']
grouped['Accuracy'] = grouped['correctness']
grouped['Neg. Distance'] = grouped['dist']
# rank - acc
sns.scatterplot(x='Rank', y='Accuracy', hue='Locality', data=grouped, s=8,
palette=color, ax=ax[1], legend=False)
# rank - dist
sns.scatterplot(x='Rank', y='Neg. Distance', hue='Locality', data=grouped, s=8,
palette=color, ax=ax[2])
fig.tight_layout()
plt.savefig('figures/wiki.pdf')