From bead1f71b68f2c2200bccfde96180eac6e04de8c Mon Sep 17 00:00:00 2001 From: TomasLiu Date: Fri, 22 Nov 2024 22:21:57 +0800 Subject: [PATCH] add style and template --- BUILD.gn | 2 ++ README.md | 4 +++ manifest.json | 19 +++++++++++++- src/main.cc | 57 +++++++++++++++++++++++++++++++++++++---- src/tmpl.h | 70 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/tts.cc | 68 ++++++++++++++++++++++++++++++++----------------- src/tts.h | 8 +++--- 7 files changed, 195 insertions(+), 33 deletions(-) create mode 100644 src/tmpl.h diff --git a/BUILD.gn b/BUILD.gn index 6a99501..abbd934 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -26,11 +26,13 @@ ten_package("azure_tts") { "include", # The build flags for in-app building + "//ten_packages/system/nlohmann_json/include", "//ten_packages/system/ten_runtime/include", "//ten_packages/system/azure_speech_sdk/include/microsoft/c_api", "//ten_packages/system/azure_speech_sdk/include/microsoft/cxx_api", # The build flags for standalone building. + ".ten/app/ten_packages/system/nlohmann_json/include", ".ten/app/ten_packages/system/ten_runtime/include", ".ten/app/ten_packages/system/azure_speech_sdk/include/microsoft/c_api", ".ten/app/ten_packages/system/azure_speech_sdk/include/microsoft/cxx_api", diff --git a/README.md b/README.md index c04e524..9d7a350 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,10 @@ TEN extension of [azure Text to speech service](https://learn.microsoft.com/en-u | `azure_subscription_key` | `string` | `""` | Azure Speech service subscription key | | `azure_subscription_region` | `string` | `""` | Azure Speech service subscription region | | `azure_synthesis_voice_name` | `string` | `""` | e.g., `en-US-AdamMultilingualNeural`, check more available voices in [languages and voices support](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts) | +| `prosody` | `string` | `""` | Azure Speech prosody | +| `language` | `string` | `""` | Azure Speech language | +| `role` | `string` | `""` | Azure Speech role | +| `style` | `string` | `""` | Azure Speech style | ## Development diff --git a/manifest.json b/manifest.json index fef1988..a934f84 100644 --- a/manifest.json +++ b/manifest.json @@ -1,7 +1,7 @@ { "type": "extension", "name": "azure_tts", - "version": "0.6.0", + "version": "0.6.1", "dependencies": [ { "type": "system", @@ -12,6 +12,11 @@ "type": "system", "name": "azure_speech_sdk", "version": "1.38.0" + }, + { + "type": "system", + "name": "nlohmann_json", + "version": "=3.11.2" } ], "api": { @@ -24,6 +29,18 @@ }, "azure_synthesis_voice_name": { "type": "string" + }, + "style": { + "type": "string" + }, + "prosody": { + "type": "string" + }, + "language": { + "type": "string" + }, + "role": { + "type": "string" } }, "data_in": [ diff --git a/src/main.cc b/src/main.cc index aaec9fd..aff78d0 100644 --- a/src/main.cc +++ b/src/main.cc @@ -10,14 +10,37 @@ #include #include #include +#include #include "log.h" #include "ten_runtime/binding/cpp/ten.h" #include "ten_utils/macro/check.h" #include "tts.h" +#include "tmpl.h" namespace azure_tts_extension { +std::string trimString(const std::string& input) { + std::string result = input; + std::string::size_type pos; + + // Remove all occurrences of "\n" + while ((pos = result.find("\\n")) != std::string::npos) { + result.erase(pos, 2); + } + + // Remove all occurrences of "\r" + while ((pos = result.find("\\r")) != std::string::npos) { + result.erase(pos, 2); + } + + // Remove all occurrences of "\t" + while ((pos = result.find("\\t")) != std::string::npos) { + result.erase(pos, 2); + } + return result; +} + class azure_tts_extension_t : public ten::extension_t { public: explicit azure_tts_extension_t(const std::string &name) : extension_t(name) {} @@ -34,14 +57,19 @@ class azure_tts_extension_t : public ten::extension_t { // read properties auto key = ten.get_property_string("azure_subscription_key"); auto region = ten.get_property_string("azure_subscription_region"); - auto voice_name = ten.get_property_string("azure_synthesis_voice_name"); - if (key.empty() || region.empty() || voice_name.empty()) { + voice_ = ten.get_property_string("azure_synthesis_voice_name"); + if (key.empty() || region.empty() || voice_.empty()) { AZURE_TTS_LOGE( "azure_subscription_key, azure_subscription_region, azure_synthesis_voice_name should not be empty, start " "failed"); return; } + style_ = ten.get_property_string("style"); + prosody_ = ten.get_property_string("prosody"); + language_ = ten.get_property_string("language"); + role_ = ten.get_property_string("role"); + ten_proxy_ = std::unique_ptr(ten::ten_env_proxy_t::create(ten)); TEN_ASSERT(ten_proxy_ != nullptr, "ten_proxy should not be nullptr"); @@ -83,7 +111,7 @@ class azure_tts_extension_t : public ten::extension_t { azure_tts_ = std::make_unique( key, region, - voice_name, + voice_, Microsoft::CognitiveServices::Speech::SpeechSynthesisOutputFormat::Raw16Khz16BitMonoPcm, pcm_frame_size, std::move(pcm_callback)); @@ -132,10 +160,22 @@ class azure_tts_extension_t : public ten::extension_t { AZURE_TTS_LOGD("input text is empty, ignored"); return; } - AZURE_TTS_LOGI("input text: [%s]", text.c_str()); + text = trimString(text); // push received text to tts queue for synthesis - azure_tts_->Push(text); + if (!prosody_.empty() || !language_.empty() || !role_.empty() || !style_.empty()) { + MsttsTemplate tmpl; + auto ssml_text = tmpl.replace(json{{"role", role_}, + {"voice", voice_}, + {"language", language_}, + {"style", style_}, + {"prosody", prosody_}, + {"text", text}}); + AZURE_TTS_LOGI("input ssml text: [%s]", ssml_text.c_str()); + azure_tts_->Push(ssml_text, true); + } else { + azure_tts_->Push(text, false); + } } // on_stop will be called when the extension is stopping. @@ -157,8 +197,15 @@ class azure_tts_extension_t : public ten::extension_t { std::unique_ptr azure_tts_; + std::string voice_; + std::string prosody_; + std::string language_; + std::string role_; + std::string style_; + const std::string kCmdNameFlush{"flush"}; const std::string kDataFieldText{"text"}; + const std::string kDataFieldSSML{"ssml"}; }; TEN_CPP_REGISTER_ADDON_AS_EXTENSION(azure_tts, azure_tts_extension_t); diff --git a/src/tmpl.h b/src/tmpl.h new file mode 100644 index 0000000..7994b64 --- /dev/null +++ b/src/tmpl.h @@ -0,0 +1,70 @@ +#include +#include +#include +#include + +namespace azure_tts_extension { + +using json = nlohmann::json; + +class MsttsTemplate { + public: + MsttsTemplate() = default; + ~MsttsTemplate() = default; + + std::string replace(const json& params) const { + std::string result = templateStr_; + if (params.contains("prosody")) { + auto value = params["prosody"].get(); + if (!value.empty()) { + result = templateProsodyStr_; + } + } + replacePlaceholder(result, "lang", params, "xml:lang=\"{lang}\""); + replacePlaceholder(result, "voice", params, "name=\"{voice}\""); + replacePlaceholder(result, "style", params, "style=\"{style}\""); + replacePlaceholder(result, "role", params, "role=\"{role}\""); + replacePlaceholder(result, "prosody", params, "{prosody}"); + replacePlaceholder(result, "text", params, "{text}"); + return result; + } + + private: + std::string templateProsodyStr_ = R"( + + + + + {text} + + + + + )"; + + std::string templateStr_ = R"( + + + + {text} + + + + )"; + + void replacePlaceholder(std::string& result, const std::string& placeholder, const json& params, const std::string& templateStr) const { + std::string value = ""; + if (params.contains(placeholder)) { + value = params[placeholder].get(); + } + std::string tempStr = templateStr; + if (value.empty()) { + result = std::regex_replace(result, std::regex("\\{" + placeholder + "\\}"), ""); + } else { + tempStr = std::regex_replace(tempStr, std::regex("\\{" + placeholder + "\\}"), value); + result = std::regex_replace(result, std::regex("\\{" + placeholder + "\\}"), tempStr); + } + } +}; + +} // namespace azure_tts_extension \ No newline at end of file diff --git a/src/tts.cc b/src/tts.cc index 6da7554..b363721 100644 --- a/src/tts.cc +++ b/src/tts.cc @@ -1,4 +1,3 @@ - #include "tts.h" #include @@ -44,7 +43,7 @@ bool AzureTTS::Start() { tasks_.pop(); } - SpeechText(task->text, task->ts); + SpeechText(task->text, task->ts, task->ssml); } AZURE_TTS_LOGI("tts_thread stopped"); @@ -67,16 +66,17 @@ bool AzureTTS::Stop() { return true; } -void AzureTTS::Push(const std::string& text) noexcept { +void AzureTTS::Push(const std::string& text, bool ssml) noexcept { auto ts = time_since_epoch_in_us(); { std::unique_lock lock(tasks_mutex_); - tasks_.emplace(std::make_unique(text, ts)); + tasks_.emplace(std::make_unique(text, ts, ssml)); tasks_cv_.notify_one(); - AZURE_TTS_LOGD("task pushed for text: [%s], text_recv_ts: %" PRId64 ", queue size %d", + AZURE_TTS_LOGD("task pushed for text: [%s], ssml: %d, text_recv_ts: %" PRId64 ", queue size %d", text.c_str(), + ssml, ts, int(tasks_.size())); } @@ -94,9 +94,9 @@ void AzureTTS::Flush() noexcept { } } -void AzureTTS::SpeechText(const std::string& text, int64_t text_recv_ts) { +void AzureTTS::SpeechText(const std::string& text, int64_t text_recv_ts, bool ssml) { auto start_time = time_since_epoch_in_us(); - AZURE_TTS_LOGD("task starting for text: [%s], text_recv_ts: %" PRId64, text.c_str(), text_recv_ts); + AZURE_TTS_LOGD("task starting for text: [%s], ssml: %d text_recv_ts: %" PRId64, text.c_str(), ssml, text_recv_ts); if (text_recv_ts < outdate_ts_.load()) { AZURE_TTS_LOGI("task discard for text: [%s], text_recv_ts: %" PRId64 ", outdate_ts: %" PRId64, @@ -108,24 +108,46 @@ void AzureTTS::SpeechText(const std::string& text, int64_t text_recv_ts) { using namespace Microsoft::CognitiveServices; + std::shared_ptr result; // async mode - auto result = speech_synthesizer_->StartSpeakingTextAsync(text).get(); - if (result->Reason == Speech::ResultReason::Canceled) { - auto cancellation = Speech::SpeechSynthesisCancellationDetails::FromResult(result); - AZURE_TTS_LOGW("task canceled for text: [%s], text_recv_ts: %" PRId64 ", reason: %d", - text.c_str(), - text_recv_ts, - (int)cancellation->Reason); - - if (cancellation->Reason == Speech::CancellationReason::Error) { - AZURE_TTS_LOGW("task canceled on error for text: [%s], text_recv_ts: %" PRId64 - ", errorcode: %d, details: %s, did you update the subscription info?", - text.c_str(), - text_recv_ts, - (int)cancellation->ErrorCode, - cancellation->ErrorDetails.c_str()); + if (ssml) { + result = speech_synthesizer_->StartSpeakingSsmlAsync(text).get(); + if (result->Reason == Speech::ResultReason::Canceled) { + auto cancellation = Speech::SpeechSynthesisCancellationDetails::FromResult(result); + AZURE_TTS_LOGW("task canceled for ssml: [%s], text_recv_ts: %" PRId64 ", reason: %d", + text.c_str(), + text_recv_ts, + (int)cancellation->Reason); + + if (cancellation->Reason == Speech::CancellationReason::Error) { + AZURE_TTS_LOGW("task canceled on error for ssml: [%s], text_recv_ts: %" PRId64 + ", errorcode: %d, details: %s, did you update the subscription info?", + text.c_str(), + text_recv_ts, + (int)cancellation->ErrorCode, + cancellation->ErrorDetails.c_str()); + } + return; + } + } else { + result = speech_synthesizer_->StartSpeakingTextAsync(text).get(); + if (result->Reason == Speech::ResultReason::Canceled) { + auto cancellation = Speech::SpeechSynthesisCancellationDetails::FromResult(result); + AZURE_TTS_LOGW("task canceled for text: [%s], text_recv_ts: %" PRId64 ", reason: %d", + text.c_str(), + text_recv_ts, + (int)cancellation->Reason); + + if (cancellation->Reason == Speech::CancellationReason::Error) { + AZURE_TTS_LOGW("task canceled on error for text: [%s], text_recv_ts: %" PRId64 + ", errorcode: %d, details: %s, did you update the subscription info?", + text.c_str(), + text_recv_ts, + (int)cancellation->ErrorCode, + cancellation->ErrorDetails.c_str()); + } + return; } - return; } auto audioDataStream = Speech::AudioDataStream::FromResult(result); diff --git a/src/tts.h b/src/tts.h index 7ab191a..1b1106a 100644 --- a/src/tts.h +++ b/src/tts.h @@ -1,4 +1,3 @@ - #include #include @@ -40,10 +39,10 @@ class AzureTTS { void Flush() noexcept; - void Push(const std::string &text) noexcept; + void Push(const std::string &text, bool ssml) noexcept; private: - void SpeechText(const std::string &text, int64_t text_recv_ts); + void SpeechText(const std::string &text, int64_t text_recv_ts, bool ssml); int64_t time_since_epoch_in_us() const; @@ -59,10 +58,11 @@ class AzureTTS { std::atomic_int64_t outdate_ts_{0}; // for flushing struct Task { - Task(const std::string &t, int64_t ts) : ts(ts), text(t) {} + Task(const std::string &t, int64_t ts, bool ssml) : ts(ts), text(t), ssml(ssml) {} int64_t ts{0}; std::string text; + bool ssml; }; std::queue> tasks_;