-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot-figure-1.py
83 lines (67 loc) · 2.25 KB
/
plot-figure-1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import glob
import altair as alt
import pandas as pd
from src.data.faces import Faces
from src.data.jester import Jester
from src.data.movielens import MovieLens
from src.user.properties import get_user_similarity
def get_user_level_evaluation(datasets):
test_dfs = []
for name, dataset in datasets.items():
files = glob.glob(f"outputs/data={name}*/test.parquet")
if len(files) == 0:
print(f"No test results found for dataset: {name}")
continue
# Load rating dataset and compute taste similarity and dispersion
df = dataset.load()
user_df = get_user_similarity(df)
# Load rating dataset, average over multiple repetitions
test_df = pd.concat([pd.read_parquet(f) for f in files])
test_df = (
test_df.groupby(["user", "model", "dataset"])
.mean(numeric_only=True)
.reset_index()
)
test_df = test_df.merge(user_df, on="user")
test_dfs.append(test_df)
return pd.concat(test_dfs)
def plot():
datasets = {
"faces": Faces(),
"jester": Jester(),
"movielens": MovieLens()
}
df = get_user_level_evaluation(datasets)
# Drop outliers for plotting
df = df[df.taste_dispersion > 0.05]
# Plot figure 1
chart = alt.Chart(df, width=200, height=200).mark_circle().encode(
column=alt.Column("dataset", header=alt.Header(title="", labelFontSize=12)),
row=alt.Row("model", header=alt.Header(title="", labelFontSize=12)),
x=alt.X(
"mean_taste_similarity",
title="Mean taste similarity",
),
y=alt.Y(
"taste_dispersion",
title="Taste dispersion",
scale=alt.Scale(zero=False),
),
color=alt.Color(
"fcp",
title="FCP",
scale=alt.Scale(
domain=[0.4, 1.0],
scheme=alt.SchemeParams("inferno", extent=(0.25, 1.0)),
)
),
tooltip=["user", "fcp", "nDCG", "rmse"],
).resolve_scale(
y="independent",
).configure_axis(
titleFontSize=12,
titleFontWeight="normal",
).interactive()
chart.save("figures/figure-1.html")
if __name__ == '__main__':
plot()