fix avg pool compare with pytorch

This commit is contained in:
Dun Liang 2020-05-27 16:59:00 +08:00
parent 87daadf34d
commit 1e867a5b7b
2 changed files with 56 additions and 66 deletions

View File

@ -26,19 +26,48 @@ class Pool(Module):
def execute(self, x):
N,C,H,W = x.shape
if (self.ceil_mode == False):
if self.ceil_mode == False:
h = (H+self.padding*2-self.kernel_size)//self.stride+1
w = (W+self.padding*2-self.kernel_size)//self.stride+1
else:
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
if (self.op == 'maximum' or self.op == 'minimum'):
if (self.op == 'maximum'):
op = 'max'
else:
op = 'min'
if self.op in ['maximum', 'minimum', 'mean']:
forward_body = f'''{{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
k3 = max(0, k3);
k2 = max(0, k2);
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
{"int count = (k2_ - k2) * (k3_ - k3);" if self.op == "mean" else ""}
for (int p = k2; p < k2_; ++p)
for (int q = k3; q < k3_; ++q)
@out(i0, i1, i2, i3) = {self.op}(out_type, @out(i0, i1, i2, i3), @in0(i0, i1, p, q));
}}'''
backward_body = f'''{{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
k3 = max(0, k3);
k2 = max(0, k2);
{"int count = (k2_ - k2) * (k3_ - k3);" if self.op == "mean" else ""}
int bo=1;
for (int p = k2; p < k2_ && bo; ++p)
for (int q = k3; q < k3_ && bo; ++q) {{
{"atomicAdd(&@out(i0,i1,p,q), @dout(i0,i1,i2,i3)/count);"
if self.op == "mean" else
f"""if (@pout(i0,i1,i2,i3) == @in0(i0,i1,p,q)) {{
atomicAdd(&@out(i0,i1,p,q), @dout(i0,i1,i2,i3)),
bo=0;
}}"""}
}}
}}'''
out = jt.code([N,C,h,w], x.dtype, [x],
cuda_header='#include <ops/binary_op_defs.h>',
cuda_src=f'''
__global__ static void kernel1(@ARGS_DEF) {{
@PRECALC
@ -49,18 +78,8 @@ class Pool(Module):
int i1 = blockIdx.y;
int i0 = blockIdx.z;
for (int i3 = p3; i3 < out_shape3; i3 += s3)
for (int i2 = p2; i2 < out_shape2; i2 += s2) {{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
k3 = max(0, k3);
k2 = max(0, k2);
@out(i0, i1, i2, i3) = @in0(i0, i1, k2, k3);
for (int p = k2; p < k2_; ++p)
for (int q = k3; q < k3_; ++q)
@out(i0, i1, i2, i3) = {op}(@out(i0, i1, i2, i3), @in0(i0, i1, p, q));
}}
for (int i2 = p2; i2 < out_shape2; i2 += s2)
{forward_body}
}}
int tx = min(1024, out_shape3);
int ty = min(1024 / tx, out_shape2);
@ -81,22 +100,8 @@ class Pool(Module):
int i1 = blockIdx.y;
int i0 = blockIdx.z;
for (int i3 = p3; i3 < pout_shape3; i3 += s3)
for (int i2 = p2; i2 < pout_shape2; i2 += s2) {{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
k3 = max(0, k3);
k2 = max(0, k2);
int bo=1;
for (int p = k2; p < k2_ && bo; ++p)
for (int q = k3; q < k3_ && bo; ++q) {{
if (@pout(i0,i1,i2,i3) == @in0(i0,i1,p,q)) {{
atomicAdd(&@out(i0,i1,p,q), @dout(i0,i1,i2,i3));
bo=0;
}}
}}
}}
for (int i2 = p2; i2 < pout_shape2; i2 += s2)
{backward_body}
}}
cudaMemsetAsync(out_p, 0, out->size);
int tx = min(1024, pout_shape3);
@ -108,48 +113,25 @@ class Pool(Module):
dim3 s2_(tx, ty);
kernel3<<<s1_, s2_>>>(@ARGS);
'''],
cpu_header='#include <ops/binary_op_defs.h>',
cpu_src=f'''
using namespace std;
for (int i0=0; i0<out_shape0; i0++)
for (int i1=0; i1<out_shape1; i1++)
for (int i2=0; i2<out_shape2; i2++)
for (int i3=0; i3<out_shape3; i3++) {{
int k2 = i2*{self.stride}-{self.padding};
int k3 = i3*{self.stride}-{self.padding};
int k2_ = std::min(k2 + {self.kernel_size}, in0_shape2);
int k3_ = std::min(k3 + {self.kernel_size}, in0_shape3);
k2 = std::max(0, k2);
k3 = std::max(0, k3);
@out(i0, i1, i2, i3) = @in0(i0, i1, k2, k3);
for (int p = k2; p < k2_; ++p)
for (int q = k3; q < k3_; ++q)
@out(i0, i1, i2, i3) = std::{op}(@out(i0, i1, i2, i3), @in0(i0, i1, p, q));
}}
for (int i3=0; i3<out_shape3; i3++)
{forward_body}
''',
cpu_grad_src = [f'''
for (int i=0; i<out_shape0; i++)
for (int j=0; j<out_shape1; j++)
for (int k=0; k<out_shape2; k++)
for (int l=0; l<out_shape3; l++) @out(i,j,k,l) = 0;
using namespace std;
std::memset(out_p, 0, out->size);
#define atomicAdd(a,b) (*a) += b
for (int i0=0; i0<pout_shape0; i0++)
for (int i1=0; i1<pout_shape1; i1++)
for (int i2=0; i2<pout_shape2; i2++)
for (int i3=0; i3<pout_shape3; i3++) {{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = std::min(k3 + {self.kernel_size}, in0_shape3);
int k2_ = std::min(k2 + {self.kernel_size}, in0_shape2);
k3 = std::max(0, k3);
k2 = std::max(0, k2);
int bo=1;
for (int p = k2; p < k2_ && bo; ++p)
for (int q = k3; q < k3_ && bo; ++q) {{
if (@pout(i0,i1,i2,i3) == @in0(i0,i1,p,q)) {{
@out(i0,i1,p,q) += @dout(i0,i1,i2,i3);
bo=0;
}}
}}
}}
for (int i3=0; i3<pout_shape3; i3++)
{backward_body}
'''])
return out
else:

View File

@ -75,5 +75,13 @@ class TestArgPoolOp(unittest.TestCase):
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
def test_cpu_avg_pool(self):
from torch.nn import AvgPool2d
jt_model = Pool(2, 2, 0, op="mean", ceil_mode=True)
torch_model = AvgPool2d(2, 2, 0, ceil_mode=True)
# shape = [64, 64, 300, 300]
shape = (2, 16, 33, 33)
check(jt_model, torch_model, shape, False)
if __name__ == "__main__":
unittest.main()