mirror of https://github.com/Jittor/Jittor
polish tensorcore of cublas in cuda 10
This commit is contained in:
parent
f36693c797
commit
5b4576c4dd
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.1.33'
|
||||
__version__ = '1.3.1.34'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -118,22 +118,12 @@ void CublasBatchedMatmulOp::jit_run() {
|
|||
k = bs[adim-2];
|
||||
}
|
||||
// a: [b,n,m], b: [b,m,k], c: [b,n,k]
|
||||
#if CUDART_VERSION >= 11000
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
#if CUDART_VERSION >= 11000
|
||||
if (use_tensorcore) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
}
|
||||
#endif
|
||||
|
||||
// checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
|
||||
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
// k, n, m, &alpha,
|
||||
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
|
||||
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
|
||||
// c->ptr<T>(), k, k * n,
|
||||
// batch_size));
|
||||
|
||||
checkCudaErrors(cublasGemmStridedBatchedEx(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
|
@ -141,7 +131,15 @@ void CublasBatchedMatmulOp::jit_run() {
|
|||
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta,
|
||||
c->ptr<T>(),get_dtype(c->dtype()), k, k * n,
|
||||
batch_size,computeType,algo));
|
||||
|
||||
#else
|
||||
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
|
||||
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
|
||||
c->ptr<T>(), k, k * n,
|
||||
batch_size));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
|
|
@ -75,21 +75,12 @@ void CublasMatmulOp::jit_run() {
|
|||
k = bs[0];
|
||||
}
|
||||
// a: [n,m], b: [m,k], c: [n,k]
|
||||
#if CUDART_VERSION >= 11000
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
#if CUDART_VERSION >= 11000
|
||||
if (use_tensorcore) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
}
|
||||
#endif
|
||||
|
||||
// checkCudaErrors(cublas@op@@gemm(handle_,
|
||||
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
// k, n, m, &alpha,
|
||||
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
|
||||
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
|
||||
// c->ptr<T>(), k));
|
||||
|
||||
checkCudaErrors(cublasGemmEx(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
|
@ -97,6 +88,15 @@ void CublasMatmulOp::jit_run() {
|
|||
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta,
|
||||
c->ptr<T>(),get_dtype(c->dtype()), k,
|
||||
computeType, algo));
|
||||
#else
|
||||
checkCudaErrors(cublas@op@@gemm(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
|
||||
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
|
||||
c->ptr<T>(), k));
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
#endif // JIT
|
||||
|
|
|
@ -294,10 +294,16 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
// #define xxx
|
||||
// i jk l
|
||||
auto j=i+1;
|
||||
while (j<src.size() && src[j] != ' ') j++;
|
||||
while (j<src.size() && (src[j] != ' ' && src[j] != '\n')) j++;
|
||||
auto mstr = src.substr(i,j-i);
|
||||
if (mstr == "#if" || mstr == "#else" || mstr == "#endif") {
|
||||
new_src += mstr;
|
||||
i = j-1;
|
||||
continue;
|
||||
}
|
||||
ASSERT(j<src.size());
|
||||
auto k=j+1;
|
||||
while (k<src.size() && src[k] == ' ') k++;
|
||||
while (k<src.size() && src[k] == ' ' && src[k] != '\n') k++;
|
||||
ASSERT(k<src.size());
|
||||
auto l=k+1;
|
||||
while (l<src.size() && (src[l] != '\n')) l++;
|
||||
|
|
|
@ -161,6 +161,11 @@ ncclBcast(..., @T_NCCL, ...)
|
|||
assert "ncclInt" in jit_precompile({"Tx":"int32"}, code)
|
||||
assert "ncclInt64" in jit_precompile({"Tx":"int64"}, code)
|
||||
|
||||
def test_mif(self):
|
||||
vars = {"Tx":"float"}
|
||||
check = lambda expr, result: \
|
||||
self.assertEqual(jit_precompile(vars, expr), result)
|
||||
check("#if aa>1\n@Tx\n#else\n@Tx@@1\n#endif", "#if aa>1\nfloat\n#else\nfloat1\n#endif")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue