From 14b33789c40f931860994eafdef18d05136ef8ef Mon Sep 17 00:00:00 2001 From: Pingchuan Date: Fri, 9 Sep 2022 15:45:16 +0100 Subject: [PATCH] use metric functions from espnet --- main.py | 2 +- metrics/measures.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 94a27e0..8227cfb 100644 --- a/main.py +++ b/main.py @@ -119,7 +119,7 @@ def benchmark_inference(lipreader, data_dir, landmarks_dir, lines, dst_dir=""): if groundtruth is not None: print(f"ref: {groundtruth}") wer.update( get_wer(output, groundtruth), len(groundtruth.split())) - cer.update( get_cer(output, groundtruth), len(groundtruth)) + cer.update( get_cer(output, groundtruth), len(groundtruth.replace(" ", ""))) print( f"progress: {idx+1}/{len(lines)}\tcur WER: {wer.val*100:.1f}\t" f"cur CER: {cer.val*100:.1f}\t" diff --git a/metrics/measures.py b/metrics/measures.py index dda4beb..a28e362 100644 --- a/metrics/measures.py +++ b/metrics/measures.py @@ -3,13 +3,16 @@ # Copyright 2021 Imperial College London (Pingchuan Ma) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +# This code refers https://github.com/espnet/espnet/blob/24c3676a8d4c2e60d2726e9bcd9bdbed740610e0/espnet/nets/e2e_asr_common.py#L213-L249 + import numpy as np def get_wer(s, ref): - return get_er(s.split(" "), ref.split(" ")) + return get_er(s.split(), ref.split()) def get_cer(s, ref): - return get_er(list(s), list(ref)) + return get_er(s.replace(" ", ""), ref.replace(" ", "")) def get_er(s, ref): """