mirror of https://github.com/Jittor/Jittor
fix avg pool compare with pytorch
This commit is contained in:
parent
87daadf34d
commit
1e867a5b7b
|
@ -26,19 +26,48 @@ class Pool(Module):
|
|||
|
||||
def execute(self, x):
|
||||
N,C,H,W = x.shape
|
||||
if (self.ceil_mode == False):
|
||||
if self.ceil_mode == False:
|
||||
h = (H+self.padding*2-self.kernel_size)//self.stride+1
|
||||
w = (W+self.padding*2-self.kernel_size)//self.stride+1
|
||||
else:
|
||||
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
|
||||
if (self.op == 'maximum' or self.op == 'minimum'):
|
||||
if (self.op == 'maximum'):
|
||||
op = 'max'
|
||||
else:
|
||||
op = 'min'
|
||||
if self.op in ['maximum', 'minimum', 'mean']:
|
||||
forward_body = f'''{{
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
|
||||
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
|
||||
k3 = max(0, k3);
|
||||
k2 = max(0, k2);
|
||||
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
|
||||
{"int count = (k2_ - k2) * (k3_ - k3);" if self.op == "mean" else ""}
|
||||
for (int p = k2; p < k2_; ++p)
|
||||
for (int q = k3; q < k3_; ++q)
|
||||
@out(i0, i1, i2, i3) = {self.op}(out_type, @out(i0, i1, i2, i3), @in0(i0, i1, p, q));
|
||||
}}'''
|
||||
backward_body = f'''{{
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
|
||||
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
|
||||
k3 = max(0, k3);
|
||||
k2 = max(0, k2);
|
||||
{"int count = (k2_ - k2) * (k3_ - k3);" if self.op == "mean" else ""}
|
||||
int bo=1;
|
||||
for (int p = k2; p < k2_ && bo; ++p)
|
||||
for (int q = k3; q < k3_ && bo; ++q) {{
|
||||
{"atomicAdd(&@out(i0,i1,p,q), @dout(i0,i1,i2,i3)/count);"
|
||||
if self.op == "mean" else
|
||||
f"""if (@pout(i0,i1,i2,i3) == @in0(i0,i1,p,q)) {{
|
||||
atomicAdd(&@out(i0,i1,p,q), @dout(i0,i1,i2,i3)),
|
||||
bo=0;
|
||||
}}"""}
|
||||
}}
|
||||
}}'''
|
||||
out = jt.code([N,C,h,w], x.dtype, [x],
|
||||
cuda_header='#include <ops/binary_op_defs.h>',
|
||||
cuda_src=f'''
|
||||
__global__ static void kernel1(@ARGS_DEF) {{
|
||||
@PRECALC
|
||||
|
@ -49,18 +78,8 @@ class Pool(Module):
|
|||
int i1 = blockIdx.y;
|
||||
int i0 = blockIdx.z;
|
||||
for (int i3 = p3; i3 < out_shape3; i3 += s3)
|
||||
for (int i2 = p2; i2 < out_shape2; i2 += s2) {{
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
|
||||
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
|
||||
k3 = max(0, k3);
|
||||
k2 = max(0, k2);
|
||||
@out(i0, i1, i2, i3) = @in0(i0, i1, k2, k3);
|
||||
for (int p = k2; p < k2_; ++p)
|
||||
for (int q = k3; q < k3_; ++q)
|
||||
@out(i0, i1, i2, i3) = {op}(@out(i0, i1, i2, i3), @in0(i0, i1, p, q));
|
||||
}}
|
||||
for (int i2 = p2; i2 < out_shape2; i2 += s2)
|
||||
{forward_body}
|
||||
}}
|
||||
int tx = min(1024, out_shape3);
|
||||
int ty = min(1024 / tx, out_shape2);
|
||||
|
@ -81,22 +100,8 @@ class Pool(Module):
|
|||
int i1 = blockIdx.y;
|
||||
int i0 = blockIdx.z;
|
||||
for (int i3 = p3; i3 < pout_shape3; i3 += s3)
|
||||
for (int i2 = p2; i2 < pout_shape2; i2 += s2) {{
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
|
||||
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
|
||||
k3 = max(0, k3);
|
||||
k2 = max(0, k2);
|
||||
int bo=1;
|
||||
for (int p = k2; p < k2_ && bo; ++p)
|
||||
for (int q = k3; q < k3_ && bo; ++q) {{
|
||||
if (@pout(i0,i1,i2,i3) == @in0(i0,i1,p,q)) {{
|
||||
atomicAdd(&@out(i0,i1,p,q), @dout(i0,i1,i2,i3));
|
||||
bo=0;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
for (int i2 = p2; i2 < pout_shape2; i2 += s2)
|
||||
{backward_body}
|
||||
}}
|
||||
cudaMemsetAsync(out_p, 0, out->size);
|
||||
int tx = min(1024, pout_shape3);
|
||||
|
@ -108,48 +113,25 @@ class Pool(Module):
|
|||
dim3 s2_(tx, ty);
|
||||
kernel3<<<s1_, s2_>>>(@ARGS);
|
||||
'''],
|
||||
cpu_header='#include <ops/binary_op_defs.h>',
|
||||
cpu_src=f'''
|
||||
using namespace std;
|
||||
for (int i0=0; i0<out_shape0; i0++)
|
||||
for (int i1=0; i1<out_shape1; i1++)
|
||||
for (int i2=0; i2<out_shape2; i2++)
|
||||
for (int i3=0; i3<out_shape3; i3++) {{
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2_ = std::min(k2 + {self.kernel_size}, in0_shape2);
|
||||
int k3_ = std::min(k3 + {self.kernel_size}, in0_shape3);
|
||||
k2 = std::max(0, k2);
|
||||
k3 = std::max(0, k3);
|
||||
@out(i0, i1, i2, i3) = @in0(i0, i1, k2, k3);
|
||||
for (int p = k2; p < k2_; ++p)
|
||||
for (int q = k3; q < k3_; ++q)
|
||||
@out(i0, i1, i2, i3) = std::{op}(@out(i0, i1, i2, i3), @in0(i0, i1, p, q));
|
||||
}}
|
||||
for (int i3=0; i3<out_shape3; i3++)
|
||||
{forward_body}
|
||||
''',
|
||||
cpu_grad_src = [f'''
|
||||
for (int i=0; i<out_shape0; i++)
|
||||
for (int j=0; j<out_shape1; j++)
|
||||
for (int k=0; k<out_shape2; k++)
|
||||
for (int l=0; l<out_shape3; l++) @out(i,j,k,l) = 0;
|
||||
using namespace std;
|
||||
std::memset(out_p, 0, out->size);
|
||||
#define atomicAdd(a,b) (*a) += b
|
||||
|
||||
for (int i0=0; i0<pout_shape0; i0++)
|
||||
for (int i1=0; i1<pout_shape1; i1++)
|
||||
for (int i2=0; i2<pout_shape2; i2++)
|
||||
for (int i3=0; i3<pout_shape3; i3++) {{
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
int k3_ = std::min(k3 + {self.kernel_size}, in0_shape3);
|
||||
int k2_ = std::min(k2 + {self.kernel_size}, in0_shape2);
|
||||
k3 = std::max(0, k3);
|
||||
k2 = std::max(0, k2);
|
||||
int bo=1;
|
||||
for (int p = k2; p < k2_ && bo; ++p)
|
||||
for (int q = k3; q < k3_ && bo; ++q) {{
|
||||
if (@pout(i0,i1,i2,i3) == @in0(i0,i1,p,q)) {{
|
||||
@out(i0,i1,p,q) += @dout(i0,i1,i2,i3);
|
||||
bo=0;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
for (int i3=0; i3<pout_shape3; i3++)
|
||||
{backward_body}
|
||||
'''])
|
||||
return out
|
||||
else:
|
||||
|
|
|
@ -75,5 +75,13 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
for i in range(10):
|
||||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
||||
def test_cpu_avg_pool(self):
|
||||
from torch.nn import AvgPool2d
|
||||
jt_model = Pool(2, 2, 0, op="mean", ceil_mode=True)
|
||||
torch_model = AvgPool2d(2, 2, 0, ceil_mode=True)
|
||||
# shape = [64, 64, 300, 300]
|
||||
shape = (2, 16, 33, 33)
|
||||
check(jt_model, torch_model, shape, False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue