Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use cublasGemmGroupedBatchedEx in cublas 12.5 #6

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

zhuzilin
Copy link

Hi!

This PR is an attempt to use the cublasGemmGroupedBatchedEx api introduced in cublas 12.5 to calculate the grouped gemm. And the code has passed op_test.py.

There is an potential optimization that is not implemented yet. The origin grouped_gemm requires an batch_sizes variable on CPU. However, for cublasGemmGroupedBatchedEx, the Aarray, Barray and Carray need to be located on device, which will move the CPU array back to GPU. And I think that we could allow the batch_sizes on GPU for this branch and calculate the d_Aarray on torch with tensor.data_ptr() and batch_sizes.

Making everything on GPU would reduce the synchronization on all streams during training and potentially make the training faster. But it may require more changes on the current codebase. I wonder if you could share your preference on this? Thank you!

Also, it would be great if you could tell me the benchmark I need to compare this code with the origin branch :)

@StudyingShao
Copy link
Collaborator

Hi @zhuzilin ,thanks for your contribution! We have already noticed the new grouped gemm API introduced in cublas 12.5, but we haven't had time to run some tests yet. We are certainly willing to try something new, and your PR can help us do it faster.

Back to the topic, literally, cublasGemmGroupedBatchedEx just need a device-side batch_sizes, and it doesn't need multi-stream, so that it can eliminate both the synchronization between cuda streams and the synchronization between host and device.

My concern is that, for the grouped gemm problems in LLM, the gemm shape is usually very large, so that we don't need to worry about the problem of low GPU utilization. In this case, we found that shape-tailored gemm kernels always outperform a grouped gemm kernel. Don't know if this is true for this new API. We will have a try based on your PR first.

Another news is that we finally decided to deprecate this personal repo in the future and move all the functions to TransformerEngine. So this repo may not be uphold anymore except some bug fix.

Thank you again for your contribution! 😊

@StudyingShao
Copy link
Collaborator

After a brief look at your code, I found that both m_array, n_array and k_array ,which correspond to batch_sizes in the original implementation, are on the CPU side. That means the synchronization between host and device is required as well.

However, for cublasGemmGroupedBatchedEx, the Aarray, Barray and Carray need to be located on device, which will move the CPU array back to GPU.

There may be some misunderstanding that the Aarray, Barray and Carray in the original implementation is on the device.

@zhuzilin
Copy link
Author

zhuzilin commented Jul 30, 2024

@StudyingShao

My concern is that, for the grouped gemm problems in LLM, the gemm shape is usually very large, so that we don't need to worry about the problem of low GPU utilization.

It seems that the MoE architectures are leaning toward larger granularity -- lots of small experts. A good example is DeepSeek-V2, which has 160 experts. In that case, there will be many small gemm in each MoE layer.

the Aarray, Barray and Carray in the original implementation is on the device.

Oh, the m, n, k array, which contains the shape of each gemm need to be on CPU, just like the old batch_sizes, while the array of pointers to the actual memory that stores the tensor, A, B, C array, need to be on device, which makes this cublas API kind of tricky to use....

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants