5x pytorch performance increase on CPU, 1 thread #408
Replies: 3 comments 7 replies
-
UPD: onnx < 1s P.S. Admittedly this was a very fast proof of concept test, so I may have screwed something up, even though result probabilities are all within 2e-6 compared to the reference probabilities. |
Beta Was this translation helpful? Give feedback.
-
Hi @IntendedConsequence , I managed to implement your idea in V5 and got a great speedup, the only problem I have is with exporting to onnx, I have to split the model to encoder and decoder but other than that no problems import numpy as np
import onnxruntime
opts = onnxruntime.SessionOptions()
opts.inter_op_num_threads = 0
opts.intra_op_num_threads = 0
opts.log_severity_level = 2
path = "encoder.onnx"
encoder_session = onnxruntime.InferenceSession(
path,
providers=["CPUExecutionProvider"],
sess_options=opts,
)
path = "decoder.onnx"
decoder_session = onnxruntime.InferenceSession(
path,
providers=["CPUExecutionProvider"],
sess_options=opts,
)
def forward(x: np.ndarray, num_samples: int, context_size_samples: int):
assert (
x.ndim == 2
), "Input should be a 2D tensor with size (batch_size, num_samples)"
assert (
x.shape[1] % num_samples == 0
), "Input size should be a multiple of num_samples"
num_audio = x.shape[0]
state = np.zeros((2, num_audio, 128), dtype="float32")
context = np.zeros(
(num_audio, context_size_samples),
dtype="float32",
)
x = x.reshape(num_audio, -1, num_samples)
context = x[..., -context_size_samples:]
context[:, -1] = 0
context = np.roll(context, 1, 1)
x = np.concatenate([context, x], 2)
x = x.reshape(-1, num_samples + context_size_samples)
x = encoder_session.run(None, {"input": x})[0]
x = x.reshape(num_audio, -1, 128)
decoder_outputs = []
for window in np.split(x, x.shape[1], axis=1):
out, state = decoder_session.run(
None, {"input": window.squeeze(1), "state": state}
)
decoder_outputs.append(out)
out = np.stack(decoder_outputs, axis=1).squeeze(-1)
return out it accepts a whole input file, no batching needed as any system can handle the whole file in one pass, the |
Beta Was this translation helpful? Give feedback.
-
Hi, This is very interesting. I'm also trying this now for the I'm measuring the performance over a single chunk of 512 samples, the recommended value by the devs (in fact its the only value supported for the new v5 model). The fact that you had initially 512*3 sample windows could have made your inference more efficient. I replicated the same 512-elem chunk across the batch dimension for batch sizes (1-256). Set the With So I could only get 2x improvement by this change. Comments are welcome. |
Beta Was this translation helpful? Give feedback.
-
tl;dr: Smarter batching without splitting the whole audio. Use batch inference on steps up until and including the encoder, then use batch size of 1 for lstm and decoder, because the last two parts are the only ones that depend on state.
All that is without affecting precision at all, as would be the case with regular batching, where one would have to split the entire audio in multiple independent slices, where each slice would be processed using its own independent state. But if batching is only used on the graph up to encoder, you can still process the entire audio with a single state.
Longer version:
When I was looking inside the jit and onnx model internals, I noticed that the majority of the computation the model performs during inference does not depend on state (hn and cn inputs to lstm).
The model consists of roughly 6 steps:
First 4 of those depend only on audio data, and only at the lstm step the state variables hn and cn are actually used and updated. This means that first 4 steps are easy to isolate and parallelize, and only the last 2 steps need to be computed sequentially.
Note: technically the decoder can also be batched, resulting in more speedups, but I didn't test that
While this is good news for GPU and multithreaded inference, I did not expect a 5x performance increase on a single thread.
Version: Silero VAD v3.1, 16kHz. Restored python code from .jit file (for implementing the model architecture in C/ggml)
Sample count: 1536
Pytorch version: 2.1-cpu
torch.set_grad_enabled(False)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
Batch size for first 4 steps: 32
Batch size for lstm and decoder: 1
Result: 600s of audio data is processed in 2-3 seconds.
And this is just using my restored pytorch silero code without any TorchScript, no jit.script or jit.trace at all.
Running the official v3.1 silero_vad.jit, batch_size=1 on the same 600s of audio data takes 10.4 seconds.
Batch size above 32 on my cpu doesn't result in much improvement.
I didn't look at V4 internals closely, but if the lstm is still near the bottom of the graph, the preceding steps can probably benefit from batching too.
I also didn't test what effect this kind of batching have on onnx model inference, and I didn't conduct any GPU/CUDA tests.
Beta Was this translation helpful? Give feedback.
All reactions