mirror of https://github.com/Jittor/Jittor
add count_include_pad for pad
This commit is contained in:
parent
93ca5e9525
commit
a6dab87634
|
@ -9,51 +9,10 @@
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from jittor import pool
|
||||||
|
|
||||||
def argmax_pool(x, size, stride, padding=0):
|
def argmax_pool(x, size, stride, padding=0):
|
||||||
y_shape = list(x.shape)
|
return pool.pool(x, size, 'maximum', padding, stride)
|
||||||
y_shape[2]=(x.shape[2]+padding*2-size)//stride+1
|
|
||||||
y_shape[3]=(x.shape[3]+padding*2-size)//stride+1
|
|
||||||
|
|
||||||
y = jt.code(y_shape, x.dtype, [x],
|
|
||||||
cpu_src=f'''
|
|
||||||
for (int i=0; i<out_shape0; i++)
|
|
||||||
for (int j=0; j<out_shape1; j++)
|
|
||||||
for (int k=0; k<out_shape2; k++)
|
|
||||||
for (int l=0; l<out_shape3; l++) {{
|
|
||||||
int kx=k*{stride}+{size}/2-{padding};
|
|
||||||
int ky=l*{stride}+{size}/2-{padding};
|
|
||||||
@out(i,j,k,l) = @in0(i,j,kx,ky);
|
|
||||||
for (int p=kx-{size}/2;p<=kx+{size}/2;p++)
|
|
||||||
for (int q=ky-{size}/2;q<=ky+{size}/2;q++)
|
|
||||||
if (p>=0 && q>=0 && p<in0_shape2 && q<in0_shape3)
|
|
||||||
if (@out(i,j,k,l) < @in0(i,j,p,q))
|
|
||||||
@out(i,j,k,l) = @in0(i,j,p,q);
|
|
||||||
}}
|
|
||||||
''',
|
|
||||||
cpu_grad_src = [f'''
|
|
||||||
for (int i=0; i<out_shape0; i++)
|
|
||||||
for (int j=0; j<out_shape1; j++)
|
|
||||||
for (int k=0; k<out_shape2; k++)
|
|
||||||
for (int l=0; l<out_shape3; l++) @out(i,j,k,l) = 0;
|
|
||||||
|
|
||||||
for (int i=0; i<pout_shape0; i++)
|
|
||||||
for (int j=0; j<pout_shape1; j++)
|
|
||||||
for (int k=0; k<pout_shape2; k++)
|
|
||||||
for (int l=0; l<pout_shape3; l++) {{
|
|
||||||
int kx=k*{stride}+{size}/2-{padding};
|
|
||||||
int ky=l*{stride}+{size}/2-{padding};
|
|
||||||
int bo=1;
|
|
||||||
for (int p=kx-{size}/2;p<=kx+{size}/2 && bo;p++)
|
|
||||||
for (int q=ky-{size}/2;q<=ky+{size}/2 && bo;q++)
|
|
||||||
if (p>=0 && q>=0 && p<in0_shape2 && q<in0_shape3)
|
|
||||||
if (@pout(i,j,k,l) == @in0(i,j,p,q)) {{
|
|
||||||
@out(i,j,p,q) += @dout(i,j,k,l);
|
|
||||||
bo=0;
|
|
||||||
}}
|
|
||||||
}}
|
|
||||||
'''])
|
|
||||||
return y
|
|
||||||
|
|
||||||
def concat(arr, dim):
|
def concat(arr, dim):
|
||||||
# TODO: low performance when concat lots of vars
|
# TODO: low performance when concat lots of vars
|
||||||
|
|
|
@ -15,7 +15,7 @@ import numpy as np
|
||||||
import math
|
import math
|
||||||
|
|
||||||
class Pool(Module):
|
class Pool(Module):
|
||||||
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, op="maximum"):
|
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
|
||||||
assert dilation == None
|
assert dilation == None
|
||||||
assert return_indices == None
|
assert return_indices == None
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
|
@ -23,6 +23,7 @@ class Pool(Module):
|
||||||
self.stride = stride if stride else kernel_size
|
self.stride = stride if stride else kernel_size
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.ceil_mode = ceil_mode
|
self.ceil_mode = ceil_mode
|
||||||
|
self.count_include_pad = count_include_pad and padding != 0
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
N,C,H,W = x.shape
|
N,C,H,W = x.shape
|
||||||
|
@ -33,7 +34,7 @@ class Pool(Module):
|
||||||
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||||
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||||
|
|
||||||
if self.op in ['maximum', 'minimum', 'mean']:
|
if self.op in ['maximum', 'minimum', 'mean'] and not self.count_include_pad:
|
||||||
forward_body = f'''{{
|
forward_body = f'''{{
|
||||||
int k3 = i3*{self.stride}-{self.padding};
|
int k3 = i3*{self.stride}-{self.padding};
|
||||||
int k2 = i2*{self.stride}-{self.padding};
|
int k2 = i2*{self.stride}-{self.padding};
|
||||||
|
|
|
@ -84,7 +84,13 @@ class TestArgPoolOp(unittest.TestCase):
|
||||||
from torch.nn import AvgPool2d
|
from torch.nn import AvgPool2d
|
||||||
jt_model = Pool(2, 2, 0, op="mean", ceil_mode=True)
|
jt_model = Pool(2, 2, 0, op="mean", ceil_mode=True)
|
||||||
torch_model = AvgPool2d(2, 2, 0, ceil_mode=True)
|
torch_model = AvgPool2d(2, 2, 0, ceil_mode=True)
|
||||||
# shape = [64, 64, 300, 300]
|
shape = (2, 16, 33, 33)
|
||||||
|
check(jt_model, torch_model, shape, False)
|
||||||
|
|
||||||
|
def test_cpu_avg_pool2(self):
|
||||||
|
from torch.nn import AvgPool2d
|
||||||
|
jt_model = Pool(3, 1, 1, op="mean", ceil_mode=True)
|
||||||
|
torch_model = AvgPool2d(3, 1, 1, ceil_mode=True)
|
||||||
shape = (2, 16, 33, 33)
|
shape = (2, 16, 33, 33)
|
||||||
check(jt_model, torch_model, shape, False)
|
check(jt_model, torch_model, shape, False)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue