Skip to content

Commit

Permalink
Kotlin API for streaming zipformer2 CTC
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Dec 21, 2023
1 parent 047a81f commit 198a65c
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,14 @@ data class OnlineParaformerModelConfig(
var decoder: String = "",
)

data class OnlineZipformer2CtcModelConfig(
var model: String = "",
)

data class OnlineModelConfig(
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(),
var zipformer2Ctc: OnlineZipformer2CtcModelConfig = OnlineZipformer2CtcModelConfig(),
var tokens: String,
var numThreads: Int = 1,
var debug: Boolean = false,
Expand Down
51 changes: 35 additions & 16 deletions kotlin-api-examples/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ fun callback(samples: FloatArray): Unit {

fun main() {
testTts()
testAsr()
testAsr("transducer")
testAsr("zipformer2-ctc")
}

fun testTts() {
Expand All @@ -30,25 +31,43 @@ fun testTts() {
audio.save(filename="test-en.wav")
}

fun testAsr() {
fun testAsr(type: String) {
var featConfig = FeatureConfig(
sampleRate = 16000,
featureDim = 80,
)

// please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// to dowload pre-trained models
var modelConfig = OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
numThreads = 1,
debug = false,
)
var waveFilename: String
var modelConfig: OnlineModelConfig = when (type) {
"transducer" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav"
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// to dowload pre-trained models
OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
numThreads = 1,
debug = false,
)
}
"zipformer2-ctc" -> {
waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav"
OnlineModelConfig(
zipformer2Ctc = OnlineZipformer2CtcModelConfig(
model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx",
),
tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt",
numThreads = 1,
debug = false,
)
}
else -> throw IllegalArgumentException(type)
}

var endpointConfig = EndpointConfig()

Expand All @@ -69,7 +88,7 @@ fun testAsr() {
)

var objArray = WaveReader.readWaveFromFile(
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav",
filename = waveFilename,
)
var samples: FloatArray = objArray[0] as FloatArray
var sampleRate: Int = objArray[1] as Int
Expand Down
6 changes: 6 additions & 0 deletions kotlin-api-examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
fi

if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
fi

if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
tar xf vits-piper-en_US-amy-low.tar.bz2
Expand Down
22 changes: 17 additions & 5 deletions sherpa-onnx/jni/jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,22 +262,34 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(model_config_cls, "paraformer",
"Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
jobject paraformer_config = env->GetObjectField(model_config, fid);
jclass paraformer_config_config_cls = env->GetObjectClass(paraformer_config);
jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);

fid = env->GetFieldID(paraformer_config_config_cls, "encoder",
"Ljava/lang/String;");
fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.encoder = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(paraformer_config_config_cls, "decoder",
"Ljava/lang/String;");
fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.decoder = p;
env->ReleaseStringUTFChars(s, p);

// streaming zipformer2 CTC
fid =
env->GetFieldID(model_config_cls, "zipformer2Ctc",
"Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;");
jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid);
jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config);

fid =
env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.zipformer2_ctc.model = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
Expand Down

0 comments on commit 198a65c

Please sign in to comment.