You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, Marlin is primarily optimized for generative inference (with a few tokens at-a-time), which is actually memory-bound and can hence be sped up via weight-quantization; e.g. input shapes of (16, 1, 11008). Note that for batchsize > 128 (meaning the overall number of tokens, in your case 16 * 1024), inference stops being memory bound and weight-only quantization can generally not be faster (though Marlin sometimes is a bit for not too large batchsizes due to slightly better partitioning than the default torch kernels).
Thanks for your answer @efrantar . Understood. I am trying to integrate it with our quantization method, below the benchmarks for the forward pass on an 3090, Llama2-7B, batch-size=1, context-size=2048:
fp16: 0.4604 + model compile: 0.4218
int4 (torch compile): 0.4554
Marlin (int4): 0.4221 + model compile: 0.3841
It is about 10% faster than fp16 with this setup-up, the llm eval score drops a bit though (51.57 -> 51.27)
Is there a way to dequantize the weights without calling the matmul with the identity matrix?
I have been making some benchmarks with Marlin, but the speed-up is far from what is reported. In fact, it's actually slower than fp16:
GPU: A6000 ada
Code below:
What could be the issue ? Thanks in advance!
The text was updated successfully, but these errors were encountered: