Skip to content

Commit

Permalink
Begin to add test for speaker recognition
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jan 11, 2024
1 parent 8d3d876 commit ce1c5e4
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 0 deletions.
60 changes: 60 additions & 0 deletions .github/scripts/test-speaker-recognition-python.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env bash

set -e

log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

d=/tmp/sr-models
mkdir -p $d

pushd $d
log "Download test waves"
git clone https://github.com/csukuangfj/sr-data
popd

log "Download wespeaker models"
model_dir=$d/wespeaker
mkdir -p $model_dir
pushd $model_dir
models=(
en_voxceleb_CAM++.onnx
en_voxceleb_CAM++_LM.onnx
en_voxceleb_resnet152_LM.onnx
en_voxceleb_resnet221_LM.onnx
en_voxceleb_resnet293_LM.onnx
en_voxceleb_resnet34.onnx
en_voxceleb_resnet34_LM.onnx
zh_cnceleb_resnet34.onnx
zh_cnceleb_resnet34_LM.onnx
)
for m in ${models[@]}; do
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m
done
ls -lh
popd

log "Download 3d-speaker models"
model_dir=$d/3dspeaker
mkdir -p $model_dir
pushd $model_dir
models=(
speech_campplus_sv_en_voxceleb_16k.onnx
speech_campplus_sv_zh-cn_16k-common.onnx
speech_eres2net_base_200k_sv_zh-cn_16k-common.onnx
speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
speech_eres2net_sv_en_voxceleb_16k.onnx
speech_eres2net_sv_zh-cn_16k-common.onnx
)
for m in ${models[@]}; do
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m
done
ls -lh
popd


python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose
1 change: 1 addition & 0 deletions .github/workflows/run-python-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ jobs:
- name: Test sherpa-onnx
shell: bash
run: |
.github/scripts/test-speaker-recognition-python.sh
.github/scripts/test-python.sh
- uses: actions/upload-artifact@v3
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ set(py_test_files
test_offline_recognizer.py
test_online_recognizer.py
test_online_transducer_model_config.py
test_speaker_recognition.py
test_text2token.py
)

Expand Down
194 changes: 194 additions & 0 deletions sherpa-onnx/python/tests/test_speaker_recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# sherpa-onnx/python/tests/test_speaker_recognition.py
#
# Copyright (c) 2024 Xiaomi Corporation
#
# To run this single test, use
#
# ctest --verbose -R test_speaker_recognition_py

import unittest
import wave
from collections import defaultdict
from pathlib import Path
from typing import Tuple

import numpy as np
import sherpa_onnx

d = "/tmp/sr-models"


def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""

with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)

samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()


def load_speaker_embedding_model(model_filename):
config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
model=model_filename,
num_threads=1,
debug=True,
provider="cpu",
)
if not config.validate():
raise ValueError(f"Invalid config. {config}")
extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
return extractor


def test_wespeaker_model(model_filename: str):
model_filename = str(model_filename)
if "en" in model_filename:
print(f"skip {model_filename}")
return
extractor = load_speaker_embedding_model(model_filename)
filenames = [
"leijun-sr-1",
"leijun-sr-2",
"fangjun-sr-1",
"fangjun-sr-2",
"fangjun-sr-3",
]
tmp = defaultdict(list)
for filename in filenames:
print(filename)
name = filename.split("-", maxsplit=1)[0]
data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/enroll/{filename}.wav")
stream = extractor.create_stream()
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
stream.input_finished()
assert extractor.is_ready(stream)
embedding = extractor.compute(stream)
embedding = np.array(embedding)
tmp[name].append(embedding)

manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
for name, embedding_list in tmp.items():
print(name, len(embedding_list))
embedding = sum(embedding_list) / len(embedding_list)
status = manager.add(name, embedding)
if not status:
raise RuntimeError(f"Failed to register speaker {name}")

filenames = [
"leijun-test-sr-1",
"leijun-test-sr-2",
"leijun-test-sr-3",
"fangjun-test-sr-1",
"fangjun-test-sr-2",
]
for filename in filenames:
name = filename.split("-", maxsplit=1)[0]
data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/test/{filename}.wav")
stream = extractor.create_stream()
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
stream.input_finished()
assert extractor.is_ready(stream)
embedding = extractor.compute(stream)
embedding = np.array(embedding)
status = manager.verify(name, embedding, threshold=0.5)
if not status:
raise RuntimeError(f"Failed to verify {name} with wave {filename}.wav")

ans = manager.search(embedding, threshold=0.5)
assert ans == name, (name, ans)


def test_3dspeaker_model(model_filename: str):
extractor = load_speaker_embedding_model(str(model_filename))
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)

filenames = [
"speaker1_a_cn_16k",
"speaker2_a_cn_16k",
"speaker1_a_en_16k",
"speaker2_a_en_16k",
]
for filename in filenames:
name = filename.rsplit("_", maxsplit=1)[0]
data, sample_rate = read_wave(
f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
)
stream = extractor.create_stream()
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
stream.input_finished()
assert extractor.is_ready(stream)
embedding = extractor.compute(stream)
embedding = np.array(embedding)

status = manager.add(name, embedding)
if not status:
raise RuntimeError(f"Failed to register speaker {name}")

filenames = [
"speaker1_b_cn_16k",
"speaker1_b_en_16k",
]
for filename in filenames:
print(filename)
name = filename.rsplit("_", maxsplit=1)[0]
name = name.replace("b_cn", "a_cn")
name = name.replace("b_en", "a_en")
print(name)

data, sample_rate = read_wave(
f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
)
stream = extractor.create_stream()
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
stream.input_finished()
assert extractor.is_ready(stream)
embedding = extractor.compute(stream)
embedding = np.array(embedding)
status = manager.verify(name, embedding, threshold=0.5)
if not status:
raise RuntimeError(
f"Failed to verify {name} with wave {filename}.wav. model: {model_filename}"
)

ans = manager.search(embedding, threshold=0.5)
assert ans == name, (name, ans)


class TestSpeakerRecognition(unittest.TestCase):
def test_wespeaker_models(self):
model_dir = Path(d) / "wespeaker"
if not model_dir.is_dir():
print(f"{model_dir} does not exist - skip it")
return
for filename in model_dir.glob("*.onnx"):
print(filename)
test_wespeaker_model(filename)

def test_3dpeaker_models(self):
model_dir = Path(d) / "3dspeaker"
if not model_dir.is_dir():
print(f"{model_dir} does not exist - skip it")
return
for filename in model_dir.glob("*.onnx"):
print(filename)
test_3dspeaker_model(filename)


if __name__ == "__main__":
unittest.main()

0 comments on commit ce1c5e4

Please sign in to comment.