From 7fc4ffd2d3bcb655621ecce4a977820cd79c4f8c Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 14 Dec 2024 21:14:52 +0000 Subject: [PATCH 1/3] Improve ONNX version --- moonshine-onnx/src/model.py | 89 +++++++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 34 deletions(-) diff --git a/moonshine-onnx/src/model.py b/moonshine-onnx/src/model.py index 7d7997a..5939269 100644 --- a/moonshine-onnx/src/model.py +++ b/moonshine-onnx/src/model.py @@ -1,11 +1,13 @@ def _get_onnx_weights(model_name): from huggingface_hub import hf_hub_download - - repo = "UsefulSensors/moonshine" + + if model_name not in ["tiny", "base"]: + raise ValueError(f"Unknown model \"{model_name}\"") + repo = f"onnx-community/moonshine-{model_name}-ONNX" return ( - hf_hub_download(repo, f"{x}.onnx", subfolder=f"onnx/{model_name}") - for x in ("preprocess", "encode", "uncached_decode", "cached_decode") + hf_hub_download(repo, f"{x}.onnx", subfolder="onnx") + for x in ("encoder_model", "decoder_model_merged_q4") ) @@ -17,18 +19,30 @@ def __init__(self, models_dir=None, model_name=None): assert ( model_name is not None ), "model_name should be specified if models_dir is not" - preprocess, encode, uncached_decode, cached_decode = ( + encoder, decoder = ( self._load_weights_from_hf_hub(model_name) ) else: - preprocess, encode, uncached_decode, cached_decode = [ + encoder, decoder = [ f"{models_dir}/{x}.onnx" - for x in ["preprocess", "encode", "uncached_decode", "cached_decode"] + for x in ("encoder_model", "decoder_model_merged") ] - self.preprocess = onnxruntime.InferenceSession(preprocess) - self.encode = onnxruntime.InferenceSession(encode) - self.uncached_decode = onnxruntime.InferenceSession(uncached_decode) - self.cached_decode = onnxruntime.InferenceSession(cached_decode) + self.encoder = onnxruntime.InferenceSession(encoder) + self.decoder = onnxruntime.InferenceSession(decoder) + + if 'tiny' in model_name: + self.num_layers = 6 + self.num_key_value_heads = 8 + self.head_dim = 36 + elif 'base' in model_name: + self.num_layers = 8 + self.num_key_value_heads = 8 + self.head_dim = 52 + else: + raise ValueError(f"Unknown model \"{model_name}\"") + + self.decoder_start_token_id = 1 + self.eos_token_id = 2 def _load_weights_from_hf_hub(self, model_name): model_name = model_name.split("/")[-1] @@ -39,31 +53,38 @@ def generate(self, audio, max_len=None): if max_len is None: # max 6 tokens per second of audio max_len = int((audio.shape[-1] / 16_000) * 6) - preprocessed = self.preprocess.run([], dict(args_0=audio))[0] - seq_len = [preprocessed.shape[-2]] - context = self.encode.run([], dict(args_0=preprocessed, args_1=seq_len))[0] - inputs = [[1]] - seq_len = [1] - tokens = [1] - logits, *cache = self.uncached_decode.run( - [], dict(args_0=inputs, args_1=context, args_2=seq_len) - ) + import numpy as np + + last_hidden_state = self.encoder.run(None, dict(input_values=audio))[0] + + past_key_values = { + f"past_key_values.{i}.{a}.{b}": np.zeros((0, self.num_key_value_heads, 1, self.head_dim), dtype=np.float32) + for i in range(self.num_layers) + for a in ("decoder", "encoder") + for b in ("key", "value") + } + + tokens = [self.decoder_start_token_id] + input_ids = [tokens] for i in range(max_len): - next_token = logits.squeeze().argmax() - tokens.extend([next_token]) - if next_token == 2: + use_cache_branch = i > 0 + decoder_inputs = dict( + input_ids=input_ids, + encoder_hidden_states=last_hidden_state, + use_cache_branch=[use_cache_branch], + **past_key_values, + ) + logits, *present_key_values = self.decoder.run(None, decoder_inputs) + next_token = logits[0, -1].argmax().item() + tokens.append(next_token) + if next_token == self.eos_token_id: break - seq_len[0] += 1 - inputs = [[next_token]] - logits, *cache = self.cached_decode.run( - [], - dict( - args_0=inputs, - args_1=context, - args_2=seq_len, - **{f"args_{i+3}": x for i, x in enumerate(cache)}, - ), - ) + # Update values for next iteration + input_ids = [[next_token]] + for k, v in zip(past_key_values.keys(), present_key_values): + if not use_cache_branch or 'decoder' in k: + past_key_values[k] = v + return [tokens] From 7f70a6c2293dc22f52a72f8e718b21dfecea0b76 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 14 Dec 2024 21:29:17 +0000 Subject: [PATCH 2/3] Use fp32 by default --- moonshine-onnx/src/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moonshine-onnx/src/model.py b/moonshine-onnx/src/model.py index 5939269..ca5b954 100644 --- a/moonshine-onnx/src/model.py +++ b/moonshine-onnx/src/model.py @@ -7,7 +7,7 @@ def _get_onnx_weights(model_name): return ( hf_hub_download(repo, f"{x}.onnx", subfolder="onnx") - for x in ("encoder_model", "decoder_model_merged_q4") + for x in ("encoder_model", "decoder_model_merged") ) From d215d9a1478984beea02d8897dbf4ac2042bcd46 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 14 Dec 2024 21:53:49 +0000 Subject: [PATCH 3/3] Formatting --- moonshine-onnx/src/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/moonshine-onnx/src/model.py b/moonshine-onnx/src/model.py index ca5b954..7af329a 100644 --- a/moonshine-onnx/src/model.py +++ b/moonshine-onnx/src/model.py @@ -30,11 +30,11 @@ def __init__(self, models_dir=None, model_name=None): self.encoder = onnxruntime.InferenceSession(encoder) self.decoder = onnxruntime.InferenceSession(decoder) - if 'tiny' in model_name: + if "tiny" in model_name: self.num_layers = 6 self.num_key_value_heads = 8 self.head_dim = 36 - elif 'base' in model_name: + elif "base" in model_name: self.num_layers = 8 self.num_key_value_heads = 8 self.head_dim = 52 @@ -84,7 +84,7 @@ def generate(self, audio, max_len=None): # Update values for next iteration input_ids = [[next_token]] for k, v in zip(past_key_values.keys(), present_key_values): - if not use_cache_branch or 'decoder' in k: + if not use_cache_branch or "decoder" in k: past_key_values[k] = v return [tokens]