polish tensorcore of cublas in cuda 10

This commit is contained in:
Dun Liang 2022-01-10 14:35:30 +08:00
parent f36693c797
commit 5b4576c4dd
5 changed files with 34 additions and 25 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.3.1.33' __version__ = '1.3.1.34'
from jittor_utils import lock from jittor_utils import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int

View File

@ -118,22 +118,12 @@ void CublasBatchedMatmulOp::jit_run() {
k = bs[adim-2]; k = bs[adim-2];
} }
// a: [b,n,m], b: [b,m,k], c: [b,n,k] // a: [b,n,m], b: [b,m,k], c: [b,n,k]
#if CUDART_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
#if CUDART_VERSION >= 11000
if (use_tensorcore) { if (use_tensorcore) {
computeType = CUBLAS_COMPUTE_32F_FAST_16F; computeType = CUBLAS_COMPUTE_32F_FAST_16F;
} }
#endif
// checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
// c->ptr<T>(), k, k * n,
// batch_size));
checkCudaErrors(cublasGemmStridedBatchedEx(handle_, checkCudaErrors(cublasGemmStridedBatchedEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha, k, n, m, &alpha,
@ -141,7 +131,15 @@ void CublasBatchedMatmulOp::jit_run() {
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta, a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k, k * n, c->ptr<T>(),get_dtype(c->dtype()), k, k * n,
batch_size,computeType,algo)); batch_size,computeType,algo));
#else
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(), k, k * n,
batch_size));
#endif
} }
#endif #endif
#endif // JIT #endif // JIT

View File

@ -75,21 +75,12 @@ void CublasMatmulOp::jit_run() {
k = bs[0]; k = bs[0];
} }
// a: [n,m], b: [m,k], c: [n,k] // a: [n,m], b: [m,k], c: [n,k]
#if CUDART_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
#if CUDART_VERSION >= 11000
if (use_tensorcore) { if (use_tensorcore) {
computeType = CUBLAS_COMPUTE_32F_FAST_16F; computeType = CUBLAS_COMPUTE_32F_FAST_16F;
} }
#endif
// checkCudaErrors(cublas@op@@gemm(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
// c->ptr<T>(), k));
checkCudaErrors(cublasGemmEx(handle_, checkCudaErrors(cublasGemmEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha, k, n, m, &alpha,
@ -97,6 +88,15 @@ void CublasMatmulOp::jit_run() {
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta, a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k, c->ptr<T>(),get_dtype(c->dtype()), k,
computeType, algo)); computeType, algo));
#else
checkCudaErrors(cublas@op@@gemm(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(), k));
#endif
} }
#endif // JIT #endif // JIT

View File

@ -294,10 +294,16 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
// #define xxx // #define xxx
// i jk l // i jk l
auto j=i+1; auto j=i+1;
while (j<src.size() && src[j] != ' ') j++; while (j<src.size() && (src[j] != ' ' && src[j] != '\n')) j++;
auto mstr = src.substr(i,j-i);
if (mstr == "#if" || mstr == "#else" || mstr == "#endif") {
new_src += mstr;
i = j-1;
continue;
}
ASSERT(j<src.size()); ASSERT(j<src.size());
auto k=j+1; auto k=j+1;
while (k<src.size() && src[k] == ' ') k++; while (k<src.size() && src[k] == ' ' && src[k] != '\n') k++;
ASSERT(k<src.size()); ASSERT(k<src.size());
auto l=k+1; auto l=k+1;
while (l<src.size() && (src[l] != '\n')) l++; while (l<src.size() && (src[l] != '\n')) l++;

View File

@ -161,6 +161,11 @@ ncclBcast(..., @T_NCCL, ...)
assert "ncclInt" in jit_precompile({"Tx":"int32"}, code) assert "ncclInt" in jit_precompile({"Tx":"int32"}, code)
assert "ncclInt64" in jit_precompile({"Tx":"int64"}, code) assert "ncclInt64" in jit_precompile({"Tx":"int64"}, code)
def test_mif(self):
vars = {"Tx":"float"}
check = lambda expr, result: \
self.assertEqual(jit_precompile(vars, expr), result)
check("#if aa>1\n@Tx\n#else\n@Tx@@1\n#endif", "#if aa>1\nfloat\n#else\nfloat1\n#endif")
if __name__ == "__main__": if __name__ == "__main__":