Skip to content

Commit

Permalink
use metric functions from espnet
Browse files Browse the repository at this point in the history
  • Loading branch information
Pingchuan committed Sep 9, 2022
1 parent 53e940e commit 14b3378
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def benchmark_inference(lipreader, data_dir, landmarks_dir, lines, dst_dir=""):
if groundtruth is not None:
print(f"ref: {groundtruth}")
wer.update( get_wer(output, groundtruth), len(groundtruth.split()))
cer.update( get_cer(output, groundtruth), len(groundtruth))
cer.update( get_cer(output, groundtruth), len(groundtruth.replace(" ", "")))
print(
f"progress: {idx+1}/{len(lines)}\tcur WER: {wer.val*100:.1f}\t"
f"cur CER: {cer.val*100:.1f}\t"
Expand Down
7 changes: 5 additions & 2 deletions metrics/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

# Copyright 2021 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

# This code refers https://github.com/espnet/espnet/blob/24c3676a8d4c2e60d2726e9bcd9bdbed740610e0/espnet/nets/e2e_asr_common.py#L213-L249

import numpy as np

def get_wer(s, ref):
return get_er(s.split(" "), ref.split(" "))
return get_er(s.split(), ref.split())

def get_cer(s, ref):
return get_er(list(s), list(ref))
return get_er(s.replace(" ", ""), ref.replace(" ", ""))

def get_er(s, ref):
"""
Expand Down

0 comments on commit 14b3378

Please sign in to comment.