add use_tensorcore flags

This commit is contained in:
Dun Liang 2021-12-30 20:59:48 +08:00
parent 7b9113f828
commit f20fa98e44
4 changed files with 56 additions and 3 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.1.31'
__version__ = '1.3.1.32'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -118,8 +118,13 @@ void CublasBatchedMatmulOp::jit_run() {
k = bs[adim-2];
}
// a: [b,n,m], b: [b,m,k], c: [b,n,k]
cublasGemmAlgo_t algo = use_tensorcore ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
#if CUDART_VERSION >= 11000
if (use_tensorcore) {
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
// checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,

View File

@ -75,8 +75,13 @@ void CublasMatmulOp::jit_run() {
k = bs[0];
}
// a: [n,m], b: [m,k], c: [n,k]
cublasGemmAlgo_t algo = use_tensorcore ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
#if CUDART_VERSION >= 11000
if (use_tensorcore) {
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
// checkCudaErrors(cublas@op@@gemm(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,

View File

@ -353,5 +353,48 @@ class TestMatmul(unittest.TestCase):
b = linear(a)
assert b.shape == (20,)
# def test_tensorcore(self):
# import time
# jt.flags.use_cuda = 1
# # jt.flags.use_tensorcore = 1
# a = jt.rand(4096, 4096)
# b = jt.rand(4096, 4096)
# for i in range(100):
# c = jt.matmul(a, b)
# c.sync()
# jt.sync_all(True)
# start = time.time()
# for i in range(1000):
# c = jt.matmul(a, b)
# c.sync()
# jt.sync_all(True)
# end = time.time() - start
# gflops = 4096**3*2 * 1000 / end / 10**9
# print(end, gflops)
# # 14T vs 37T
# def test_conv(self):
# import time
# jt.flags.use_cuda = 1
# # jt.flags.use_tensorcore = 1
# a = jt.rand(160, 1024, 16, 16)
# b = jt.rand(1024, 1024, 1, 1)
# for i in range(100):
# c = jt.nn.conv2d(a, b)
# c.sync()
# jt.sync_all(True)
# start = time.time()
# for i in range(1000):
# c = jt.nn.conv2d(a, b)
# c.sync()
# jt.sync_all(True)
# end = time.time() - start
# gflops = a.numel() * b.numel() * 2 / 1024 * 1000 / end / 10**9
# print(end, gflops)
# # 12T vs 30T
if __name__ == "__main__":
unittest.main()