diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index e01ca333..eb4f8240 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.1.33' +__version__ = '1.3.1.34' from jittor_utils import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc index 010309f5..6ab519a0 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc @@ -118,22 +118,12 @@ void CublasBatchedMatmulOp::jit_run() { k = bs[adim-2]; } // a: [b,n,m], b: [b,m,k], c: [b,n,k] + #if CUDART_VERSION >= 11000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; - #if CUDART_VERSION >= 11000 if (use_tensorcore) { computeType = CUBLAS_COMPUTE_32F_FAST_16F; } - #endif - - // checkCudaErrors(cublas@op@@gemmStridedBatched(handle_, - // CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, - // k, n, m, &alpha, - // b->ptr(), '@Trans_b' == 'N' ? k : m, k * m, - // a->ptr(), '@Trans_a' == 'N' ? m : n, n * m, &beta, - // c->ptr(), k, k * n, - // batch_size)); - checkCudaErrors(cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, k, n, m, &alpha, @@ -141,7 +131,15 @@ void CublasBatchedMatmulOp::jit_run() { a->ptr(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta, c->ptr(),get_dtype(c->dtype()), k, k * n, batch_size,computeType,algo)); - + #else + checkCudaErrors(cublas@op@@gemmStridedBatched(handle_, + CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, + k, n, m, &alpha, + b->ptr(), '@Trans_b' == 'N' ? k : m, k * m, + a->ptr(), '@Trans_a' == 'N' ? m : n, n * m, &beta, + c->ptr(), k, k * n, + batch_size)); + #endif } #endif #endif // JIT diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc index ff340d62..0ed46bc4 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc @@ -75,21 +75,12 @@ void CublasMatmulOp::jit_run() { k = bs[0]; } // a: [n,m], b: [m,k], c: [n,k] + #if CUDART_VERSION >= 11000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; - #if CUDART_VERSION >= 11000 if (use_tensorcore) { computeType = CUBLAS_COMPUTE_32F_FAST_16F; } - #endif - - // checkCudaErrors(cublas@op@@gemm(handle_, - // CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, - // k, n, m, &alpha, - // b->ptr(), '@Trans_b' == 'N' ? k : m, - // a->ptr(), '@Trans_a' == 'N' ? m : n, &beta, - // c->ptr(), k)); - checkCudaErrors(cublasGemmEx(handle_, CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, k, n, m, &alpha, @@ -97,6 +88,15 @@ void CublasMatmulOp::jit_run() { a->ptr(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta, c->ptr(),get_dtype(c->dtype()), k, computeType, algo)); + #else + checkCudaErrors(cublas@op@@gemm(handle_, + CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, + k, n, m, &alpha, + b->ptr(), '@Trans_b' == 'N' ? k : m, + a->ptr(), '@Trans_a' == 'N' ? m : n, &beta, + c->ptr(), k)); + + #endif } #endif // JIT diff --git a/python/jittor/src/op_compiler.cc b/python/jittor/src/op_compiler.cc index 5e9b56bf..7945f0fe 100644 --- a/python/jittor/src/op_compiler.cc +++ b/python/jittor/src/op_compiler.cc @@ -294,10 +294,16 @@ string precompile(unordered_map defs, string src, unordered_map1\n@Tx\n#else\n@Tx@@1\n#endif", "#if aa>1\nfloat\n#else\nfloat1\n#endif") if __name__ == "__main__":