add count_include_pad for pad

This commit is contained in:
Dun Liang 2020-05-27 22:19:16 +08:00
parent 93ca5e9525
commit a6dab87634
3 changed files with 12 additions and 46 deletions

View File

@ -9,51 +9,10 @@
# ***************************************************************
import jittor as jt
import numpy as np
from jittor import pool
def argmax_pool(x, size, stride, padding=0):
y_shape = list(x.shape)
y_shape[2]=(x.shape[2]+padding*2-size)//stride+1
y_shape[3]=(x.shape[3]+padding*2-size)//stride+1
y = jt.code(y_shape, x.dtype, [x],
cpu_src=f'''
for (int i=0; i<out_shape0; i++)
for (int j=0; j<out_shape1; j++)
for (int k=0; k<out_shape2; k++)
for (int l=0; l<out_shape3; l++) {{
int kx=k*{stride}+{size}/2-{padding};
int ky=l*{stride}+{size}/2-{padding};
@out(i,j,k,l) = @in0(i,j,kx,ky);
for (int p=kx-{size}/2;p<=kx+{size}/2;p++)
for (int q=ky-{size}/2;q<=ky+{size}/2;q++)
if (p>=0 && q>=0 && p<in0_shape2 && q<in0_shape3)
if (@out(i,j,k,l) < @in0(i,j,p,q))
@out(i,j,k,l) = @in0(i,j,p,q);
}}
''',
cpu_grad_src = [f'''
for (int i=0; i<out_shape0; i++)
for (int j=0; j<out_shape1; j++)
for (int k=0; k<out_shape2; k++)
for (int l=0; l<out_shape3; l++) @out(i,j,k,l) = 0;
for (int i=0; i<pout_shape0; i++)
for (int j=0; j<pout_shape1; j++)
for (int k=0; k<pout_shape2; k++)
for (int l=0; l<pout_shape3; l++) {{
int kx=k*{stride}+{size}/2-{padding};
int ky=l*{stride}+{size}/2-{padding};
int bo=1;
for (int p=kx-{size}/2;p<=kx+{size}/2 && bo;p++)
for (int q=ky-{size}/2;q<=ky+{size}/2 && bo;q++)
if (p>=0 && q>=0 && p<in0_shape2 && q<in0_shape3)
if (@pout(i,j,k,l) == @in0(i,j,p,q)) {{
@out(i,j,p,q) += @dout(i,j,k,l);
bo=0;
}}
}}
'''])
return y
return pool.pool(x, size, 'maximum', padding, stride)
def concat(arr, dim):
# TODO: low performance when concat lots of vars

View File

@ -15,7 +15,7 @@ import numpy as np
import math
class Pool(Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, op="maximum"):
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
assert dilation == None
assert return_indices == None
self.kernel_size = kernel_size
@ -23,6 +23,7 @@ class Pool(Module):
self.stride = stride if stride else kernel_size
self.padding = padding
self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad and padding != 0
def execute(self, x):
N,C,H,W = x.shape
@ -33,7 +34,7 @@ class Pool(Module):
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
if self.op in ['maximum', 'minimum', 'mean']:
if self.op in ['maximum', 'minimum', 'mean'] and not self.count_include_pad:
forward_body = f'''{{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};

View File

@ -84,7 +84,13 @@ class TestArgPoolOp(unittest.TestCase):
from torch.nn import AvgPool2d
jt_model = Pool(2, 2, 0, op="mean", ceil_mode=True)
torch_model = AvgPool2d(2, 2, 0, ceil_mode=True)
# shape = [64, 64, 300, 300]
shape = (2, 16, 33, 33)
check(jt_model, torch_model, shape, False)
def test_cpu_avg_pool2(self):
from torch.nn import AvgPool2d
jt_model = Pool(3, 1, 1, op="mean", ceil_mode=True)
torch_model = AvgPool2d(3, 1, 1, ceil_mode=True)
shape = (2, 16, 33, 33)
check(jt_model, torch_model, shape, False)