From 909c4af8953d006e9ab428bb262a57a2f5cfb8e2 Mon Sep 17 00:00:00 2001 From: He Peng Date: Tue, 10 Oct 2023 12:12:33 +0800 Subject: [PATCH] Add lm decode for the Python API. --- python-api-examples/online-decode-files.py | 20 +++++++++++++++++++ sherpa-onnx/python/csrc/online-recognizer.cc | 1 + .../python/sherpa_onnx/online_recognizer.py | 15 ++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index cdf7870fb..a03afef31 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -115,6 +115,24 @@ def get_args(): """, ) + parser.add_argument( + "--lm", + type=str, + default=0.1, + help="""Used only when --decoding-method is modified_beam_search. + path of language model. + """, + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.1, + help="""Used only when --decoding-method is modified_beam_search. + scale of language model. + """, + ) + parser.add_argument( "--provider", type=str, @@ -215,6 +233,8 @@ def main(): feature_dim=80, decoding_method=args.decoding_method, max_active_paths=args.max_active_paths, + lm=args.lm, + scale=args.lm_scale, hotwords_file=args.hotwords_file, hotwords_score=args.hotwords_score, ) diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index 68e97b60a..9cfce8456 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -37,6 +37,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { py::arg("hotwords_score") = 0) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) + .def_readwrite("lm_config", &PyClass::lm_config) .def_readwrite("endpoint_config", &PyClass::endpoint_config) .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) .def_readwrite("decoding_method", &PyClass::decoding_method) diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index eabf99ec8..b576f1e61 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -10,6 +10,7 @@ OnlineRecognizer as _Recognizer, OnlineRecognizerConfig, OnlineStream, + OnlineLMConfig, OnlineTransducerModelConfig, ) @@ -46,6 +47,8 @@ def from_transducer( hotwords_file: str = "", provider: str = "cpu", model_type: str = "", + lm:str = "", + scale:float = 0.1, ): """ Please refer to @@ -137,10 +140,22 @@ def from_transducer( "Please use --decoding-method=modified_beam_search when using " f"--hotwords-file. Currently given: {decoding_method}" ) + + if lm and decoding_method != "modified_beam_search": + raise ValueError( + "Please use --decoding-method=modified_beam_search when using " + f"--lm. Currently given: {decoding_method}" + ) + + lm_config = OnlineLMConfig( + model = lm, + scale = scale, + ) recognizer_config = OnlineRecognizerConfig( feat_config=feat_config, model_config=model_config, + lm_config = lm_config , endpoint_config=endpoint_config, enable_endpoint=enable_endpoint_detection, decoding_method=decoding_method,