Skip to content

Commit

Permalink
polish matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
li-xl committed Mar 18, 2022
1 parent ad57ec8 commit 1fa8977
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
28 changes: 19 additions & 9 deletions python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,32 @@ void CublasBatchedMatmulOp::jit_run() {
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUBLAS_COMPUTE_16F;
}
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cudaDataType_t computeType = CUDA_R_32F;
if (use_tensorcore) {
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUDA_R_16F;
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
#endif
checkCudaErrors(cublasGemmStridedBatchedEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(),get_dtype(b->dtype()), '@Trans_b' == 'N' ? k : m, k * m,
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k, k * n,
batch_size,computeType,algo));
#else
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(), k, k * n,
batch_size));
#endif
// checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
// c->ptr<T>(), k, k * n,
// batch_size));
}
#endif
#endif // JIT
Expand Down
26 changes: 18 additions & 8 deletions python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,32 @@ void CublasMatmulOp::jit_run() {
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUBLAS_COMPUTE_16F;
}
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cudaDataType_t computeType = CUDA_R_32F;
if (use_tensorcore) {
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUDA_R_16F;
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
#endif
checkCudaErrors(cublasGemmEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(),get_dtype(b->dtype()), '@Trans_b' == 'N' ? k : m,
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k,
computeType, algo));
#else
checkCudaErrors(cublas@op@@gemm(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(), k));
// checkCudaErrors(cublas@op@@gemm(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
// c->ptr<T>(), k));

#endif

}
#endif // JIT
Expand Down

0 comments on commit 1fa8977

Please sign in to comment.