Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sangeet2020 committed May 23, 2024
1 parent 7800cc0 commit d47bf6f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 21 deletions.
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
OnlineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OnlineTransducerNeMoModel> model_;
std::unique_ptr<OnlineTransducerDecoder> decoder_;
std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;

int32_t batch_size_ = 1;
};
Expand Down
29 changes: 17 additions & 12 deletions sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ static std::pair<Ort::Value, Ort::Value> BuildDecoderInput(
return {std::move(decoder_input), std::move(decoder_input_length)};
}

OnlineTransducerGreedySearchNeMoDecoder::OnlineTransducerGreedySearchNeMoDecoder(
OnlineTransducerNeMoModel *model, float blank_penalty)
: model_(model), blank_penalty_(blank_penalty) {
// Initialize decoder state
auto init_states = model_->GetDecoderInitStates(1);
decoder_states_ = std::move(init_states);
}

static OnlineTransducerDecoderResult DecodeOne(
// OnlineTransducerGreedySearchNeMoDecoder::OnlineTransducerGreedySearchNeMoDecoder(
// OnlineTransducerNeMoModel *model, float blank_penalty)
// : model_(model), blank_penalty_(blank_penalty) {
// // Initialize decoder state
// auto init_states = model_->GetDecoderInitStates(1);
// decoder_states_ = std::move(init_states);
// }

std::pair<OnlineTransducerDecoderResult, std::vector<Ort::Value>> DecodeOne(
const float *encoder_out, int32_t num_rows, int32_t num_cols,
OnlineTransducerNeMoModel *model, float blank_penalty,
std::vector<Ort::Value>& decoder_states) {
Expand Down Expand Up @@ -97,12 +97,15 @@ static OnlineTransducerDecoderResult DecodeOne(
// Update the decoder states for the next chunk
decoder_states = std::move(decoder_output_pair.second);

return result;
return {result, decoder_states};
}

std::vector<OnlineTransducerDecoderResult>
OnlineTransducerGreedySearchNeMoDecoder::Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length,
Ort::Value encoder_out,
Ort::Value encoder_out_length,
std::vector<Ort::Value> decoder_states,
std::vector<OnlineTransducerDecoderResult> *results,
OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) {

auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape();
Expand All @@ -119,7 +122,9 @@ OnlineTransducerGreedySearchNeMoDecoder::Decode(
const float *this_p = p + dim1 * dim2 * i;
int32_t this_len = p_length[i];

ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states_);
auto decode_result_pair = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states);
ans[i] = decode_result_pair.first;
decoder_states = std::move(decode_result_pair.second); // Update decoder states for next chunk
}

return ans;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ class OnlineTransducerGreedySearchNeMoDecoder : public OnlineTransducerDecoder {
float blank_penalty);

std::vector<OnlineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length,
Ort::Value encoder_out,
Ort::Value encoder_out_length,
std::vector<Ort::Value> decoder_states,
std::vector<OnlineTransducerDecoderResult> *results,
OnlineStream **ss = nullptr, int32_t n = 0);

private:
Expand All @@ -27,7 +30,7 @@ class OnlineTransducerGreedySearchNeMoDecoder : public OnlineTransducerDecoder {

OnlineTransducerNeMoModel *model_; // Not owned
float blank_penalty_;
std::vector<Ort::Value> decoder_states_; // Decoder states to be maintained across chunks
// std::vector<Ort::Value> decoder_states_; // Decoder states to be maintained across chunks
};

} // namespace sherpa_onnx
Expand Down
6 changes: 0 additions & 6 deletions sherpa-onnx/csrc/online-transducer-nemo-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,6 @@ class OnlineTransducerNeMoModel {
std::vector<std::vector<Ort::Value>> UnStackStates(
const std::vector<Ort::Value> &states) const;

// A list of 3 tensors:
// - cache_last_channel
// - cache_last_time
// - cache_last_channel_len
std::vector<Ort::Value> GetInitStates() const;

/** Run the encoder.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
Expand Down

0 comments on commit d47bf6f

Please sign in to comment.