mirror of https://github.com/Jittor/Jittor
add use_tensorcore flags
This commit is contained in:
parent
7b9113f828
commit
f20fa98e44
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.1.31'
|
||||
__version__ = '1.3.1.32'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -118,8 +118,13 @@ void CublasBatchedMatmulOp::jit_run() {
|
|||
k = bs[adim-2];
|
||||
}
|
||||
// a: [b,n,m], b: [b,m,k], c: [b,n,k]
|
||||
cublasGemmAlgo_t algo = use_tensorcore ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
#if CUDART_VERSION >= 11000
|
||||
if (use_tensorcore) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
}
|
||||
#endif
|
||||
|
||||
// checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
|
||||
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
|
|
|
@ -75,8 +75,13 @@ void CublasMatmulOp::jit_run() {
|
|||
k = bs[0];
|
||||
}
|
||||
// a: [n,m], b: [m,k], c: [n,k]
|
||||
cublasGemmAlgo_t algo = use_tensorcore ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
#if CUDART_VERSION >= 11000
|
||||
if (use_tensorcore) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
}
|
||||
#endif
|
||||
|
||||
// checkCudaErrors(cublas@op@@gemm(handle_,
|
||||
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
|
|
|
@ -353,5 +353,48 @@ class TestMatmul(unittest.TestCase):
|
|||
b = linear(a)
|
||||
assert b.shape == (20,)
|
||||
|
||||
# def test_tensorcore(self):
|
||||
# import time
|
||||
# jt.flags.use_cuda = 1
|
||||
# # jt.flags.use_tensorcore = 1
|
||||
# a = jt.rand(4096, 4096)
|
||||
# b = jt.rand(4096, 4096)
|
||||
# for i in range(100):
|
||||
# c = jt.matmul(a, b)
|
||||
# c.sync()
|
||||
# jt.sync_all(True)
|
||||
|
||||
# start = time.time()
|
||||
# for i in range(1000):
|
||||
# c = jt.matmul(a, b)
|
||||
# c.sync()
|
||||
# jt.sync_all(True)
|
||||
# end = time.time() - start
|
||||
# gflops = 4096**3*2 * 1000 / end / 10**9
|
||||
# print(end, gflops)
|
||||
# # 14T vs 37T
|
||||
|
||||
# def test_conv(self):
|
||||
# import time
|
||||
# jt.flags.use_cuda = 1
|
||||
# # jt.flags.use_tensorcore = 1
|
||||
# a = jt.rand(160, 1024, 16, 16)
|
||||
# b = jt.rand(1024, 1024, 1, 1)
|
||||
# for i in range(100):
|
||||
# c = jt.nn.conv2d(a, b)
|
||||
# c.sync()
|
||||
# jt.sync_all(True)
|
||||
|
||||
# start = time.time()
|
||||
# for i in range(1000):
|
||||
# c = jt.nn.conv2d(a, b)
|
||||
# c.sync()
|
||||
# jt.sync_all(True)
|
||||
# end = time.time() - start
|
||||
# gflops = a.numel() * b.numel() * 2 / 1024 * 1000 / end / 10**9
|
||||
# print(end, gflops)
|
||||
# # 12T vs 30T
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue