This commit is contained in:
guoye 2020-03-30 13:10:57 +08:00
parent 0625957143
commit 9d9895190e
1 changed files with 6 additions and 26 deletions

View File

@ -72,19 +72,6 @@ class Pool(Module):
kernel1<<<s1, s2>>>(@ARGS);
''',
cuda_grad_src=[f'''
__global__ static void kernel2(@ARGS_DEF) {{
@PRECALC
int p3 = threadIdx.x;
int s3 = blockDim.x;
int p2 = threadIdx.y + blockIdx.x * blockDim.y;
int s2 = blockDim.y * gridDim.x;
int i1 = blockIdx.y;
int i0 = blockIdx.z;
for (int i3 = p3; i3 < outshape3; i3 += s3)
for (int i2 = p2; i2 < outshape2; i2 += s2) {{
@out(i0, i1, i2, i3) = 0;
}}
}}
__global__ static void kernel3(@ARGS_DEF) {{
@PRECALC
int p3 = threadIdx.x;
@ -111,19 +98,12 @@ class Pool(Module):
}}
}}
}}
int tx = min(1024, outshape3);
int ty = min(1024 / tx, outshape2);
int bx = (outshape2 - 1) / ty + 1;
int by = outshape1;
int bz = outshape0;
dim3 s1(bx, by, bz);
dim3 s2(tx, ty);
kernel2<<<s1, s2>>>(@ARGS);
tx = min(1024, poutshape3);
ty = min(1024 / tx, poutshape2);
bx = (poutshape2 - 1) / ty + 1;
by = poutshape1;
bz = poutshape0;
cudaMemsetAsync(outp, 0, out->size);
int tx = min(1024, poutshape3);
int ty = min(1024 / tx, poutshape2);
int bx = (poutshape2 - 1) / ty + 1;
int by = poutshape1;
int bz = poutshape0;
dim3 s1_(bx, by, bz);
dim3 s2_(tx, ty);
kernel3<<<s1_, s2_>>>(@ARGS);