Skip to content

Commit

Permalink
Add Kotlin API for Matcha-TTS models. (#1668)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Dec 31, 2024
1 parent 0a43e9c commit 3422b93
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 9 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/jni.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,8 @@ jobs:
cd ./kotlin-api-examples
./run.sh
- uses: actions/upload-artifact@v4
with:
name: tts-files-${{ matrix.os }}
path: kotlin-api-examples/test-*.wav
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,4 @@ sherpa-onnx-moonshine-tiny-en-int8
sherpa-onnx-moonshine-base-en-int8
harmony-os/SherpaOnnxHar/sherpa_onnx/LICENSE
harmony-os/SherpaOnnxHar/sherpa_onnx/CHANGELOG.md
matcha-icefall-zh-baker
10 changes: 10 additions & 0 deletions kotlin-api-examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ function testTts() {
rm vits-piper-en_US-amy-low.tar.bz2
fi

if [ ! -f ./matcha-icefall-zh-baker/model-steps-3.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2
tar xvf matcha-icefall-zh-baker.tar.bz2
rm matcha-icefall-zh-baker.tar.bz2
fi

if [ ! -f ./hifigan_v2.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx
fi

out_filename=test_tts.jar
kotlinc-jvm -include-runtime -d $out_filename \
test_tts.kt \
Expand Down
29 changes: 27 additions & 2 deletions kotlin-api-examples/test_tts.kt
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
package com.k2fsa.sherpa.onnx

fun main() {
testTts()
testVits()
testMatcha()
}

fun testTts() {
fun testMatcha() {
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2
var config = OfflineTtsConfig(
model=OfflineTtsModelConfig(
matcha=OfflineTtsMatchaModelConfig(
acousticModel="./matcha-icefall-zh-baker/model-steps-3.onnx",
vocoder="./hifigan_v2.onnx",
tokens="./matcha-icefall-zh-baker/tokens.txt",
lexicon="./matcha-icefall-zh-baker/lexicon.txt",
dictDir="./matcha-icefall-zh-baker/dict",
),
numThreads=1,
debug=true,
),
ruleFsts="./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst",
)
val tts = OfflineTts(config=config)
val audio = tts.generateWithCallback(text="某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。", callback=::callback)
audio.save(filename="test-zh.wav")
tts.release()
println("Saved to test-zh.wav")
}

fun testVits() {
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
var config = OfflineTtsConfig(
Expand Down
12 changes: 8 additions & 4 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1727,11 +1727,15 @@ const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
auto p = new SherpaOnnxOnlinePunctuation;
try {
sherpa_onnx::OnlinePunctuationConfig punctuation_config;
punctuation_config.model.cnn_bilstm = SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
punctuation_config.model.bpe_vocab = SHERPA_ONNX_OR(config->model.bpe_vocab, "");
punctuation_config.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
punctuation_config.model.cnn_bilstm =
SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
punctuation_config.model.bpe_vocab =
SHERPA_ONNX_OR(config->model.bpe_vocab, "");
punctuation_config.model.num_threads =
SHERPA_ONNX_OR(config->model.num_threads, 1);
punctuation_config.model.debug = config->model.debug;
punctuation_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
punctuation_config.model.provider =
SHERPA_ONNX_OR(config->model.provider, "cpu");

p->impl =
std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config);
Expand Down
6 changes: 4 additions & 2 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1381,12 +1381,14 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig {
SherpaOnnxOnlinePunctuationModelConfig model;
} SherpaOnnxOnlinePunctuationConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation SherpaOnnxOnlinePunctuation;
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation
SherpaOnnxOnlinePunctuation;

// Create an online punctuation processor. The user has to invoke
// SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer
// to avoid memory leak
SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *
SherpaOnnxCreateOnlinePunctuation(
const SherpaOnnxOnlinePunctuationConfig *config);

// Free a pointer returned by SherpaOnnxCreateOnlinePunctuation()
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/jieba-lexicon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class JiebaLexicon::Impl {

this_sentence.insert(this_sentence.end(), ids.begin(), ids.end());

if (w == "" || w == "" || w == "" || w == "") {
if (IsPunct(w)) {
ans.emplace_back(std::move(this_sentence));
this_sentence = {};
}
Expand Down
49 changes: 49 additions & 0 deletions sherpa-onnx/jni/offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
jobject model = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model);

// vits
fid = env->GetFieldID(model_config_cls, "vits",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;");
jobject vits = env->GetObjectField(model, fid);
Expand Down Expand Up @@ -64,6 +65,54 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(vits_cls, "lengthScale", "F");
ans.model.vits.length_scale = env->GetFloatField(vits, fid);

// matcha
fid = env->GetFieldID(model_config_cls, "matcha",
"Lcom/k2fsa/sherpa/onnx/OfflineTtsMatchaModelConfig;");
jobject matcha = env->GetObjectField(model, fid);
jclass matcha_cls = env->GetObjectClass(matcha);

fid = env->GetFieldID(matcha_cls, "acousticModel", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.acoustic_model = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(matcha_cls, "vocoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.vocoder = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(matcha_cls, "lexicon", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.lexicon = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(matcha_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.tokens = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(matcha_cls, "dataDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.data_dir = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(matcha_cls, "dictDir", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(matcha, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.matcha.dict_dir = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(matcha_cls, "noiseScale", "F");
ans.model.matcha.noise_scale = env->GetFloatField(matcha, fid);

fid = env->GetFieldID(matcha_cls, "lengthScale", "F");
ans.model.matcha.length_scale = env->GetFloatField(matcha, fid);

fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model, fid);

Expand Down
12 changes: 12 additions & 0 deletions sherpa-onnx/kotlin-api/Tts.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,20 @@ data class OfflineTtsVitsModelConfig(
var lengthScale: Float = 1.0f,
)

data class OfflineTtsMatchaModelConfig(
var acousticModel: String = "",
var vocoder: String = "",
var lexicon: String = "",
var tokens: String = "",
var dataDir: String = "",
var dictDir: String = "",
var noiseScale: Float = 1.0f,
var lengthScale: Float = 1.0f,
)

data class OfflineTtsModelConfig(
var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(),
var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(),
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
Expand Down

0 comments on commit 3422b93

Please sign in to comment.