-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcreate_plots.py
74 lines (57 loc) · 1.83 KB
/
create_plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import pandas as pd
import argparse
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
sns.set_theme()
hue_model_order = [
'perceiver-pytorch',
'flash-perceiver',
]
def calc_relative_improvement(s, col):
ref_value = s[s['implementation'] == 'perceiver-pytorch'][col].iloc[0]
rel_improvement = ref_value / s[col]
return rel_improvement
def create_plot(df, y_col):
fig, ax = plt.subplots()
g = sns.barplot(
df,
x='input sequence length',
y=y_col,
hue='implementation',
hue_order=hue_model_order,
width=0.5,
ax=ax
)
g.set_ylim(0, g.get_ylim()[1] * 1.2)
g.bar_label(g.containers[1])
return fig
def main(args):
df = pd.read_csv(args.results_file)
df = df.rename(columns={
'model': 'implementation',
'input_size': 'input sequence length'
}).sort_values('input sequence length')
for res_col, col in [
['speedup', 'time_per_it'],
['memory usage reduction', 'peak_memory']
]:
df[res_col] = (
df
.groupby(['input sequence length'])
.apply(calc_relative_improvement, col=col)
.reset_index(drop=True).values
)
for col in df.columns:
if df[col].dtype == 'float64':
df[col] = df[col].round(2)
out_dir = Path(args.output_dir)
for y_col in ['speedup', 'memory usage reduction']:
savename = y_col.replace(' ', '_')
fig = create_plot(df, y_col)
fig.savefig(out_dir / f'benchmark_{savename}.png', bbox_inches='tight')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--results_file', type=str, default='benchmark_results.csv')
parser.add_argument('--output_dir', type=str, default='figures')
main(parser.parse_args())