mirror of https://github.com/Jittor/Jittor
set to 0
This commit is contained in:
parent
0625957143
commit
9d9895190e
|
@ -72,19 +72,6 @@ class Pool(Module):
|
|||
kernel1<<<s1, s2>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src=[f'''
|
||||
__global__ static void kernel2(@ARGS_DEF) {{
|
||||
@PRECALC
|
||||
int p3 = threadIdx.x;
|
||||
int s3 = blockDim.x;
|
||||
int p2 = threadIdx.y + blockIdx.x * blockDim.y;
|
||||
int s2 = blockDim.y * gridDim.x;
|
||||
int i1 = blockIdx.y;
|
||||
int i0 = blockIdx.z;
|
||||
for (int i3 = p3; i3 < outshape3; i3 += s3)
|
||||
for (int i2 = p2; i2 < outshape2; i2 += s2) {{
|
||||
@out(i0, i1, i2, i3) = 0;
|
||||
}}
|
||||
}}
|
||||
__global__ static void kernel3(@ARGS_DEF) {{
|
||||
@PRECALC
|
||||
int p3 = threadIdx.x;
|
||||
|
@ -111,19 +98,12 @@ class Pool(Module):
|
|||
}}
|
||||
}}
|
||||
}}
|
||||
int tx = min(1024, outshape3);
|
||||
int ty = min(1024 / tx, outshape2);
|
||||
int bx = (outshape2 - 1) / ty + 1;
|
||||
int by = outshape1;
|
||||
int bz = outshape0;
|
||||
dim3 s1(bx, by, bz);
|
||||
dim3 s2(tx, ty);
|
||||
kernel2<<<s1, s2>>>(@ARGS);
|
||||
tx = min(1024, poutshape3);
|
||||
ty = min(1024 / tx, poutshape2);
|
||||
bx = (poutshape2 - 1) / ty + 1;
|
||||
by = poutshape1;
|
||||
bz = poutshape0;
|
||||
cudaMemsetAsync(outp, 0, out->size);
|
||||
int tx = min(1024, poutshape3);
|
||||
int ty = min(1024 / tx, poutshape2);
|
||||
int bx = (poutshape2 - 1) / ty + 1;
|
||||
int by = poutshape1;
|
||||
int bz = poutshape0;
|
||||
dim3 s1_(bx, by, bz);
|
||||
dim3 s2_(tx, ty);
|
||||
kernel3<<<s1_, s2_>>>(@ARGS);
|
||||
|
|
Loading…
Reference in New Issue