Skip to content

Commit

Permalink
Re-implement LM rescore for online transducer (k2-fsa#1231)
Browse files Browse the repository at this point in the history
Co-authored-by: Martins Kronis <[email protected]>
  • Loading branch information
SilverSulfide and Martins Kronis authored Sep 6, 2024
1 parent 4323bd2 commit 0b706b4
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 31 deletions.
7 changes: 6 additions & 1 deletion sherpa-onnx/csrc/hypothesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,13 @@ struct Hypothesis {
// LM log prob if any.
double lm_log_prob = 0;

// the nn lm score for next token given the current ys
// the nn lm score for next token given the current ys,
// when using shallow fusion
CopyableOrtValue nn_lm_scores;

// cur scored tokens by RNN LM, when rescoring
int32_t cur_scored_pos = 0;

// the nn lm states
std::vector<CopyableOrtValue> nn_lm_states;

Expand Down
5 changes: 4 additions & 1 deletion sherpa-onnx/csrc/online-lm-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ void OnlineLMConfig::Register(ParseOptions *po) {
"Number of threads to run the neural network of LM model");
po->Register("lm-provider", &lm_provider,
"Specify a provider to LM model use: cpu, cuda, coreml");
po->Register("lm-shallow-fusion", &shallow_fusion,
"Boolean whether to use shallow fusion or rescore.");
}

bool OnlineLMConfig::Validate() const {
Expand All @@ -34,7 +36,8 @@ std::string OnlineLMConfig::ToString() const {

os << "OnlineLMConfig(";
os << "model=\"" << model << "\", ";
os << "scale=" << scale << ")";
os << "scale=" << scale << ", ";
os << "shallow_fusion=" << (shallow_fusion ? "True" : "False") << ")";

return os.str();
}
Expand Down
7 changes: 5 additions & 2 deletions sherpa-onnx/csrc/online-lm-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ struct OnlineLMConfig {
float scale = 0.5;
int32_t lm_num_threads = 1;
std::string lm_provider = "cpu";
// enable shallow fusion
bool shallow_fusion = true;

OnlineLMConfig() = default;

OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
const std::string &lm_provider)
const std::string &lm_provider, bool shallow_fusion)
: model(model),
scale(scale),
lm_num_threads(lm_num_threads),
lm_provider(lm_provider) {}
lm_provider(lm_provider),
shallow_fusion(shallow_fusion) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
24 changes: 19 additions & 5 deletions sherpa-onnx/csrc/online-lm.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,41 @@ class OnlineLM {

static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);

virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0;
// init states for classic rescore
virtual std::vector<Ort::Value> GetInitStates() = 0;

/** ScoreToken a batch of sentences.
// init states for shallow fusion
virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() = 0;

/** ScoreToken a batch of sentences (shallow fusion).
*
* @param x A 2-D tensor of shape (N, 1) with data type int64.
* @param states It contains the states for the LM model
* @return Return a pair containingo
* @return Return a pair containing
* - log_prob of NN LM
* - updated states
*
*/
virtual std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
Ort::Value x, std::vector<Ort::Value> states) = 0;

/** This function updates lm_lob_prob and nn_lm_scores of hyp
/** This function updates hyp.lm_log_prob of hyps (classic rescore).
*
* @param scale LM score
* @param context_size Context size of the transducer decoder model
* @param hyps It is changed in-place.
*
*/
virtual void ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) = 0;

/** This function updates lm_log_prob and nn_lm_scores of hyp (shallow fusion).
*
* @param scale LM score
* @param hyps It is changed in-place.
*
*/
virtual void ComputeLMScore(float scale, Hypothesis *hyp) = 0;
virtual void ComputeLMScoreSF(float scale, Hypothesis *hyp) = 0;
};

} // namespace sherpa_onnx
Expand Down
6 changes: 4 additions & 2 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {

decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_, config_.blank_penalty,
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
config_.blank_penalty,
config_.temperature_scale);

} else if (config.decoding_method == "greedy_search") {
Expand Down Expand Up @@ -156,7 +157,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {

decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_, config_.blank_penalty,
config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
config_.blank_penalty,
config_.temperature_scale);

} else if (config.decoding_method == "greedy_search") {
Expand Down
90 changes: 83 additions & 7 deletions sherpa-onnx/csrc/online-rnn-lm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <string>
#include <utility>
#include <vector>
#include <algorithm>

#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
Expand All @@ -27,9 +28,10 @@ class OnlineRnnLM::Impl {
Init(config);
}

void ComputeLMScore(float scale, Hypothesis *hyp) {
// shallow fusion scoring function
void ComputeLMScoreSF(float scale, Hypothesis *hyp) {
if (hyp->nn_lm_states.empty()) {
auto init_states = GetInitStates();
auto init_states = GetInitStatesSF();
hyp->nn_lm_scores.value = std::move(init_states.first);
hyp->nn_lm_states = Convert(std::move(init_states.second));
}
Expand All @@ -49,6 +51,52 @@ class OnlineRnnLM::Impl {
hyp->nn_lm_states = Convert(std::move(lm_out.second));
}

// classic rescore function
void ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) {
Ort::AllocatorWithDefaultOptions allocator;

for (auto &hyp : *hyps) {
for (auto &h_m : hyp) {
auto &h = h_m.second;
auto &ys = h.ys;
const int32_t token_num_in_chunk =
ys.size() - context_size - h.cur_scored_pos - 1;

if (token_num_in_chunk < 1) {
continue;
}

if (h.nn_lm_states.empty()) {
h.nn_lm_states = Convert(GetInitStates());
}

if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
std::array<int64_t, 2> x_shape{1, token_num_in_chunk};

Ort::Value x = Ort::Value::CreateTensor<int64_t>(
allocator, x_shape.data(), x_shape.size());
int64_t *p_x = x.GetTensorMutableData<int64_t>();
std::copy(ys.begin() + context_size + h.cur_scored_pos,
ys.end() - 1, p_x);

// streaming forward by NN LM
auto out = ScoreToken(std::move(x),
Convert(std::move(h.nn_lm_states)));

// update NN LM score in hyp
const float *p_nll = out.first.GetTensorData<float>();
h.lm_log_prob = -scale * (*p_nll);

// update NN LM states in hyp
h.nn_lm_states = Convert(std::move(out.second));

h.cur_scored_pos += token_num_in_chunk;
}
}
}
}

std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
Ort::Value x, std::vector<Ort::Value> states) {
std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states[0]),
Expand All @@ -66,7 +114,8 @@ class OnlineRnnLM::Impl {
return {std::move(out[0]), std::move(next_states)};
}

std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() {
// get init states for shallow fusion
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() {
std::vector<Ort::Value> ans;
ans.reserve(init_states_.size());
for (auto &s : init_states_) {
Expand All @@ -75,6 +124,18 @@ class OnlineRnnLM::Impl {
return {View(&init_scores_.value), std::move(ans)};
}

// get init states for classic rescore
std::vector<Ort::Value> GetInitStates() const {
std::vector<Ort::Value> ans;
ans.reserve(init_states_.size());

for (const auto &s : init_states_) {
ans.emplace_back(Clone(allocator_, &s));
}

return ans;
}

private:
void Init(const OnlineLMConfig &config) {
auto buf = ReadFile(config_.model);
Expand Down Expand Up @@ -116,7 +177,8 @@ class OnlineRnnLM::Impl {
states.push_back(std::move(c));
auto pair = ScoreToken(std::move(x), std::move(states));

init_scores_.value = std::move(pair.first);
init_scores_.value = std::move(pair.first); // only used during
// shallow fusion
init_states_ = std::move(pair.second);
}

Expand Down Expand Up @@ -147,17 +209,31 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)

OnlineRnnLM::~OnlineRnnLM() = default;

std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStates() {
// classic rescore state init
std::vector<Ort::Value> OnlineRnnLM::GetInitStates() {
return impl_->GetInitStates();
}

// shallow fusion state init
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStatesSF() {
return impl_->GetInitStatesSF();
}

std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken(
Ort::Value x, std::vector<Ort::Value> states) {
return impl_->ScoreToken(std::move(x), std::move(states));
}

void OnlineRnnLM::ComputeLMScore(float scale, Hypothesis *hyp) {
return impl_->ComputeLMScore(scale, hyp);
// classic rescore scores
void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) {
return impl_->ComputeLMScore(scale, context_size, hyps);
}

// shallow fusion scores
void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) {
return impl_->ComputeLMScoreSF(scale, hyp);
}


} // namespace sherpa_onnx
24 changes: 19 additions & 5 deletions sherpa-onnx/csrc/online-rnn-lm.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,41 @@ class OnlineRnnLM : public OnlineLM {

explicit OnlineRnnLM(const OnlineLMConfig &config);

std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override;
// init scores for classic rescore
std::vector<Ort::Value> GetInitStates() override;

/** ScoreToken a batch of sentences.
// init scores for shallow fusion
std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStatesSF() override;

/** ScoreToken a batch of sentences (shallow fusion).
*
* @param x A 2-D tensor of shape (N, L) with data type int64.
* @param states It contains the states for the LM model
* @return Return a pair containingo
* @return Return a pair containing
* - log_prob of NN LM
* - updated states
*
*/
std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
Ort::Value x, std::vector<Ort::Value> states) override;

/** This function updates lm_lob_prob and nn_lm_scores of hyp
/** This function updates hyp.lm_lob_prob of hyps (classic rescore).
*
* @param scale LM score
* @param context_size Context size of the transducer decoder model
* @param hyps It is changed in-place.
*
*/
void ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) override;

/** This function updates lm_lob_prob and nn_lm_scores of hyp (shallow fusion).
*
* @param scale LM score
* @param hyps It is changed in-place.
*
*/
void ComputeLMScore(float scale, Hypothesis *hyp) override;
void ComputeLMScoreSF(float scale, Hypothesis *hyp) override;

private:
class Impl;
Expand Down
28 changes: 23 additions & 5 deletions sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(

// add log_prob of each hypothesis to p_logprob before taking top_k
for (int32_t i = 0; i != num_hyps; ++i) {
float log_prob = prev[i].log_prob + prev[i].lm_log_prob;
float log_prob = prev[i].log_prob;
if (lm_ && shallow_fusion_) {
log_prob += prev[i].lm_log_prob;
}

for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
*p_logprob += log_prob;
}
Expand Down Expand Up @@ -192,22 +196,31 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
context_score = std::get<0>(context_res);
new_hyp.context_state = std::get<1>(context_res);
}
if (lm_) {
lm_->ComputeLMScore(lm_scale_, &new_hyp);
if (lm_ && shallow_fusion_) {
lm_->ComputeLMScoreSF(lm_scale_, &new_hyp);
}
} else {
++new_hyp.num_trailing_blanks;
}
new_hyp.log_prob = p_logprob[k] + context_score -
if (lm_ && shallow_fusion_) {
new_hyp.log_prob = p_logprob[k] + context_score -
prev_lm_log_prob; // log_prob only includes the
// score of the transducer
} else {
new_hyp.log_prob = p_logprob[k] + context_score; // rescore or no LM
// previous token
// score is ignored
}

// export the per-token log scores
if (new_token != 0 && new_token != unk_id_) {
float y_prob = logit_with_temperature[start * vocab_size + k];
new_hyp.ys_probs.push_back(y_prob);

if (lm_) { // export only when LM is used
if (lm_ && shallow_fusion_) { // export only if
// LM shallow fusion is used
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;

if (lm_scale_ != 0.0) {
lm_prob /= lm_scale_; // remove lm-scale
}
Expand All @@ -227,6 +240,11 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
} // for (int32_t b = 0; b != batch_size; ++b)
} // for (int32_t t = 0; t != num_frames; ++t)

// classic lm rescore
if (lm_ && !shallow_fusion_) {
lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur);
}

for (int32_t b = 0; b != batch_size; ++b) {
auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(true);
Expand Down
Loading

0 comments on commit 0b706b4

Please sign in to comment.