diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc index 47874ef3..e8229c8a 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc @@ -89,7 +89,7 @@ void CublasBatchedMatmulOp::jit_prepare(JK& jk) { jk << _CS("[T:") << a->dtype(); jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N'); - jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D'); + jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); jk << ']'; } diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc index 95de20f7..babcfa1e 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc @@ -50,7 +50,7 @@ void CublasMatmulOp::jit_prepare(JK& jk) { jk << _CS("[T:") << a->dtype(); jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N'); - jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D'); + jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); jk << ']'; } diff --git a/python/jittor/other/code_softmax.py b/python/jittor/other/code_softmax.py index 8534f0cb..837bd648 100644 --- a/python/jittor/other/code_softmax.py +++ b/python/jittor/other/code_softmax.py @@ -100,7 +100,7 @@ def softmax_v1(a, log=False): {for_loop} #pragma unroll for (int j=0; j<{ILP}; j++) - v1 += {"vy[i][j];" if log else "vx[i][j]*vy[i][j];"} + v1 += {"float(vy[i][j]);" if log else "float(vx[i][j]*vy[i][j]);"} typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -114,8 +114,8 @@ def softmax_v1(a, log=False): #pragma unroll for (int j=0; j<{ILP}; j++) vx[i][j] = { - "vy[i][j] - expf(vx[i][j]) * reduce_var;" if log - else "vx[i][j] * (vy[i][j] - reduce_var);" + "vy[i][j] - in0_type(expf(vx[i][j]) * reduce_var);" if log + else "vx[i][j] * (vy[i][j] - in0_type(reduce_var));" } {for_loop}