Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoiding .contiguous call before matmul #1965

Closed
EricLBuehler opened this issue Mar 29, 2024 · 12 comments
Closed

Avoiding .contiguous call before matmul #1965

EricLBuehler opened this issue Mar 29, 2024 · 12 comments

Comments

@EricLBuehler
Copy link
Member

During my work with Candle I have noticed that matmul seems to only accept contiguous tensors. That is, after calling .reshape, matmul will fail because it is not contiguous in the right dimensions. The logical fix (#1948) seems to be to call .contiguous and this is implemented in the provided models.

It is likely that supporting non-contiguous tensors cause a good performance increase (#1939), as not copying as many tensors around unnecessarily is always better. After reading the dfdx code, it seems like they are able to handle non-contiguous tensors.

Is there a way to implement this technique in Candle? Thanks!

@LaurentMazare
Copy link
Collaborator

Matmuls are actually also supported for transposed layouts on top of contiguous ones, so we enforce contiguous or transposed contiguous as these are supported on all devices + mkl & accelerate and likely to be the fastest. Supporting arbitrary strides is dodgy, e.g. on mkl hence the limitation at the moment.

@EricLBuehler
Copy link
Member Author

I plan on writing this for CUDA, would that work? Do you know of any good resources or places where I can find a sample implementation?

@LaurentMazare
Copy link
Collaborator

I'm not sure it would, we use cublas for the matmul (actually this function) and it doesn't let you work with arbitrary strides except for the batch dimensions, you can just specify transpose and not transpose which is what we're doing.
The recent lazy reshape work should make the situation a bit better though and avoid some unecessary copies (and there is one more change to come on this front that will add some fast copy mode in some specific case).

@EricLBuehler
Copy link
Member Author

That sounds exciting! I am trying to create a fast LLM engine, targeting being faster than llama.cpp with Candle, so any optimization is beneficial. I have been using my fork as a playground to test out new ideas, and have implemented some fused kernels. Are there any directions which I could look into for optimizing performance, specifically on CUDA?

@LaurentMazare
Copy link
Collaborator

Random thoughts on performance:

  • Be sure to use the latest github version as performance on cuda for transformers architectures has improved a lot over the last two weeks (as well as metal).
  • You should also enable flash-attn if possible.
  • Fused kernels are certainly good.
  • The recently introduced inplace-ops also make it possible to avoid allocation and reduce the memory bandwidth, though they are a bit tricky to use manually at the moment.

@EricLBuehler
Copy link
Member Author

I have been following the recent changes and look forward to trying them out! I am running what is essentially the quantized_llama code as my test case, so I don't think flash attention is an option for that, unfortunately (unless we implemented a f32 version).

Regarding inplace-ops, could that be implemented for CUDA by passing the src slice = dst slice (depending on kernel impl, of course)?

@LaurentMazare
Copy link
Collaborator

Regarding inplace-ops, could that be implemented for CUDA by passing the src slice = dst slice (depending on kernel impl, of course)?

Yes, we don't have a cuda test to show this but here is the cpu one, custom_op_test.

Also besides this I've spent quite some time using nsys profile and nsys-ui lately and it's really great to see where the time is actually spent on the cuda side. For the in-house transformer model I've been looking at, we're at ~85% of the time spent in the matmul kernels so not much to be gained by optimizing the other kernels.

@EricLBuehler
Copy link
Member Author

I ran a flamegraph with mistral.rs, and unfortunately it shows that much of the time is spent in the call to synchronize the GPU (almost 2x more than the model shows). It seems strange that llama.cpp can be so much faster, but looking at their code I think they have their own matmul kernel which allows non-contiguous tensors: https://github.com/ggerganov/llama.cpp/blob/c342d070c64a1ffe35d22c1b16b672e684a30297/ggml-cuda.cu#L1111. Do you think that avoiding calling .contiguous and using a kernel like this could improve performance? I wonder what measures can currently be taken to optimize candle.

@LaurentMazare
Copy link
Collaborator

Is it a normal cpu flamegraph? If so it's not representative of where the time is actually spent if you use cuda. I've optimized a bit the synchronization points with the gpu - removing most of them in #1886 - and it actually doesn't make much of a difference except for very small transformers (which we do use internally so that was actually quite a win there), for a 7b it's negligilble at least in my setup. You certainly see them disappearing in the nsys-ui but that doesn't change much as the time is spent in the matmul kernels.

Here is the kernel occupation on my RTX 2080 for the quantized example using a q4 7b model and you can see that we spent mostly all the time in the dequantize_mul_mat_vec_q4_0... kernels that do the matmul, the copy kernels are almost invisible. On a full exec I get 95% in the matmul, 1% in the rms-norm, 0.9% in the kv-concatenation optimized kernel and the actual contiguous are reported as 0.0% in nsys-ui.

20240330-trace

@EricLBuehler
Copy link
Member Author

EricLBuehler commented Mar 30, 2024

Is it a normal cpu flamegraph?

Yes

I wonder what llama.cpp is doing that makes it so much faster. Clearly, candle is spending much of its time in matmul, so there is probably not much more trim. I noticed that they use some sort of graph for model execution, do you know of any other techniques they employ?

@LaurentMazare
Copy link
Collaborator

I'm not sure what llama.cpp is doing differently, certainly interested if you manage to dig some of this out.
Also just to mention I double checked and with the latest backend tweaks all the contiguous ops in the inference loop for the quantized example are actually no-ops (they don't trigger any cuda kernel or anything). I'm looking at removing them as there are some spurious checks but pretty sure that there isn't much to be gained on that front.

@EricLBuehler
Copy link
Member Author

Thanks for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants