-
Notifications
You must be signed in to change notification settings - Fork 508
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Begin to add test for speaker recognition
- Loading branch information
1 parent
8d3d876
commit ce1c5e4
Showing
4 changed files
with
256 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#!/usr/bin/env bash | ||
|
||
set -e | ||
|
||
log() { | ||
# This function is from espnet | ||
local fname=${BASH_SOURCE[1]##*/} | ||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
} | ||
|
||
d=/tmp/sr-models | ||
mkdir -p $d | ||
|
||
pushd $d | ||
log "Download test waves" | ||
git clone https://github.com/csukuangfj/sr-data | ||
popd | ||
|
||
log "Download wespeaker models" | ||
model_dir=$d/wespeaker | ||
mkdir -p $model_dir | ||
pushd $model_dir | ||
models=( | ||
en_voxceleb_CAM++.onnx | ||
en_voxceleb_CAM++_LM.onnx | ||
en_voxceleb_resnet152_LM.onnx | ||
en_voxceleb_resnet221_LM.onnx | ||
en_voxceleb_resnet293_LM.onnx | ||
en_voxceleb_resnet34.onnx | ||
en_voxceleb_resnet34_LM.onnx | ||
zh_cnceleb_resnet34.onnx | ||
zh_cnceleb_resnet34_LM.onnx | ||
) | ||
for m in ${models[@]}; do | ||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m | ||
done | ||
ls -lh | ||
popd | ||
|
||
log "Download 3d-speaker models" | ||
model_dir=$d/3dspeaker | ||
mkdir -p $model_dir | ||
pushd $model_dir | ||
models=( | ||
speech_campplus_sv_en_voxceleb_16k.onnx | ||
speech_campplus_sv_zh-cn_16k-common.onnx | ||
speech_eres2net_base_200k_sv_zh-cn_16k-common.onnx | ||
speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx | ||
speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx | ||
speech_eres2net_sv_en_voxceleb_16k.onnx | ||
speech_eres2net_sv_zh-cn_16k-common.onnx | ||
) | ||
for m in ${models[@]}; do | ||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m | ||
done | ||
ls -lh | ||
popd | ||
|
||
|
||
python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
# sherpa-onnx/python/tests/test_speaker_recognition.py | ||
# | ||
# Copyright (c) 2024 Xiaomi Corporation | ||
# | ||
# To run this single test, use | ||
# | ||
# ctest --verbose -R test_speaker_recognition_py | ||
|
||
import unittest | ||
import wave | ||
from collections import defaultdict | ||
from pathlib import Path | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
import sherpa_onnx | ||
|
||
d = "/tmp/sr-models" | ||
|
||
|
||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
""" | ||
Args: | ||
wave_filename: | ||
Path to a wave file. It should be single channel and each sample should | ||
be 16-bit. Its sample rate does not need to be 16kHz. | ||
Returns: | ||
Return a tuple containing: | ||
- A 1-D array of dtype np.float32 containing the samples, which are | ||
normalized to the range [-1, 1]. | ||
- sample rate of the wave file | ||
""" | ||
|
||
with wave.open(wave_filename) as f: | ||
assert f.getnchannels() == 1, f.getnchannels() | ||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
num_samples = f.getnframes() | ||
samples = f.readframes(num_samples) | ||
samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
samples_float32 = samples_int16.astype(np.float32) | ||
|
||
samples_float32 = samples_float32 / 32768 | ||
return samples_float32, f.getframerate() | ||
|
||
|
||
def load_speaker_embedding_model(model_filename): | ||
config = sherpa_onnx.SpeakerEmbeddingExtractorConfig( | ||
model=model_filename, | ||
num_threads=1, | ||
debug=True, | ||
provider="cpu", | ||
) | ||
if not config.validate(): | ||
raise ValueError(f"Invalid config. {config}") | ||
extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config) | ||
return extractor | ||
|
||
|
||
def test_wespeaker_model(model_filename: str): | ||
model_filename = str(model_filename) | ||
if "en" in model_filename: | ||
print(f"skip {model_filename}") | ||
return | ||
extractor = load_speaker_embedding_model(model_filename) | ||
filenames = [ | ||
"leijun-sr-1", | ||
"leijun-sr-2", | ||
"fangjun-sr-1", | ||
"fangjun-sr-2", | ||
"fangjun-sr-3", | ||
] | ||
tmp = defaultdict(list) | ||
for filename in filenames: | ||
print(filename) | ||
name = filename.split("-", maxsplit=1)[0] | ||
data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/enroll/{filename}.wav") | ||
stream = extractor.create_stream() | ||
stream.accept_waveform(sample_rate=sample_rate, waveform=data) | ||
stream.input_finished() | ||
assert extractor.is_ready(stream) | ||
embedding = extractor.compute(stream) | ||
embedding = np.array(embedding) | ||
tmp[name].append(embedding) | ||
|
||
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim) | ||
for name, embedding_list in tmp.items(): | ||
print(name, len(embedding_list)) | ||
embedding = sum(embedding_list) / len(embedding_list) | ||
status = manager.add(name, embedding) | ||
if not status: | ||
raise RuntimeError(f"Failed to register speaker {name}") | ||
|
||
filenames = [ | ||
"leijun-test-sr-1", | ||
"leijun-test-sr-2", | ||
"leijun-test-sr-3", | ||
"fangjun-test-sr-1", | ||
"fangjun-test-sr-2", | ||
] | ||
for filename in filenames: | ||
name = filename.split("-", maxsplit=1)[0] | ||
data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/test/{filename}.wav") | ||
stream = extractor.create_stream() | ||
stream.accept_waveform(sample_rate=sample_rate, waveform=data) | ||
stream.input_finished() | ||
assert extractor.is_ready(stream) | ||
embedding = extractor.compute(stream) | ||
embedding = np.array(embedding) | ||
status = manager.verify(name, embedding, threshold=0.5) | ||
if not status: | ||
raise RuntimeError(f"Failed to verify {name} with wave {filename}.wav") | ||
|
||
ans = manager.search(embedding, threshold=0.5) | ||
assert ans == name, (name, ans) | ||
|
||
|
||
def test_3dspeaker_model(model_filename: str): | ||
extractor = load_speaker_embedding_model(str(model_filename)) | ||
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim) | ||
|
||
filenames = [ | ||
"speaker1_a_cn_16k", | ||
"speaker2_a_cn_16k", | ||
"speaker1_a_en_16k", | ||
"speaker2_a_en_16k", | ||
] | ||
for filename in filenames: | ||
name = filename.rsplit("_", maxsplit=1)[0] | ||
data, sample_rate = read_wave( | ||
f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav" | ||
) | ||
stream = extractor.create_stream() | ||
stream.accept_waveform(sample_rate=sample_rate, waveform=data) | ||
stream.input_finished() | ||
assert extractor.is_ready(stream) | ||
embedding = extractor.compute(stream) | ||
embedding = np.array(embedding) | ||
|
||
status = manager.add(name, embedding) | ||
if not status: | ||
raise RuntimeError(f"Failed to register speaker {name}") | ||
|
||
filenames = [ | ||
"speaker1_b_cn_16k", | ||
"speaker1_b_en_16k", | ||
] | ||
for filename in filenames: | ||
print(filename) | ||
name = filename.rsplit("_", maxsplit=1)[0] | ||
name = name.replace("b_cn", "a_cn") | ||
name = name.replace("b_en", "a_en") | ||
print(name) | ||
|
||
data, sample_rate = read_wave( | ||
f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav" | ||
) | ||
stream = extractor.create_stream() | ||
stream.accept_waveform(sample_rate=sample_rate, waveform=data) | ||
stream.input_finished() | ||
assert extractor.is_ready(stream) | ||
embedding = extractor.compute(stream) | ||
embedding = np.array(embedding) | ||
status = manager.verify(name, embedding, threshold=0.5) | ||
if not status: | ||
raise RuntimeError( | ||
f"Failed to verify {name} with wave {filename}.wav. model: {model_filename}" | ||
) | ||
|
||
ans = manager.search(embedding, threshold=0.5) | ||
assert ans == name, (name, ans) | ||
|
||
|
||
class TestSpeakerRecognition(unittest.TestCase): | ||
def test_wespeaker_models(self): | ||
model_dir = Path(d) / "wespeaker" | ||
if not model_dir.is_dir(): | ||
print(f"{model_dir} does not exist - skip it") | ||
return | ||
for filename in model_dir.glob("*.onnx"): | ||
print(filename) | ||
test_wespeaker_model(filename) | ||
|
||
def test_3dpeaker_models(self): | ||
model_dir = Path(d) / "3dspeaker" | ||
if not model_dir.is_dir(): | ||
print(f"{model_dir} does not exist - skip it") | ||
return | ||
for filename in model_dir.glob("*.onnx"): | ||
print(filename) | ||
test_3dspeaker_model(filename) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |