mirror of https://github.com/Jittor/Jittor
add count_include_pad for pad
This commit is contained in:
parent
93ca5e9525
commit
a6dab87634
|
@ -9,51 +9,10 @@
|
|||
# ***************************************************************
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor import pool
|
||||
|
||||
def argmax_pool(x, size, stride, padding=0):
|
||||
y_shape = list(x.shape)
|
||||
y_shape[2]=(x.shape[2]+padding*2-size)//stride+1
|
||||
y_shape[3]=(x.shape[3]+padding*2-size)//stride+1
|
||||
|
||||
y = jt.code(y_shape, x.dtype, [x],
|
||||
cpu_src=f'''
|
||||
for (int i=0; i<out_shape0; i++)
|
||||
for (int j=0; j<out_shape1; j++)
|
||||
for (int k=0; k<out_shape2; k++)
|
||||
for (int l=0; l<out_shape3; l++) {{
|
||||
int kx=k*{stride}+{size}/2-{padding};
|
||||
int ky=l*{stride}+{size}/2-{padding};
|
||||
@out(i,j,k,l) = @in0(i,j,kx,ky);
|
||||
for (int p=kx-{size}/2;p<=kx+{size}/2;p++)
|
||||
for (int q=ky-{size}/2;q<=ky+{size}/2;q++)
|
||||
if (p>=0 && q>=0 && p<in0_shape2 && q<in0_shape3)
|
||||
if (@out(i,j,k,l) < @in0(i,j,p,q))
|
||||
@out(i,j,k,l) = @in0(i,j,p,q);
|
||||
}}
|
||||
''',
|
||||
cpu_grad_src = [f'''
|
||||
for (int i=0; i<out_shape0; i++)
|
||||
for (int j=0; j<out_shape1; j++)
|
||||
for (int k=0; k<out_shape2; k++)
|
||||
for (int l=0; l<out_shape3; l++) @out(i,j,k,l) = 0;
|
||||
|
||||
for (int i=0; i<pout_shape0; i++)
|
||||
for (int j=0; j<pout_shape1; j++)
|
||||
for (int k=0; k<pout_shape2; k++)
|
||||
for (int l=0; l<pout_shape3; l++) {{
|
||||
int kx=k*{stride}+{size}/2-{padding};
|
||||
int ky=l*{stride}+{size}/2-{padding};
|
||||
int bo=1;
|
||||
for (int p=kx-{size}/2;p<=kx+{size}/2 && bo;p++)
|
||||
for (int q=ky-{size}/2;q<=ky+{size}/2 && bo;q++)
|
||||
if (p>=0 && q>=0 && p<in0_shape2 && q<in0_shape3)
|
||||
if (@pout(i,j,k,l) == @in0(i,j,p,q)) {{
|
||||
@out(i,j,p,q) += @dout(i,j,k,l);
|
||||
bo=0;
|
||||
}}
|
||||
}}
|
||||
'''])
|
||||
return y
|
||||
return pool.pool(x, size, 'maximum', padding, stride)
|
||||
|
||||
def concat(arr, dim):
|
||||
# TODO: low performance when concat lots of vars
|
||||
|
|
|
@ -15,7 +15,7 @@ import numpy as np
|
|||
import math
|
||||
|
||||
class Pool(Module):
|
||||
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, op="maximum"):
|
||||
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
|
||||
assert dilation == None
|
||||
assert return_indices == None
|
||||
self.kernel_size = kernel_size
|
||||
|
@ -23,6 +23,7 @@ class Pool(Module):
|
|||
self.stride = stride if stride else kernel_size
|
||||
self.padding = padding
|
||||
self.ceil_mode = ceil_mode
|
||||
self.count_include_pad = count_include_pad and padding != 0
|
||||
|
||||
def execute(self, x):
|
||||
N,C,H,W = x.shape
|
||||
|
@ -33,7 +34,7 @@ class Pool(Module):
|
|||
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
|
||||
if self.op in ['maximum', 'minimum', 'mean']:
|
||||
if self.op in ['maximum', 'minimum', 'mean'] and not self.count_include_pad:
|
||||
forward_body = f'''{{
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
|
|
|
@ -84,7 +84,13 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
from torch.nn import AvgPool2d
|
||||
jt_model = Pool(2, 2, 0, op="mean", ceil_mode=True)
|
||||
torch_model = AvgPool2d(2, 2, 0, ceil_mode=True)
|
||||
# shape = [64, 64, 300, 300]
|
||||
shape = (2, 16, 33, 33)
|
||||
check(jt_model, torch_model, shape, False)
|
||||
|
||||
def test_cpu_avg_pool2(self):
|
||||
from torch.nn import AvgPool2d
|
||||
jt_model = Pool(3, 1, 1, op="mean", ceil_mode=True)
|
||||
torch_model = AvgPool2d(3, 1, 1, ceil_mode=True)
|
||||
shape = (2, 16, 33, 33)
|
||||
check(jt_model, torch_model, shape, False)
|
||||
|
||||
|
|
Loading…
Reference in New Issue