-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use cublasGemmGroupedBatchedEx
in cublas 12.5
#6
base: main
Are you sure you want to change the base?
Conversation
Hi @zhuzilin ,thanks for your contribution! We have already noticed the new grouped gemm API introduced in cublas 12.5, but we haven't had time to run some tests yet. We are certainly willing to try something new, and your PR can help us do it faster. Back to the topic, literally, My concern is that, for the grouped gemm problems in LLM, the gemm shape is usually very large, so that we don't need to worry about the problem of low GPU utilization. In this case, we found that shape-tailored gemm kernels always outperform a grouped gemm kernel. Don't know if this is true for this new API. We will have a try based on your PR first. Another news is that we finally decided to deprecate this personal repo in the future and move all the functions to TransformerEngine. So this repo may not be uphold anymore except some bug fix. Thank you again for your contribution! 😊 |
After a brief look at your code, I found that both
There may be some misunderstanding that the Aarray, Barray and Carray in the original implementation is on the device. |
It seems that the MoE architectures are leaning toward larger granularity -- lots of small experts. A good example is DeepSeek-V2, which has 160 experts. In that case, there will be many small gemm in each MoE layer.
Oh, the m, n, k array, which contains the shape of each gemm need to be on CPU, just like the old |
Hi!
This PR is an attempt to use the
cublasGemmGroupedBatchedEx
api introduced in cublas 12.5 to calculate the grouped gemm. And the code has passedop_test.py
.There is an potential optimization that is not implemented yet. The origin
grouped_gemm
requires anbatch_sizes
variable on CPU. However, forcublasGemmGroupedBatchedEx
, theAarray
,Barray
andCarray
need to be located on device, which will move the CPU array back to GPU. And I think that we could allow thebatch_sizes
on GPU for this branch and calculate thed_Aarray
on torch withtensor.data_ptr()
andbatch_sizes
.Making everything on GPU would reduce the synchronization on all streams during training and potentially make the training faster. But it may require more changes on the current codebase. I wonder if you could share your preference on this? Thank you!
Also, it would be great if you could tell me the benchmark I need to compare this code with the origin branch :)