-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathplotter.py
111 lines (90 loc) · 3.48 KB
/
plotter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
# FileName [ plotter.py ]
# Synopsis [ code used to generate plots ]
# Author [ Ting-Wei Liu (Andi611) ]
# Copyright [ Copyleft(c), NTUEE, NTU, Taiwan ]
"""*********************************************************************************************"""
###############
# IMPORTATION #
###############
import os
import csv
import argparse
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 20})
###########
# LAMBDAS #
###########
to_str = lambda x: [str(i) for i in x]
norm = lambda x : np.interp(x, (np.amin(x), np.amax(x)), (0, +1))
##################
# CONFIGURATIONS #
##################
def get_config():
parser = argparse.ArgumentParser(description='plotter arguments')
parser.add_argument('--result_dir', type=str, default='./result/plots/', help='directory to save plots')
mode_args = parser.add_argument_group('mode')
mode_args.add_argument('--all', action='store_true', help='plot all curve')
mode_args.add_argument('--tradeoff', action='store_true', help='plot trade-off curve')
mode_args.add_argument('--encoding', action='store_true', help='plot encoding curve')
args = parser.parse_args()
return args
##################
# PLOT TRADE OFF #
##################
def plot_tradeoff(wer, br, dim, name):
fig = plt.figure(figsize=(12, 5))
plt.xlabel('Bit Rate')
plt.ylabel('CER')
plt.gca().invert_xaxis()
fig.autofmt_xdate()
plt.plot(br, wer, linestyle=':', marker='o', color='m') # Ours
plt.plot([71.98], [1.000], linestyle=':', marker='o', color='r') # Baseline
plt.plot([138.45, 138.45], [0.036, 0.040], linestyle=':', marker='o', color='b') # Continues
for x, y, i in zip(br, wer, dim):
plt.annotate(i, xy=(x-1, y), xycoords='data', xytext=(+25, -10) if i > 16 else (+13, -30),
textcoords='offset points', fontsize=20,
arrowprops=dict(arrowstyle='->', connectionstyle="arc3,rad=0.5"))
plt.annotate('baseline', xy=(71.98-1, 1.000), xycoords='data', xytext=(-15, -30),
textcoords='offset points', fontsize=20,
arrowprops=dict(arrowstyle='->', connectionstyle="arc3,rad=0.5"))
plt.annotate('continues', xy=(138.45-1, 0.038), xycoords='data', xytext=(+40, -5),
textcoords='offset points', fontsize=20,
arrowprops=dict(arrowstyle='->', connectionstyle="arc3,rad=0.5"))
# plt.xscale('log')
plt.tight_layout()
plt.savefig(name)
plt.close()
##################
# PLOT TRADE OFF #
##################
def plot_encoding(wer, br, dim, name):
plt.figure(figsize=(9, 5))
plt.ylabel('Linear Interpolated WER and Bit Rate')
plt.xlabel('Embedding Size')
plt.plot(to_str(dim), norm(wer), linestyle=':', marker='o', color='m', label='cer')
plt.plot(to_str(dim), norm(br), linestyle=':', marker='o', color='c', label='br')
plt.legend(loc='center right')
plt.tight_layout()
plt.savefig(name)
plt.close()
########
# MAIN #
########
"""
main function
"""
def main():
args = get_config()
os.makedirs(args.result_dir, exist_ok=True)
wer = [0.196, 0.313, 0.430, 0.629, 0.717, 0.797, 0.887, 0.998, 0.998, 1.000, 1.000]
br = [138.54, 138.45, 135.45, 138.45, 138.35, 134.80, 105.96, 61.79, 55.97, 48.78, 41.32]
dim = [1024, 512, 256, 128, 64, 32, 16, 8, 7, 6, 5]
if args.all or args.tradeoff:
plot_tradeoff(wer, br, dim, os.path.join(args.result_dir, 'tradeoff.png'))
if args.all or args.encoding:
plot_encoding(wer, br, dim, os.path.join(args.result_dir, 'encoding.png'))
if __name__ == '__main__':
main()