Merge branch 'master' of github.com:Jittor/jittor

This commit is contained in:
Dun Liang 2022-03-21 17:49:24 +08:00
commit 99d6d6b6e1
3 changed files with 42 additions and 22 deletions

View File

@ -89,7 +89,7 @@ void CublasBatchedMatmulOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << a->dtype();
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D');
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
jk << ']';
}
@ -128,6 +128,18 @@ void CublasBatchedMatmulOp::jit_run() {
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUBLAS_COMPUTE_16F;
}
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cudaDataType_t computeType = CUDA_R_32F;
if (use_tensorcore) {
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUDA_R_16F;
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
#endif
checkCudaErrors(cublasGemmStridedBatchedEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
@ -135,15 +147,13 @@ void CublasBatchedMatmulOp::jit_run() {
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k, k * n,
batch_size,computeType,algo));
#else
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(), k, k * n,
batch_size));
#endif
// checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
// c->ptr<T>(), k, k * n,
// batch_size));
}
#endif
#endif // JIT

View File

@ -50,7 +50,7 @@ void CublasMatmulOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << a->dtype();
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D');
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
jk << ']';
}
@ -85,6 +85,18 @@ void CublasMatmulOp::jit_run() {
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUBLAS_COMPUTE_16F;
}
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cudaDataType_t computeType = CUDA_R_32F;
if (use_tensorcore) {
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUDA_R_16F;
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
#endif
checkCudaErrors(cublasGemmEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
@ -92,15 +104,13 @@ void CublasMatmulOp::jit_run() {
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k,
computeType, algo));
#else
checkCudaErrors(cublas@op@@gemm(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(), k));
// checkCudaErrors(cublas@op@@gemm(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
// c->ptr<T>(), k));
#endif
}
#endif // JIT

View File

@ -100,7 +100,7 @@ __global__ void kernel(pout0_type* x, dout_type* y, out0_type* z, int len) {{
{for_loop}
#pragma unroll
for (int j=0; j<{ILP}; j++)
v1 += {"vy[i][j];" if log else "vx[i][j]*vy[i][j];"}
v1 += {"float(vy[i][j]);" if log else "float(vx[i][j]*vy[i][j]);"}
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
@ -114,8 +114,8 @@ __global__ void kernel(pout0_type* x, dout_type* y, out0_type* z, int len) {{
#pragma unroll
for (int j=0; j<{ILP}; j++)
vx[i][j] = {
"vy[i][j] - expf(vx[i][j]) * reduce_var;" if log
else "vx[i][j] * (vy[i][j] - reduce_var);"
"vy[i][j] - in0_type(expf(vx[i][j]) * reduce_var);" if log
else "vx[i][j] * (vy[i][j] - in0_type(reduce_var));"
}
{for_loop}