mirror of https://github.com/Jittor/Jittor
130 lines
3.9 KiB
Python
130 lines
3.9 KiB
Python
import jittor as jt
|
|
from jittor import nn
|
|
|
|
def can_softmax_v1(a, dim):
|
|
if not jt.flags.use_cuda:
|
|
return False
|
|
if dim != -1 and dim != len(a.shape)-1:
|
|
return False
|
|
if a.shape[len(a.shape)-1] > 10000:
|
|
return False
|
|
return True
|
|
|
|
def softmax_v1(a, log=False):
|
|
assert can_softmax_v1(a, -1)
|
|
length = a.shape[-1]
|
|
# tnum = 1024
|
|
tnum = 500 if length % 500 == 0 else 512
|
|
tnum = 125 if length % 125 == 0 else 128
|
|
# tnum = 125
|
|
# tnum = 1000 if length % 1000 == 0 else 1024
|
|
# tnum = 250
|
|
per_thread = (length-1) // tnum + 1
|
|
ILP = 1
|
|
for ilp in [8,4,2]:
|
|
if length % tnum == 0 and per_thread % ilp == 0:
|
|
ILP = ilp
|
|
per_thread //= ILP
|
|
break
|
|
for_loop = f"""
|
|
#pragma unroll
|
|
for (int i=0; i<{per_thread}; i++)
|
|
"""
|
|
if length % tnum != 0:
|
|
for_loop += f"if ((i*{tnum}+threadIdx.x)*{ILP} < len)\n"
|
|
|
|
return jt.code(a.shape, a.dtype, [a], cuda_header=f'''
|
|
#include <{jt.compile_extern.cub_home}cub/cub.cuh>
|
|
#include <type/fp16_compute.h>
|
|
''', cuda_src=f'''
|
|
__global__ void kernel(in0_type* x, out0_type* y, int len) {{
|
|
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
|
|
int id = blockIdx.x * len;
|
|
in0_type v[{per_thread}][{ILP}];
|
|
{for_loop}
|
|
vload<sizeof(in0_type)*{ILP}>(v[i], &x[id+(i*{tnum}+threadIdx.x)*{ILP}]);
|
|
// v[i] = x[id+i*{tnum}+threadIdx.x];
|
|
float v1 = -1e30;
|
|
{for_loop}
|
|
#pragma unroll
|
|
for (int j=0; j<{ILP}; j++) {{
|
|
v1 = max(v1, float(v[i][j]));
|
|
}}
|
|
__shared__ float vmax;
|
|
auto tmp = BlockReduce(temp_storage).Reduce(v1, cub::Max());
|
|
if (threadIdx.x == 0)
|
|
vmax = tmp;
|
|
__syncthreads();
|
|
|
|
v1 = 0;
|
|
{for_loop}
|
|
#pragma unroll
|
|
for (int j=0; j<{ILP}; j++) {{
|
|
v[i][j] = expf(float(v[i][j]) - vmax);
|
|
v1 += float(v[i][j]);
|
|
}}
|
|
|
|
tmp = BlockReduce(temp_storage).Sum(v1);
|
|
__shared__ float vsum;
|
|
if (threadIdx.x == 0)
|
|
vsum = tmp;
|
|
__syncthreads();
|
|
|
|
{for_loop}
|
|
#pragma unroll
|
|
for (int j=0; j<{ILP}; j++)
|
|
v[i][j] = {
|
|
"@expand_op(log,@in0_type,float(v[i][j])/vsum)" if log
|
|
else "float(v[i][j])/vsum"
|
|
};
|
|
{for_loop}
|
|
vload<sizeof(in0_type)*{ILP}>(&y[id+(i*{tnum}+threadIdx.x)*{ILP}], v[i]);
|
|
}}
|
|
int len = in0->shape[in0->shape.size()-1];
|
|
int bnum = in0->numel() / len;
|
|
cudaGetLastError();
|
|
kernel<<<bnum, {tnum}>>>(in0_p, out0_p, len);
|
|
CHECK(0 == cudaGetLastError());
|
|
''', cuda_grad_src=[f"""
|
|
__global__ void kernel(pout0_type* x, dout_type* y, out0_type* z, int len) {{
|
|
int id = blockIdx.x * len;
|
|
in0_type vx[{per_thread}][{ILP}];
|
|
in0_type vy[{per_thread}][{ILP}];
|
|
{for_loop} {{
|
|
vload<sizeof(in0_type)*{ILP}>(vx[i], &x[id+(i*{tnum}+threadIdx.x)*{ILP}]);
|
|
vload<sizeof(in0_type)*{ILP}>(vy[i], &y[id+(i*{tnum}+threadIdx.x)*{ILP}]);
|
|
}}
|
|
float v1 = 0;
|
|
{for_loop}
|
|
#pragma unroll
|
|
for (int j=0; j<{ILP}; j++)
|
|
v1 += {"float(vy[i][j]);" if log else "float(vx[i][j]*vy[i][j]);"}
|
|
|
|
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
auto tmp = BlockReduce(temp_storage).Sum(v1);
|
|
__shared__ float reduce_var;
|
|
if (threadIdx.x == 0)
|
|
reduce_var = tmp;
|
|
__syncthreads();
|
|
|
|
{for_loop}
|
|
#pragma unroll
|
|
for (int j=0; j<{ILP}; j++)
|
|
vx[i][j] = {
|
|
"vy[i][j] - in0_type(expf(vx[i][j]) * reduce_var);" if log
|
|
else "vx[i][j] * (vy[i][j] - in0_type(reduce_var));"
|
|
}
|
|
|
|
{for_loop}
|
|
vload<sizeof(in0_type)*{ILP}>(&z[id+(i*{tnum}+threadIdx.x)*{ILP}],
|
|
vx[i]);
|
|
}}
|
|
int len = in0->shape[in0->shape.size()-1];
|
|
int bnum = in0->numel() / len;
|
|
cudaGetLastError();
|
|
kernel<<<bnum, {tnum}>>>(pout0_p, dout_p, out0_p, len);
|
|
CHECK(0 == cudaGetLastError());
|
|
"""]) |