mirror of https://github.com/Jittor/Jittor
polish fp16
This commit is contained in:
parent
5efb222dd3
commit
ad57ec890f
|
@ -89,7 +89,7 @@ void CublasBatchedMatmulOp::jit_prepare(JK& jk) {
|
|||
jk << _CS("[T:") << a->dtype();
|
||||
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
|
||||
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
|
||||
jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D');
|
||||
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ void CublasMatmulOp::jit_prepare(JK& jk) {
|
|||
jk << _CS("[T:") << a->dtype();
|
||||
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
|
||||
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
|
||||
jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D');
|
||||
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ __global__ void kernel(pout0_type* x, dout_type* y, out0_type* z, int len) {{
|
|||
{for_loop}
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++)
|
||||
v1 += {"vy[i][j];" if log else "vx[i][j]*vy[i][j];"}
|
||||
v1 += {"float(vy[i][j]);" if log else "float(vx[i][j]*vy[i][j]);"}
|
||||
|
||||
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
@ -114,8 +114,8 @@ __global__ void kernel(pout0_type* x, dout_type* y, out0_type* z, int len) {{
|
|||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++)
|
||||
vx[i][j] = {
|
||||
"vy[i][j] - expf(vx[i][j]) * reduce_var;" if log
|
||||
else "vx[i][j] * (vy[i][j] - reduce_var);"
|
||||
"vy[i][j] - in0_type(expf(vx[i][j]) * reduce_var);" if log
|
||||
else "vx[i][j] * (vy[i][j] - in0_type(reduce_var));"
|
||||
}
|
||||
|
||||
{for_loop}
|
||||
|
|
Loading…
Reference in New Issue