diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py index f5fb5cc8e..672e3d17a 100644 --- a/cmake/cmake_extension.py +++ b/cmake/cmake_extension.py @@ -58,6 +58,7 @@ def get_binaries(): "sherpa-onnx-offline-tts", "sherpa-onnx-offline-tts-play", "sherpa-onnx-offline-websocket-server", + "sherpa-onnx-online-punctuation", "sherpa-onnx-online-websocket-client", "sherpa-onnx-online-websocket-server", "sherpa-onnx-vad-microphone", diff --git a/sherpa-onnx/csrc/online-cnn-bilstm-model.cc b/sherpa-onnx/csrc/online-cnn-bilstm-model.cc index 739cf83f7..ce8da377e 100644 --- a/sherpa-onnx/csrc/online-cnn-bilstm-model.cc +++ b/sherpa-onnx/csrc/online-cnn-bilstm-model.cc @@ -35,8 +35,11 @@ class OnlineCNNBiLSTMModel::Impl { } #endif - std::pair Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) { - std::array inputs = {std::move(token_ids), std::move(valid_ids), std::move(label_lens)}; + std::pair Forward(Ort::Value token_ids, + Ort::Value valid_ids, + Ort::Value label_lens) { + std::array inputs = { + std::move(token_ids), std::move(valid_ids), std::move(label_lens)}; auto ans = sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), @@ -117,18 +120,18 @@ OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel( OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default; -std::pair OnlineCNNBiLSTMModel::Forward(Ort::Value token_ids, - Ort::Value valid_ids, - Ort::Value label_lens) const { - return impl_->Forward(std::move(token_ids), std::move(valid_ids), std::move(label_lens)); +std::pair OnlineCNNBiLSTMModel::Forward( + Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const { + return impl_->Forward(std::move(token_ids), std::move(valid_ids), + std::move(label_lens)); } OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const { return impl_->Allocator(); } -const OnlineCNNBiLSTMModelMetaData & -OnlineCNNBiLSTMModel::GetModelMetadata() const { +const OnlineCNNBiLSTMModelMetaData &OnlineCNNBiLSTMModel::GetModelMetadata() + const { return impl_->GetModelMetadata(); } diff --git a/sherpa-onnx/csrc/online-cnn-bilstm-model.h b/sherpa-onnx/csrc/online-cnn-bilstm-model.h index aa0ca2d34..25886107a 100644 --- a/sherpa-onnx/csrc/online-cnn-bilstm-model.h +++ b/sherpa-onnx/csrc/online-cnn-bilstm-model.h @@ -23,12 +23,11 @@ namespace sherpa_onnx { */ class OnlineCNNBiLSTMModel { public: - explicit OnlineCNNBiLSTMModel( - const OnlinePunctuationModelConfig &config); + explicit OnlineCNNBiLSTMModel(const OnlinePunctuationModelConfig &config); #if __ANDROID_API__ >= 9 OnlineCNNBiLSTMModel(AAssetManager *mgr, - const OnlinePunctuationModelConfig &config); + const OnlinePunctuationModelConfig &config); #endif ~OnlineCNNBiLSTMModel(); @@ -43,7 +42,9 @@ class OnlineCNNBiLSTMModel { * - case_logits: A 2-D tensor of shape (T', num_cases). * - punct_logits: A 2-D tensor of shape (T', num_puncts). */ - std::pair Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const; + std::pair Forward(Ort::Value token_ids, + Ort::Value valid_ids, + Ort::Value label_lens) const; /** Return an allocator for allocating memory */ diff --git a/sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h b/sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h index aca25bb00..e586ccd07 100644 --- a/sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h +++ b/sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h @@ -7,27 +7,28 @@ #include +#include #include #include #include #include -#include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #endif +#include // NOLINT + #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/math.h" +#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h" #include "sherpa-onnx/csrc/online-cnn-bilstm-model.h" #include "sherpa-onnx/csrc/online-punctuation-impl.h" #include "sherpa-onnx/csrc/online-punctuation.h" -#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h" -#include "sherpa-onnx/csrc/text-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/text-utils.h" #include "ssentencepiece/csrc/ssentencepiece.h" -#include // NOLINT namespace sherpa_onnx { @@ -35,25 +36,24 @@ static const int32_t kMaxSeqLen = 200; class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { public: - explicit OnlinePunctuationCNNBiLSTMImpl( - const OnlinePunctuationConfig &config) + explicit OnlinePunctuationCNNBiLSTMImpl(const OnlinePunctuationConfig &config) : config_(config), model_(config.model) { - if (!config_.model.bpe_vocab.empty()) { - bpe_encoder_ = std::make_unique( - config_.model.bpe_vocab); - } - } + if (!config_.model.bpe_vocab.empty()) { + bpe_encoder_ = std::make_unique( + config_.model.bpe_vocab); + } + } #if __ANDROID_API__ >= 9 OnlinePunctuationCNNBiLSTMImpl(AAssetManager *mgr, - const OnlinePunctuationConfig &config) + const OnlinePunctuationConfig &config) : config_(config), model_(mgr, config.model) { - if (!config_.model.bpe_vocab.empty()) { - auto buf = ReadFile(mgr, config_.model.bpe_vocab); - std::istringstream iss(std::string(buf.begin(), buf.end())); - bpe_encoder_ = std::make_unique(iss); - } - } + if (!config_.model.bpe_vocab.empty()) { + auto buf = ReadFile(mgr, config_.model.bpe_vocab); + std::istringstream iss(std::string(buf.begin(), buf.end())); + bpe_encoder_ = std::make_unique(iss); + } + } #endif std::string AddPunctuationWithCase(const std::string &text) const override { @@ -61,9 +61,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { return {}; } - std::vector tokens_list; // N * kMaxSeqLen - std::vector valids_list; // N * kMaxSeqLen - std::vector label_len_list; // N + std::vector tokens_list; // N * kMaxSeqLen + std::vector valids_list; // N * kMaxSeqLen + std::vector label_len_list; // N EncodeSentences(text, tokens_list, valids_list, label_len_list); @@ -75,34 +75,43 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { int32_t n = label_len_list.size(); std::array token_ids_shape = {n, kMaxSeqLen}; - Ort::Value token_ids = Ort::Value::CreateTensor(memory_info, tokens_list.data(), tokens_list.size(), - token_ids_shape.data(), token_ids_shape.size()); + Ort::Value token_ids = Ort::Value::CreateTensor( + memory_info, tokens_list.data(), tokens_list.size(), + token_ids_shape.data(), token_ids_shape.size()); std::array valid_ids_shape = {n, kMaxSeqLen}; - Ort::Value valid_ids = Ort::Value::CreateTensor(memory_info, valids_list.data(), valids_list.size(), - valid_ids_shape.data(), valid_ids_shape.size()); + Ort::Value valid_ids = Ort::Value::CreateTensor( + memory_info, valids_list.data(), valids_list.size(), + valid_ids_shape.data(), valid_ids_shape.size()); std::array label_len_shape = {n}; - Ort::Value label_len = Ort::Value::CreateTensor(memory_info, label_len_list.data(), label_len_list.size(), - label_len_shape.data(), label_len_shape.size()); + Ort::Value label_len = Ort::Value::CreateTensor( + memory_info, label_len_list.data(), label_len_list.size(), + label_len_shape.data(), label_len_shape.size()); - auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), std::move(label_len)); + auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), + std::move(label_len)); std::vector case_pred; std::vector punct_pred; - const float* active_case_logits = pair.first.GetTensorData(); - const float* active_punct_logits = pair.second.GetTensorData(); - std::vector case_logits_shape = pair.first.GetTensorTypeAndShapeInfo().GetShape(); + const float *active_case_logits = pair.first.GetTensorData(); + const float *active_punct_logits = pair.second.GetTensorData(); + std::vector case_logits_shape = + pair.first.GetTensorTypeAndShapeInfo().GetShape(); for (int32_t i = 0; i < case_logits_shape[0]; ++i) { - const float* p_cur_case = active_case_logits + i * meta_data.num_cases; + const float *p_cur_case = active_case_logits + i * meta_data.num_cases; auto index_case = static_cast(std::distance( - p_cur_case, std::max_element(p_cur_case, p_cur_case + meta_data.num_cases))); + p_cur_case, + std::max_element(p_cur_case, p_cur_case + meta_data.num_cases))); case_pred.push_back(index_case); - const float* p_cur_punct = active_punct_logits + i * meta_data.num_punctuations; + const float *p_cur_punct = + active_punct_logits + i * meta_data.num_punctuations; auto index_punct = static_cast(std::distance( - p_cur_punct, std::max_element(p_cur_punct, p_cur_punct + meta_data.num_punctuations))); + p_cur_punct, + std::max_element(p_cur_punct, + p_cur_punct + meta_data.num_punctuations))); punct_pred.push_back(index_punct); } @@ -112,60 +121,60 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { } private: - void EncodeSentences(const std::string& text, - std::vector& tokens_list, - std::vector& valids_list, - std::vector& label_len_list) const { + void EncodeSentences(const std::string &text, + std::vector &tokens_list, // NOLINT + std::vector &valids_list, // NOLINT + std::vector &label_len_list) const { // NOLINT std::vector tokens; std::vector valids; int32_t label_len = 0; - tokens.push_back(1); // hardcode 1 now, 1 - + tokens.push_back(1); // hardcode 1 now, 1 - valids.push_back(1); std::stringstream ss(text); std::string word; while (ss >> word) { - std::vector word_tokens; - bpe_encoder_->Encode(word, &word_tokens); + std::vector word_tokens; + bpe_encoder_->Encode(word, &word_tokens); - int32_t seq_len = tokens.size() + word_tokens.size(); - if (seq_len > kMaxSeqLen - 1) { - tokens.push_back(2); // hardcode 2 now, 2 - - valids.push_back(1); + int32_t seq_len = tokens.size() + word_tokens.size(); + if (seq_len > kMaxSeqLen - 1) { + tokens.push_back(2); // hardcode 2 now, 2 - + valids.push_back(1); - label_len = std::count(valids.begin(), valids.end(), 1); + label_len = std::count(valids.begin(), valids.end(), 1); - if (tokens.size() < kMaxSeqLen) { - tokens.resize(kMaxSeqLen, 0); - valids.resize(kMaxSeqLen, 0); - } + if (tokens.size() < kMaxSeqLen) { + tokens.resize(kMaxSeqLen, 0); + valids.resize(kMaxSeqLen, 0); + } - assert(tokens.size() == kMaxSeqLen); - assert(valids.size() == kMaxSeqLen); + assert(tokens.size() == kMaxSeqLen); + assert(valids.size() == kMaxSeqLen); - tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); - valids_list.insert(valids_list.end(), valids.begin(), valids.end()); - label_len_list.push_back(label_len); + tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); + valids_list.insert(valids_list.end(), valids.begin(), valids.end()); + label_len_list.push_back(label_len); - std::vector().swap(tokens); - std::vector().swap(valids); - label_len = 0; - tokens.push_back(1); // hardcode 1 now, 1 - - valids.push_back(1); - } + std::vector().swap(tokens); + std::vector().swap(valids); + label_len = 0; + tokens.push_back(1); // hardcode 1 now, 1 - + valids.push_back(1); + } - tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end()); - valids.push_back(1); // only the first sub word is valid - int32_t remaining_size = static_cast(word_tokens.size()) - 1; - if (remaining_size > 0) { - int32_t valids_cur_size = static_cast(valids.size()); - valids.resize(valids_cur_size + remaining_size, 0); - } + tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end()); + valids.push_back(1); // only the first sub word is valid + int32_t remaining_size = static_cast(word_tokens.size()) - 1; + if (remaining_size > 0) { + int32_t valids_cur_size = static_cast(valids.size()); + valids.resize(valids_cur_size + remaining_size, 0); + } } if (tokens.size() > 0) { - tokens.push_back(2); // hardcode 2 now, 2 - + tokens.push_back(2); // hardcode 2 now, 2 - valids.push_back(1); label_len = std::count(valids.begin(), valids.end(), 1); @@ -176,17 +185,17 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { } assert(tokens.size() == kMaxSeqLen); - assert(valids.size() == kMaxSeqLen); + assert(valids.size() == kMaxSeqLen); tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); valids_list.insert(valids_list.end(), valids.begin(), valids.end()); label_len_list.push_back(label_len); - } + } } - std::string DecodeSentences(const std::string& raw_text, - const std::vector& case_pred, - const std::vector& punct_pred) const { + std::string DecodeSentences(const std::string &raw_text, + const std::vector &case_pred, + const std::vector &punct_pred) const { std::string result_text; std::istringstream iss(raw_text); std::vector words; @@ -203,28 +212,29 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { std::string prefix = ((i != 0) ? " " : ""); result_text += prefix; switch (case_pred[i]) { - case 1: // upper + case 1: // upper { - std::transform(words[i].begin(), words[i].end(), words[i].begin(), [](auto c){ return std::toupper(c); }); + std::transform(words[i].begin(), words[i].end(), words[i].begin(), + [](auto c) { return std::toupper(c); }); result_text += words[i]; break; } - case 2: // cap + case 2: // cap { words[i][0] = std::toupper(words[i][0]); result_text += words[i]; break; } - case 3: // mix case + case 3: // mix case { - // TODO: - // Need to add a map containing supported mix case words so that we can fetch the predicted word from the map - // e.g. mcdonald's -> McDonald's + // TODO(frankyoujian): + // Need to add a map containing supported mix case words so that we + // can fetch the predicted word from the map e.g. mcdonald's -> + // McDonald's result_text += words[i]; break; } - default: - { + default: { result_text += words[i]; break; } @@ -232,17 +242,17 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { std::string suffix; switch (punct_pred[i]) { - case 1: // comma + case 1: // comma { suffix = ","; break; } - case 2: // period + case 2: // period { suffix = "."; break; } - case 3: // question + case 3: // question { suffix = "?"; break; @@ -252,9 +262,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { } result_text += suffix; - } + } - return result_text; + return result_text; } private: diff --git a/sherpa-onnx/csrc/online-punctuation-impl.cc b/sherpa-onnx/csrc/online-punctuation-impl.cc index 2ff0050b8..ebdbc8487 100644 --- a/sherpa-onnx/csrc/online-punctuation-impl.cc +++ b/sherpa-onnx/csrc/online-punctuation-impl.cc @@ -20,7 +20,9 @@ std::unique_ptr OnlinePunctuationImpl::Create( return std::make_unique(config); } - SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer"); + SHERPA_ONNX_LOGE( + "Please specify a punctuation model and bpe vocab! Return a null " + "pointer"); return nullptr; } @@ -31,7 +33,9 @@ std::unique_ptr OnlinePunctuationImpl::Create( return std::make_unique(mgr, config); } - SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer"); + SHERPA_ONNX_LOGE( + "Please specify a punctuation model and bpe vocab! Return a null " + "pointer"); return nullptr; } #endif diff --git a/sherpa-onnx/csrc/online-punctuation-model-config.cc b/sherpa-onnx/csrc/online-punctuation-model-config.cc index 8c8b2a309..5dab600fd 100644 --- a/sherpa-onnx/csrc/online-punctuation-model-config.cc +++ b/sherpa-onnx/csrc/online-punctuation-model-config.cc @@ -13,8 +13,7 @@ void OnlinePunctuationModelConfig::Register(ParseOptions *po) { po->Register("cnn-bilstm", &cnn_bilstm, "Path to the light-weight CNN-BiLSTM model"); - po->Register("bpe-vocab", &bpe_vocab, - "Path to the bpe vocab file"); + po->Register("bpe-vocab", &bpe_vocab, "Path to the bpe vocab file"); po->Register("num-threads", &num_threads, "Number of threads to run the neural network"); @@ -33,8 +32,7 @@ bool OnlinePunctuationModelConfig::Validate() const { } if (!FileExists(cnn_bilstm)) { - SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist", - cnn_bilstm.c_str()); + SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist", cnn_bilstm.c_str()); return false; } @@ -44,8 +42,7 @@ bool OnlinePunctuationModelConfig::Validate() const { } if (!FileExists(bpe_vocab)) { - SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist", - bpe_vocab.c_str()); + SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist", bpe_vocab.c_str()); return false; } diff --git a/sherpa-onnx/csrc/online-punctuation-model-config.h b/sherpa-onnx/csrc/online-punctuation-model-config.h index 2ee2c7c34..b2d26ce9d 100644 --- a/sherpa-onnx/csrc/online-punctuation-model-config.h +++ b/sherpa-onnx/csrc/online-punctuation-model-config.h @@ -22,9 +22,9 @@ struct OnlinePunctuationModelConfig { OnlinePunctuationModelConfig() = default; OnlinePunctuationModelConfig(const std::string &cnn_bilstm, - const std::string &bpe_vocab, - int32_t num_threads, bool debug, - const std::string &provider) + const std::string &bpe_vocab, + int32_t num_threads, bool debug, + const std::string &provider) : cnn_bilstm(cnn_bilstm), bpe_vocab(bpe_vocab), num_threads(num_threads), diff --git a/sherpa-onnx/csrc/online-punctuation.cc b/sherpa-onnx/csrc/online-punctuation.cc index 754870a3e..6435b1c4c 100644 --- a/sherpa-onnx/csrc/online-punctuation.cc +++ b/sherpa-onnx/csrc/online-punctuation.cc @@ -14,9 +14,7 @@ namespace sherpa_onnx { -void OnlinePunctuationConfig::Register(ParseOptions *po) { - model.Register(po); -} +void OnlinePunctuationConfig::Register(ParseOptions *po) { model.Register(po); } bool OnlinePunctuationConfig::Validate() const { if (!model.Validate()) { @@ -40,13 +38,14 @@ OnlinePunctuation::OnlinePunctuation(const OnlinePunctuationConfig &config) #if __ANDROID_API__ >= 9 OnlinePunctuation::OnlinePunctuation(AAssetManager *mgr, - const OnlinePunctuationConfig &config) + const OnlinePunctuationConfig &config) : impl_(OnlinePunctuationImpl::Create(mgr, config)) {} #endif OnlinePunctuation::~OnlinePunctuation() = default; -std::string OnlinePunctuation::AddPunctuationWithCase(const std::string &text) const { +std::string OnlinePunctuation::AddPunctuationWithCase( + const std::string &text) const { return impl_->AddPunctuationWithCase(text); } diff --git a/sherpa-onnx/csrc/online-punctuation.h b/sherpa-onnx/csrc/online-punctuation.h index a70d336c6..f5d2a1440 100644 --- a/sherpa-onnx/csrc/online-punctuation.h +++ b/sherpa-onnx/csrc/online-punctuation.h @@ -40,8 +40,7 @@ class OnlinePunctuation { explicit OnlinePunctuation(const OnlinePunctuationConfig &config); #if __ANDROID_API__ >= 9 - OnlinePunctuation(AAssetManager *mgr, - const OnlinePunctuationConfig &config); + OnlinePunctuation(AAssetManager *mgr, const OnlinePunctuationConfig &config); #endif ~OnlinePunctuation(); diff --git a/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc b/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc index 11f21c362..aef469e5d 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc @@ -3,9 +3,9 @@ // Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) #include -#include #include // NOLINT +#include #include "sherpa-onnx/csrc/online-punctuation.h" #include "sherpa-onnx/csrc/parse-options.h" @@ -57,7 +57,7 @@ The output text should look like below: std::string text = po.GetArg(1); std::string text_with_punct_case = punct.AddPunctuationWithCase(text); - + const auto end = std::chrono::steady_clock::now(); fprintf(stderr, "Done\n");