From c741fe85295a863f97bf48141f4615586722265a Mon Sep 17 00:00:00 2001 From: 916BGAI <916772719@qq.com> Date: Mon, 25 Nov 2024 11:19:19 +0800 Subject: [PATCH] optimize speech and add skip frame api --- components/nn/include/maix_nn_speech.hpp | 288 ++++++++++++----------- examples/nn_speech/main/src/main.cpp | 2 +- 2 files changed, 153 insertions(+), 137 deletions(-) diff --git a/components/nn/include/maix_nn_speech.hpp b/components/nn/include/maix_nn_speech.hpp index af0ffeff..68ecfaf5 100644 --- a/components/nn/include/maix_nn_speech.hpp +++ b/components/nn/include/maix_nn_speech.hpp @@ -15,20 +15,20 @@ #include #include -static std::function, int)> _raw_callback; -static std::function _digit_callback; -static std::function, int)> _kws_callback; -static std::function, int)> _lvcsr_callback; - -#ifdef PLATFORM_MAIXCAM namespace maix::nn { +static bool _is_skip_frames {false}; +static std::function>, int)> raw_callback; +static std::function digit_callback; +static std::function, int)> kws_callback; +static std::function, int)> lvcsr_callback; + /** * @brief speech device * @maixpy maix.nn.SpeechDevice */ -enum SpeechDevice { +enum class SpeechDevice { DEVICE_NONE = -1, DEVICE_PCM, DEVICE_MIC, @@ -39,7 +39,7 @@ enum SpeechDevice { * @brief speech decoder type * @maixpy maix.nn.SpeechDecoder */ -enum SpeechDecoder { +enum class SpeechDecoder { DECODER_RAW = 1, DECODER_DIG = 2, DECODER_LVCSR = 4, @@ -47,6 +47,7 @@ enum SpeechDecoder { DECODER_ALL = 65535, }; +#ifdef PLATFORM_MAIXCAM /** * Speech * @maixpy maix.nn.Speech @@ -77,7 +78,7 @@ enum SpeechDecoder { ~Speech() { - if (_dev_type != DEVICE_NONE) { + if (_dev_type != SpeechDevice::DEVICE_NONE) { this->deinit(); } @@ -111,13 +112,13 @@ enum SpeechDecoder { { if (_extra_info["model_type"] != "speech") { - log::error("model_type not match, expect 'speech', but got '%s'", _extra_info["model_type"].c_str()); + log::error("model_type not match, expect 'speech', but got '%s'.", _extra_info["model_type"].c_str()); return err::ERR_ARGS; } } else { - log::error("model_type key not found"); + log::error("model_type key not found."); return err::ERR_ARGS; } if (_extra_info.find("mean") != _extra_info.end()) @@ -132,14 +133,14 @@ enum SpeechDecoder { } catch (std::exception &e) { - log::error("mean value error, should float"); + log::error("mean value error, should float."); return err::ERR_ARGS; } } } else { - log::error("mean key not found"); + log::error("mean key not found."); return err::ERR_ARGS; } if (_extra_info.find("scale") != _extra_info.end()) @@ -154,14 +155,14 @@ enum SpeechDecoder { } catch (std::exception &e) { - log::error("scale value error, should float"); + log::error("scale value error, should float."); return err::ERR_ARGS; } } } else { - log::error("scale key not found"); + log::error("scale key not found."); return err::ERR_ARGS; } _inputs = _model->inputs_info(); @@ -179,31 +180,43 @@ enum SpeechDecoder { * @return err::Err type, if init success, return err::ERR_NONE * @maixpy maix.nn.Speech.init */ - err::Err init(nn::SpeechDevice dev_type, const string &device_name) + err::Err init(nn::SpeechDevice dev_type, const string &device_name = "") { + string _device_name = device_name; + if (_model_path == "") { - log::error("please load am model first\n"); + log::error("please load am model first."); throw err::Exception(err::ERR_NOT_IMPL); } am_args_t am_args = {(char*)_model_path.c_str(), 192, 6, 6, CN_PNYTONE, 1}; - if (this->dev_type() != DEVICE_NONE) { - log::error("device has been initialized, please use maix.nn.Speech.devive to reset devive\n"); + if (this->dev_type() != SpeechDevice::DEVICE_NONE) { + log::error("device has been initialized, please use maix.nn.Speech.devive to reset devive."); return err::ERR_RUNTIME; } - if(dev_type > 2) { - log::error("not support device %d\n", dev_type); + if((int)dev_type > 2) { + log::error("not support device %d.", dev_type); throw err::Exception(err::ERR_NOT_IMPL); + } else if (dev_type == SpeechDevice::DEVICE_MIC && device_name == "") { + _dev_type = dev_type; + _device_name = "hw:0,0"; + } else if (dev_type == SpeechDevice::DEVICE_PCM && device_name == "") { + log::error("please enter the correct path to the PCM file."); + return err::ERR_ARGS; + } else if (dev_type == SpeechDevice::DEVICE_WAV && device_name == "") { + log::error("please enter the correct path to the WAV file."); + return err::ERR_ARGS; } else { _dev_type = dev_type; + _device_name = device_name; } - int ret = ms_asr_init(_dev_type, (char*)device_name.c_str(), &am_args, 0); + int ret = ms_asr_init((int)_dev_type, (char*)_device_name.c_str(), &am_args, 0); if(ret) { - log::error("asr init error!\n"); - _dev_type = DEVICE_NONE; + log::error("asr init error!"); + _dev_type = SpeechDevice::DEVICE_NONE; return err::ERR_NOT_IMPL; } @@ -221,42 +234,23 @@ enum SpeechDecoder { */ err::Err devive(nn::SpeechDevice dev_type, const string &device_name) { - if(dev_type > 2) { - log::error("not support device %d\n", dev_type); + if((int)dev_type > 2) { + log::error("not support device %d.", dev_type); throw err::Exception(err::ERR_NOT_IMPL); } else { _dev_type = dev_type; } - int ret = ms_asr_set_dev(_dev_type, (char*)device_name.c_str()); + int ret = ms_asr_set_dev((int)_dev_type, (char*)device_name.c_str()); if(ret) { log::error("set devive error!\n"); - _dev_type = DEVICE_NONE; + _dev_type = SpeechDevice::DEVICE_NONE; return err::ERR_NOT_IMPL; } return err::ERR_NONE; } - /** - * Deinit the ASR library. - * @maixpy maix.nn.Speech.deinit - */ - void deinit() - { - _dev_type = DEVICE_NONE; - _decoder_raw = false; - _decoder_dig = false; - _decoder_lvcsr = false; - _decoder_kws = false; - _raw_callback = nullptr; - _digit_callback = nullptr; - _kws_callback = nullptr; - _lvcsr_callback = nullptr; - ms_asr_deinit(); - sys::register_default_signal_handle(); - } - /** * Deinit the decoder. * @param decoder decoder type want to deinit @@ -266,7 +260,7 @@ enum SpeechDecoder { */ void dec_deinit(nn::SpeechDecoder decoder) { - ms_asr_decoder_cfg(decoder, NULL , NULL, 0); + ms_asr_decoder_cfg((int)decoder, NULL , NULL, 0); switch (decoder) { case nn::SpeechDecoder::DECODER_RAW: @@ -288,7 +282,7 @@ enum SpeechDecoder { _decoder_kws = false; break; default: - log::error("not support decoder %d\n", decoder); + log::error("not support decoder %d.", decoder); throw err::Exception(err::ERR_NOT_IMPL); } } @@ -299,12 +293,17 @@ enum SpeechDecoder { * @return err::Err type, if init success, return err::ERR_NONE * @maixpy maix.nn.Speech.raw */ - err::Err raw(std::function, int)> callback) + err::Err raw(std::function>, int)> callback) { + if (this->dev_type() == SpeechDevice::DEVICE_NONE) { + log::error("please init a type of audio device first."); + return err::ERR_NOT_INIT; + } + _raw_callback = callback; - int ret = ms_asr_decoder_cfg(nn::SpeechDecoder::DECODER_RAW, raw_callback_wrapper , NULL, 0); + int ret = ms_asr_decoder_cfg((int)nn::SpeechDecoder::DECODER_RAW, raw_callback_wrapper , NULL, 0); if (ret != 0) { - log::error("raw decoder init error"); + log::error("raw decoder init error."); return err::ERR_RUNTIME; } else { _decoder_raw = true; @@ -328,14 +327,19 @@ enum SpeechDecoder { */ err::Err digit(int blank, std::function callback) { + if (this->dev_type() == SpeechDevice::DEVICE_NONE) { + log::error("please init a type of audio device first."); + return err::ERR_NOT_INIT; + } + size_t decoder_args[10]; decoder_args[0] = blank; _digit_callback = callback; - int ret = ms_asr_decoder_cfg(nn::SpeechDecoder::DECODER_DIG, digit_callback_wrapper, &decoder_args, 1); + int ret = ms_asr_decoder_cfg((int)nn::SpeechDecoder::DECODER_DIG, digit_callback_wrapper, &decoder_args, 1); if (ret != 0) { - log::error("Digit decoder init error"); + log::error("digit decoder init error."); return err::ERR_RUNTIME; } else { _decoder_dig = true; @@ -363,8 +367,13 @@ enum SpeechDecoder { */ err::Err kws(std::vector kw_tbl, std::vector kw_gate, std::function, int)> callback, bool auto_similar = true) { + if (this->dev_type() == SpeechDevice::DEVICE_NONE) { + log::error("please init a type of audio device first."); + return err::ERR_NOT_INIT; + } + if (kw_tbl.size() != kw_gate.size()) { - log::error("kw_tbl num must equal to kw_gate num"); + log::error("kw_tbl num must equal to kw_gate num."); return err::ERR_ARGS; } @@ -384,7 +393,7 @@ enum SpeechDecoder { decoder_args[2] = kw_tbl.size(); decoder_args[3] = auto_similar; _kws_callback = callback; - int ret = ms_asr_decoder_cfg(nn::SpeechDecoder::DECODER_KWS, kws_callback_wrapper, &decoder_args, 3); + int ret = ms_asr_decoder_cfg((int)nn::SpeechDecoder::DECODER_KWS, kws_callback_wrapper, &decoder_args, 3); delete[] _kw_gate; for (size_t i = 0; i < kw_tbl.size(); ++i) { @@ -393,7 +402,7 @@ enum SpeechDecoder { delete[] _kw_tbl; if (ret != 0) { - log::error("kws decoder init error"); + log::error("kws decoder init error."); return err::ERR_RUNTIME; } else { _decoder_kws = true; @@ -430,6 +439,11 @@ enum SpeechDecoder { std::function, int)> callback, float beam = 8, float bg_prob = 10, float scale = 0.5, bool mmap = false) { + if (this->dev_type() == SpeechDevice::DEVICE_NONE) { + log::error("please init a type of audio device first."); + return err::ERR_NOT_INIT; + } + size_t decoder_args[10]; decoder_args[0] = (size_t)sfst_name.c_str(); decoder_args[1] = (size_t)sym_name.c_str(); @@ -441,9 +455,9 @@ enum SpeechDecoder { decoder_args[7] = mmap; _lvcsr_callback = callback; - int ret = ms_asr_decoder_cfg(nn::SpeechDecoder::DECODER_LVCSR, lvcsr_callback_wrapper, &decoder_args, 8); + int ret = ms_asr_decoder_cfg((int)nn::SpeechDecoder::DECODER_LVCSR, lvcsr_callback_wrapper, &decoder_args, 8); if (ret != 0) { - log::error("lvcsr decoder init error"); + log::error("lvcsr decoder init error."); return err::ERR_RUNTIME; } else { _decoder_lvcsr = true; @@ -467,7 +481,24 @@ enum SpeechDecoder { */ int run(int frame) { + if (!(raw() || digit() || kws() || lvcsr())) { + log::error("please init at least one decoder before running."); + return 0; + } + + raw_callback = this->_raw_callback; + digit_callback = this->_digit_callback; + kws_callback = this->_kws_callback; + lvcsr_callback = this->_lvcsr_callback; + int frames = ms_asr_run(frame); + + // Set it to nullptr, otherwise MaixPy cannot exit properly. + raw_callback = nullptr; + digit_callback = nullptr; + kws_callback = nullptr; + lvcsr_callback = nullptr; + return frames; } @@ -490,19 +521,6 @@ enum SpeechDecoder { return ms_asr_get_frame_time(); } - /** - * Get the acoustic model dictionary. - * @return std::pair type, return the dictionary and length. - * @maixpy maix.nn.Speech.vocab - */ - std::pair vocab() - { - char* dummy; - int vocab_cnt; - ms_asr_get_am_vocab(&dummy, &vocab_cnt); - return {dummy, vocab_cnt}; - } - /** * Manually register mute words, and each pinyin can register up to 10 homophones, * please note that using this interface to register homophones will overwrite, @@ -515,7 +533,7 @@ enum SpeechDecoder { err::Err similar(const string &pny, std::vector similar_pnys) { if (this->kws() != true) { - log::error("please init kws decoder first"); + log::error("please init kws decoder first."); return err::ERR_RUNTIME; } @@ -533,13 +551,25 @@ enum SpeechDecoder { delete[] _similar_pnys; if (ret != 0) { - log::error("set similar pny error"); + log::error("set similar pny error."); return err::ERR_RUNTIME; } else { return err::ERR_NONE; } } + /** + * Run some frames and drop, this can be used to avoid + * incorrect recognition results when switching decoders. + * @param num number of frames to run and drop + * @maixpy maix.nn.Speech.skip_frames + */ + void skip_frames(int num) { + _is_skip_frames = true; + this->run(num); + _is_skip_frames = false; + } + public: /** * Get mean value, list type @@ -576,46 +606,66 @@ enum SpeechDecoder { std::map _extra_info; image::Size _input_size; std::vector _inputs; - nn::SpeechDevice _dev_type = DEVICE_NONE; + nn::SpeechDevice _dev_type = SpeechDevice::DEVICE_NONE; bool _decoder_raw = false; bool _decoder_dig = false; bool _decoder_kws = false; bool _decoder_lvcsr = false; + std::function>, int)> _raw_callback; + std::function _digit_callback; + std::function, int)> _kws_callback; + std::function, int)> _lvcsr_callback; + + void deinit() + { + _dev_type = SpeechDevice::DEVICE_NONE; + _decoder_raw = false; + _decoder_dig = false; + _decoder_lvcsr = false; + _decoder_kws = false; + _raw_callback = nullptr; + _digit_callback = nullptr; + _kws_callback = nullptr; + _lvcsr_callback = nullptr; + ms_asr_deinit(); + sys::register_default_signal_handle(); + } + static void digit_callback_wrapper(void* data, int cnt) { - if (_digit_callback) { - _digit_callback(static_cast(data), cnt); + if (digit_callback && _is_skip_frames == false) { + digit_callback(static_cast(data), cnt); } } static void kws_callback_wrapper(void* data, int cnt) { - if (_kws_callback) { + if (kws_callback && _is_skip_frames == false) { std::vector kws_data; float* p = (float*) data; for(int i=0; i raw_data; + if (raw_callback && _is_skip_frames == false) { + std::vector> raw_data; pnyp_t* res = (pnyp_t*)data; for(int t=0; tidx, pp->p}); } - _raw_callback(raw_data, cnt); + raw_callback(raw_data, cnt); } } static void lvcsr_callback_wrapper(void* data, int cnt) { - if (_lvcsr_callback) { + if (lvcsr_callback && _is_skip_frames == false) { char* words = ((char**)data)[0]; char* pnys = ((char**)data)[1]; - _lvcsr_callback({words, pnys}, cnt); + lvcsr_callback({words, pnys}, cnt); } } @@ -647,32 +697,6 @@ enum SpeechDecoder { #endif #ifdef PLATFORM_LINUX -namespace maix::nn -{ - -/** - * @brief speech device - * @maixpy maix.nn.SpeechDevice - */ -enum SpeechDevice { - DEVICE_NONE = -1, - DEVICE_PCM, - DEVICE_MIC, - DEVICE_WAV, -}; - -/** - * @brief speech decoder type - * @maixpy maix.nn.SpeechDecoder - */ -enum SpeechDecoder { - DECODER_RAW = 1, - DECODER_DIG = 2, - DECODER_LVCSR = 4, - DECODER_KWS = 8, - DECODER_ALL = 65535, -}; - /** * Speech * @maixpy maix.nn.Speech @@ -717,7 +741,7 @@ enum SpeechDecoder { * @return err::Err type, if init success, return err::ERR_NONE * @maixpy maix.nn.Speech.init */ - err::Err init(nn::SpeechDevice dev_type, const string &device_name) + err::Err init(nn::SpeechDevice dev_type, const string &device_name = "") { return err::ERR_NONE; } @@ -736,19 +760,6 @@ enum SpeechDecoder { return err::ERR_NONE; } - /** - * Deinit the ASR library. - * @maixpy maix.nn.Speech.deinit - */ - void deinit() - { - _dev_type = DEVICE_NONE; - _decoder_raw = false; - _decoder_dig = false; - _decoder_lvcsr = false; - _decoder_kws = false; - } - /** * Deinit the decoder. * @param decoder decoder type want to deinit @@ -884,16 +895,6 @@ enum SpeechDecoder { return 0; } - /** - * Get the acoustic model dictionary. - * @return std::pair type, return the dictionary and length. - * @maixpy maix.nn.Speech.vocab - */ - std::pair vocab() - { - return {nullptr, 0}; - } - /** * Manually register mute words, and each pinyin can register up to 10 homophones, * please note that using this interface to register homophones will overwrite, @@ -908,6 +909,16 @@ enum SpeechDecoder { return err::ERR_NONE; } + /** + * Run some frames and drop, this can be used to avoid + * incorrect recognition results when switching decoders. + * @param num number of frames to run and drop + * @maixpy maix.nn.Speech.skip_frames + */ + void skip_frames(int num) { + return; + } + public: /** * Get mean value, list type @@ -934,12 +945,17 @@ enum SpeechDecoder { std::map _extra_info; image::Size _input_size; std::vector _inputs; - nn::SpeechDevice _dev_type = DEVICE_NONE; + nn::SpeechDevice _dev_type = SpeechDevice::DEVICE_NONE; bool _decoder_raw = false; bool _decoder_dig = false; bool _decoder_kws = false; bool _decoder_lvcsr = false; + void deinit() + { + return ; + } + static void digit_callback_wrapper(void* data, int cnt) { } diff --git a/examples/nn_speech/main/src/main.cpp b/examples/nn_speech/main/src/main.cpp index 4fbe47ac..0d92c346 100644 --- a/examples/nn_speech/main/src/main.cpp +++ b/examples/nn_speech/main/src/main.cpp @@ -36,7 +36,7 @@ int _main(int argc, char* argv[]) { nn::Speech speech("/root/models/am_3332_192_int8.mud"); - speech.init(nn::SpeechDevice::DEVICE_MIC, "hw:0,0"); // use mic device + speech.init(nn::SpeechDevice::DEVICE_MIC); // use mic device // speech.init(nn::SpeechDevice::DEVICE_WAV, "test.wav"); // use wav file // speech.init(nn::SpeechDevice::DEVICE_PCM, "test.pcm"); // use pcm file