Skip to content

Commit

Permalink
polish tensorcore of cublas in cuda 10
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Jan 10, 2022
1 parent f36693c commit 5b4576c
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 25 deletions.
2 changes: 1 addition & 1 deletion python/jittor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

__version__ = '1.3.1.33'
__version__ = '1.3.1.34'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
Expand Down
22 changes: 10 additions & 12 deletions python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,30 +118,28 @@ void CublasBatchedMatmulOp::jit_run() {
k = bs[adim-2];
}
// a: [b,n,m], b: [b,m,k], c: [b,n,k]
#if CUDART_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
#if CUDART_VERSION >= 11000
if (use_tensorcore) {
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif

// checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
// c->ptr<T>(), k, k * n,
// batch_size));

checkCudaErrors(cublasGemmStridedBatchedEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(),get_dtype(b->dtype()), '@Trans_b' == 'N' ? k : m, k * m,
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k, k * n,
batch_size,computeType,algo));

#else
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(), k, k * n,
batch_size));
#endif
}
#endif
#endif // JIT
Expand Down
20 changes: 10 additions & 10 deletions python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,28 @@ void CublasMatmulOp::jit_run() {
k = bs[0];
}
// a: [n,m], b: [m,k], c: [n,k]
#if CUDART_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
#if CUDART_VERSION >= 11000
if (use_tensorcore) {
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif

// checkCudaErrors(cublas@op@@gemm(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
// c->ptr<T>(), k));

checkCudaErrors(cublasGemmEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(),get_dtype(b->dtype()), '@Trans_b' == 'N' ? k : m,
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k,
computeType, algo));
#else
checkCudaErrors(cublas@op@@gemm(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(), k));

#endif

}
#endif // JIT
Expand Down
10 changes: 8 additions & 2 deletions python/jittor/src/op_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,16 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
// #define xxx
// i jk l
auto j=i+1;
while (j<src.size() && src[j] != ' ') j++;
while (j<src.size() && (src[j] != ' ' && src[j] != '\n')) j++;
auto mstr = src.substr(i,j-i);
if (mstr == "#if" || mstr == "#else" || mstr == "#endif") {
new_src += mstr;
i = j-1;
continue;
}
ASSERT(j<src.size());
auto k=j+1;
while (k<src.size() && src[k] == ' ') k++;
while (k<src.size() && src[k] == ' ' && src[k] != '\n') k++;
ASSERT(k<src.size());
auto l=k+1;
while (l<src.size() && (src[l] != '\n')) l++;
Expand Down
5 changes: 5 additions & 0 deletions python/jittor/test/test_op_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ def test_strcmp(self):
assert "ncclInt" in jit_precompile({"Tx":"int32"}, code)
assert "ncclInt64" in jit_precompile({"Tx":"int64"}, code)

def test_mif(self):
vars = {"Tx":"float"}
check = lambda expr, result: \
self.assertEqual(jit_precompile(vars, expr), result)
check("#if aa>1\n@Tx\n#else\n@Tx@@1\n#endif", "#if aa>1\nfloat\n#else\nfloat1\n#endif")


if __name__ == "__main__":
Expand Down

0 comments on commit 5b4576c

Please sign in to comment.