diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc index e8229c8a..6f93e343 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc @@ -128,6 +128,18 @@ void CublasBatchedMatmulOp::jit_run() { || b->dtype() == ns_float16 || c->dtype() == ns_float16) { computeType = CUBLAS_COMPUTE_16F; } + #else + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + cudaDataType_t computeType = CUDA_R_32F; + if (use_tensorcore) { + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + if (a->dtype() == ns_float16 + || b->dtype() == ns_float16 || c->dtype() == ns_float16) { + computeType = CUDA_R_16F; + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + #endif checkCudaErrors(cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, k, n, m, &alpha, @@ -135,15 +147,13 @@ void CublasBatchedMatmulOp::jit_run() { a->ptr(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta, c->ptr(),get_dtype(c->dtype()), k, k * n, batch_size,computeType,algo)); - #else - checkCudaErrors(cublas@op@@gemmStridedBatched(handle_, - CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, - k, n, m, &alpha, - b->ptr(), '@Trans_b' == 'N' ? k : m, k * m, - a->ptr(), '@Trans_a' == 'N' ? m : n, n * m, &beta, - c->ptr(), k, k * n, - batch_size)); - #endif + // checkCudaErrors(cublas@op@@gemmStridedBatched(handle_, + // CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, + // k, n, m, &alpha, + // b->ptr(), '@Trans_b' == 'N' ? k : m, k * m, + // a->ptr(), '@Trans_a' == 'N' ? m : n, n * m, &beta, + // c->ptr(), k, k * n, + // batch_size)); } #endif #endif // JIT diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc index babcfa1e..a6708225 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc @@ -85,6 +85,18 @@ void CublasMatmulOp::jit_run() { || b->dtype() == ns_float16 || c->dtype() == ns_float16) { computeType = CUBLAS_COMPUTE_16F; } + #else + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + cudaDataType_t computeType = CUDA_R_32F; + if (use_tensorcore) { + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + if (a->dtype() == ns_float16 + || b->dtype() == ns_float16 || c->dtype() == ns_float16) { + computeType = CUDA_R_16F; + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + #endif checkCudaErrors(cublasGemmEx(handle_, CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, k, n, m, &alpha, @@ -92,15 +104,13 @@ void CublasMatmulOp::jit_run() { a->ptr(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta, c->ptr(),get_dtype(c->dtype()), k, computeType, algo)); - #else - checkCudaErrors(cublas@op@@gemm(handle_, - CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, - k, n, m, &alpha, - b->ptr(), '@Trans_b' == 'N' ? k : m, - a->ptr(), '@Trans_a' == 'N' ? m : n, &beta, - c->ptr(), k)); + // checkCudaErrors(cublas@op@@gemm(handle_, + // CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, + // k, n, m, &alpha, + // b->ptr(), '@Trans_b' == 'N' ? k : m, + // a->ptr(), '@Trans_a' == 'N' ? m : n, &beta, + // c->ptr(), k)); - #endif } #endif // JIT