-
Notifications
You must be signed in to change notification settings - Fork 508
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add online punctuation and casing prediction model for English langua…
…ge (#1224)
- Loading branch information
1 parent
52830cc
commit 1414e4d
Showing
14 changed files
with
874 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h | ||
// | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_ | ||
#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_ | ||
|
||
namespace sherpa_onnx { | ||
|
||
struct OnlineCNNBiLSTMModelMetaData { | ||
int32_t comma_id; | ||
int32_t period_id; | ||
int32_t quest_id; | ||
|
||
int32_t upper_id; | ||
int32_t cap_id; | ||
int32_t mix_case_id; | ||
|
||
int32_t num_cases; | ||
int32_t num_punctuations; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
// sherpa-onnx/csrc/online-cnn-bilstm-model.cc | ||
// | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h" | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "sherpa-onnx/csrc/onnx-utils.h" | ||
#include "sherpa-onnx/csrc/session.h" | ||
#include "sherpa-onnx/csrc/text-utils.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
class OnlineCNNBiLSTMModel::Impl { | ||
public: | ||
explicit Impl(const OnlinePunctuationModelConfig &config) | ||
: config_(config), | ||
env_(ORT_LOGGING_LEVEL_ERROR), | ||
sess_opts_(GetSessionOptions(config)), | ||
allocator_{} { | ||
auto buf = ReadFile(config_.cnn_bilstm); | ||
Init(buf.data(), buf.size()); | ||
} | ||
|
||
#if __ANDROID_API__ >= 9 | ||
Impl(AAssetManager *mgr, const OnlinePunctuationModelConfig &config) | ||
: config_(config), | ||
env_(ORT_LOGGING_LEVEL_ERROR), | ||
sess_opts_(GetSessionOptions(config)), | ||
allocator_{} { | ||
auto buf = ReadFile(mgr, config_.cnn_bilstm); | ||
Init(buf.data(), buf.size()); | ||
} | ||
#endif | ||
|
||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) { | ||
std::array<Ort::Value, 3> inputs = {std::move(token_ids), std::move(valid_ids), std::move(label_lens)}; | ||
|
||
auto ans = | ||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
output_names_ptr_.data(), output_names_ptr_.size()); | ||
return {std::move(ans[0]), std::move(ans[1])}; | ||
} | ||
|
||
OrtAllocator *Allocator() const { return allocator_; } | ||
|
||
const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const { | ||
return meta_data_; | ||
} | ||
|
||
private: | ||
void Init(void *model_data, size_t model_data_length) { | ||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
sess_opts_); | ||
|
||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
|
||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
|
||
// get meta data | ||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
|
||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
|
||
SHERPA_ONNX_READ_META_DATA(meta_data_.comma_id, "COMMA"); | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.period_id, "PERIOD"); | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.quest_id, "QUESTION"); | ||
|
||
// assert here, because we will use the constant value | ||
assert(meta_data_.comma_id == 1); | ||
assert(meta_data_.period_id == 2); | ||
assert(meta_data_.quest_id == 3); | ||
|
||
SHERPA_ONNX_READ_META_DATA(meta_data_.upper_id, "UPPER"); | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.cap_id, "CAP"); | ||
SHERPA_ONNX_READ_META_DATA(meta_data_.mix_case_id, "MIX_CASE"); | ||
|
||
assert(meta_data_.upper_id == 1); | ||
assert(meta_data_.cap_id == 2); | ||
assert(meta_data_.mix_case_id == 3); | ||
|
||
// output shape is (T', num_cases) | ||
meta_data_.num_cases = | ||
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
meta_data_.num_punctuations = | ||
sess_->GetOutputTypeInfo(1).GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
} | ||
|
||
private: | ||
OnlinePunctuationModelConfig config_; | ||
Ort::Env env_; | ||
Ort::SessionOptions sess_opts_; | ||
Ort::AllocatorWithDefaultOptions allocator_; | ||
|
||
std::unique_ptr<Ort::Session> sess_; | ||
|
||
std::vector<std::string> input_names_; | ||
std::vector<const char *> input_names_ptr_; | ||
|
||
std::vector<std::string> output_names_; | ||
std::vector<const char *> output_names_ptr_; | ||
|
||
OnlineCNNBiLSTMModelMetaData meta_data_; | ||
}; | ||
|
||
OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel( | ||
const OnlinePunctuationModelConfig &config) | ||
: impl_(std::make_unique<Impl>(config)) {} | ||
|
||
#if __ANDROID_API__ >= 9 | ||
OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel( | ||
AAssetManager *mgr, const OnlinePunctuationModelConfig &config) | ||
: impl_(std::make_unique<Impl>(mgr, config)) {} | ||
#endif | ||
|
||
OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default; | ||
|
||
std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(Ort::Value token_ids, | ||
Ort::Value valid_ids, | ||
Ort::Value label_lens) const { | ||
return impl_->Forward(std::move(token_ids), std::move(valid_ids), std::move(label_lens)); | ||
} | ||
|
||
OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const { | ||
return impl_->Allocator(); | ||
} | ||
|
||
const OnlineCNNBiLSTMModelMetaData & | ||
OnlineCNNBiLSTMModel::GetModelMetadata() const { | ||
return impl_->GetModelMetadata(); | ||
} | ||
|
||
} // namespace sherpa_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
// sherpa-onnx/csrc/online-cnn-bilstm-model.h | ||
// | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_ | ||
#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_ | ||
#include <memory> | ||
#include <utility> | ||
|
||
#if __ANDROID_API__ >= 9 | ||
#include "android/asset_manager.h" | ||
#include "android/asset_manager_jni.h" | ||
#endif | ||
|
||
#include "onnxruntime_cxx_api.h" // NOLINT | ||
#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h" | ||
#include "sherpa-onnx/csrc/online-punctuation-model-config.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
/** This class implements | ||
* https://github.com/frankyoujian/Edge-Punct-Casing/blob/main/onnx_decode_sentence.py | ||
*/ | ||
class OnlineCNNBiLSTMModel { | ||
public: | ||
explicit OnlineCNNBiLSTMModel( | ||
const OnlinePunctuationModelConfig &config); | ||
|
||
#if __ANDROID_API__ >= 9 | ||
OnlineCNNBiLSTMModel(AAssetManager *mgr, | ||
const OnlinePunctuationModelConfig &config); | ||
#endif | ||
|
||
~OnlineCNNBiLSTMModel(); | ||
|
||
/** Run the forward method of the model. | ||
* | ||
* @param token_ids A tensor of shape (N, T) of dtype int32. | ||
* @param valid_ids A tensor of shape (N, T) of dtype int32. | ||
* @param label_lens A tensor of shape (N) of dtype int32. | ||
* | ||
* @return Return a pair of tensors | ||
* - case_logits: A 2-D tensor of shape (T', num_cases). | ||
* - punct_logits: A 2-D tensor of shape (T', num_puncts). | ||
*/ | ||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const; | ||
|
||
/** Return an allocator for allocating memory | ||
*/ | ||
OrtAllocator *Allocator() const; | ||
|
||
const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const; | ||
|
||
private: | ||
class Impl; | ||
std::unique_ptr<Impl> impl_; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_ |
Oops, something went wrong.