From 6422966a7fdd8a35c38add141d9c382dce9bc318 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 5 Aug 2024 14:06:21 +0800 Subject: [PATCH] Support passing TTS callback in Swift API (#1218) --- swift-api-examples/SherpaOnnx.swift | 19 ++++++++++++++ swift-api-examples/tts.swift | 39 ++++++++++++++++++++++++++--- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index 6e1900eca..dc80a9650 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -757,6 +757,14 @@ class SherpaOnnxGeneratedAudioWrapper { } } +typealias TtsCallbackWithArg = ( + @convention(c) ( + UnsafePointer?, // const float* samples + Int32, // int32_t n + UnsafeMutableRawPointer? // void *arg + ) -> Int32 +)? + class SherpaOnnxOfflineTtsWrapper { /// A pointer to the underlying counterpart in C let tts: OpaquePointer! @@ -780,6 +788,17 @@ class SherpaOnnxOfflineTtsWrapper { return SherpaOnnxGeneratedAudioWrapper(audio: audio) } + + func generateWithCallbackWithArg( + text: String, callback: TtsCallbackWithArg, arg: UnsafeMutableRawPointer, sid: Int = 0, + speed: Float = 1.0 + ) -> SherpaOnnxGeneratedAudioWrapper { + let audio: UnsafePointer? = + SherpaOnnxOfflineTtsGenerateWithCallbackWithArg( + tts, toCPointer(text), Int32(sid), speed, callback, arg) + + return SherpaOnnxGeneratedAudioWrapper(audio: audio) + } } // spoken language identification diff --git a/swift-api-examples/tts.swift b/swift-api-examples/tts.swift index 61c68bce3..fc8cc787c 100644 --- a/swift-api-examples/tts.swift +++ b/swift-api-examples/tts.swift @@ -1,3 +1,9 @@ +class MyClass { + func playSamples(samples: [Float]) { + print("Play \(samples.count) samples") + } +} + func run() { let model = "./vits-piper-en_US-amy-low/en_US-amy-low.onnx" let tokens = "./vits-piper-en_US-amy-low/tokens.txt" @@ -11,6 +17,27 @@ func run() { let modelConfig = sherpaOnnxOfflineTtsModelConfig(vits: vits) var ttsConfig = sherpaOnnxOfflineTtsConfig(model: modelConfig) + let myClass = MyClass() + + // We use Unretained here so myClass must be kept alive as the callback is invoked + // + // See also + // https://medium.com/codex/swift-c-callback-interoperability-6d57da6c8ee6 + let arg = Unmanaged.passUnretained(myClass).toOpaque() + + let callback: TtsCallbackWithArg = { samples, n, arg in + let o = Unmanaged.fromOpaque(arg!).takeUnretainedValue() + var savedSamples: [Float] = [] + for index in 0..