polish pool interface

This commit is contained in:
Dun Liang 2021-08-04 19:28:25 +08:00
parent 317e07907f
commit 86a3feeaab
3 changed files with 27 additions and 22 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.87'
__version__ = '1.2.3.88'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -73,7 +73,7 @@ class Pool(Module):
for (int q = k3; q < k3_; ++q)
if (out_value < @in0(i0, i1, p, q)) {{
out_value = @in0(i0, i1, p, q);
out_index = (p - k2) * {self.kernel_size[1]} + (q - k3);
out_index = p * in0_shape3 + q;
}}
@out(i0, i1, i2, i3) = out_value;
@out1(i0, i1, i2, i3) = out_index;
@ -99,10 +99,7 @@ class Pool(Module):
'''
if self.return_indices:
return_shapes = [[N,C,h,w]] * 2
return_dtypes = [x.dtype, 'uint8']
ks = self.kernel_size[0] * self.kernel_size[1]
if ks > 65536: return_dtypes[1] = 'uint32'
elif ks > 256: return_dtypes[1] = 'uint16'
return_dtypes = [x.dtype, 'int32']
else:
return_shapes = [N,C,h,w]
return_dtypes = x.dtype
@ -260,7 +257,7 @@ class Pool3d(Module):
for (int r = k4; q < k4_; ++r)
if (out_value < @in0(i0, i1, p, q, r)) {{
out_value = @in0(i0, i1, p, q, r);
out_index = (p - k2) * {self.kernel_size[1]} * {self.kernel_size[2]} + (q - k3) * {self.kernel_size[2]} + (r - k4);
out_index = p * in0_shape3 * in0_shape4 + q * in0_shape4 + r;
}}
@out(i0, i1, i2, i3, i4) = out_value;
@out1(i0, i1, i2, i3, i4) = out_index;
@ -290,10 +287,7 @@ class Pool3d(Module):
'''
if self.return_indices:
return_shapes = [[N,C,d,h,w]] * 2
return_dtypes = [x.dtype, 'uint8']
ks = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
if ks > 65536: return_dtypes[1] = 'uint32'
elif ks > 256: return_dtypes[1] = 'uint16'
return_dtypes = [x.dtype, 'int32']
else:
return_shapes = [N,C,d,h,w]
return_dtypes = x.dtype
@ -420,8 +414,9 @@ class AdaptiveAvgPool2d(Module):
return xx.reduce("mean", [4,5])
class AdaptiveMaxPool2d(Module):
def __init__(self, output_size):
def __init__(self, output_size, return_indices=False):
self.output_size = output_size
self.return_indices = return_indices
def execute(self, x):
if isinstance(self.output_size, int):
@ -439,6 +434,10 @@ class AdaptiveMaxPool2d(Module):
self.sw = math.floor(W / ow)
self.ksh = H - (oh - 1) * self.sh
self.ksw = W - (ow - 1) * self.sw
if self.return_indices:
return MaxPool2d(
kernel_size=(self.ksh, self.ksw),
stride=(self.sh, self.sw), return_indices=True)(x)
h = (H-self.ksh)//self.sh+1
w = (W-self.ksw)//self.sw+1
xx = x.reindex([N,C,h,w,self.ksh,self.ksw], [
@ -478,12 +477,13 @@ class AdaptiveAvgPool3d(Module):
return xx.reduce("mean", [5,6,7])
class AdaptiveMaxPool3d(Module):
def __init__(self, output_size):
def __init__(self, output_size, return_indices=False):
self.output_size = _triple(output_size)
self.return_indices = return_indices
def execute(self, x):
od, oh, ow = self.output_size
if od == 1 and oh == 1 and ow == 1:
if od == 1 and oh == 1 and ow == 1 and not self.return_indices:
return x.reduce("maximum", [2,3,4], keepdims=True)
N,C,D,H,W = x.shape
self.sd = math.floor(D / od)
@ -492,6 +492,10 @@ class AdaptiveMaxPool3d(Module):
self.ksd = D - (od - 1) * self.sd
self.ksh = H - (oh - 1) * self.sh
self.ksw = W - (ow - 1) * self.sw
if self.return_indices:
return MaxPool3d(
kernel_size=(self.ksd, self.ksh, self.ksw),
stride=(self.sd, self.sh, self.sw), return_indices=True)(x)
d = (D-self.ksd)//self.sd+1
h = (H-self.ksh)//self.sh+1
w = (W-self.ksw)//self.sw+1
@ -597,15 +601,15 @@ class MaxUnpool2d(Module):
indexes=['i0', 'i1', f'i2/{kh}', f'i3/{kw}'],
extras=[id],
overflow_conditions=[
f'((i2%{kh})*{kw}+i3%{kw}) != @e0(i0,i1,i2/{kh},i3/{kw})'],
f'(i2*yshape3+i3) != @e0(i0,i1,i2/{kh},i3/{kw})'],
overflow_value=0)
else:
x = x.reindex_reduce(
op="add",
shape=[b, c, h, w],
indexes=['i0', 'i1',
f'i2*{sh}+@e0(i0,i1,i2,i3)/{kw}',
f'i3*{sw}+@e0(i0,i1,i2,i3)%{kw}'],
f'@e0(i0,i1,i2,i3)/xshape3',
f'@e0(i0,i1,i2,i3)%xshape3'],
extras=[id],
)
return x
@ -635,16 +639,16 @@ class MaxUnpool3d(Module):
indexes=['i0', 'i1', f'i2/{kd}', f'i3/{kh}', f'i4/{kw}'],
extras=[id],
overflow_conditions=[
f'((i2%{kd})*{kh*kw}+(i3%{kh})*{kw}+i4%{kw}) != @e0(i0,i1,i2/{kd},i3/{kh},i4/{kw})'],
f'(i2*yshape3*yshape4+i3*yshape4+i4) != @e0(i0,i1,i2/{kd},i3/{kh},i4/{kw})'],
overflow_value=0)
else:
x = x.reindex_reduce(
op="add",
shape=[b, c, d, h, w],
indexes=['i0', 'i1',
f'i2*{sd}+@e0(i0,i1,i2,i3,i4)/{kh*kw}',
f'i3*{sh}+@e0(i0,i1,i2,i3,i4)/{kw}%{kh}',
f'i4*{sw}+@e0(i0,i1,i2,i3,i4)%{kw}'],
f'@e0(i0,i1,i2,i3,i4)/(xshape4*xshape3)',
f'@e0(i0,i1,i2,i3,i4)/xshape4%xshape3',
f'@e0(i0,i1,i2,i3,i4)%xshape4'],
extras=[id],
)
return x

View File

@ -156,7 +156,7 @@ class TestArgPoolOp(unittest.TestCase):
0,0,0,0,
1,0,0,1]).reshape((1,1,4,4))
b, idx = pool(a)
assert (idx.data.reshape((4,)) == [0,1,2,3]).all()
assert (idx.data.reshape((4,)) == [0,3,12,15]).all()
def test_unpool(self):
from jittor import nn
@ -168,6 +168,7 @@ class TestArgPoolOp(unittest.TestCase):
[13, 14, 15, 16,0],
[0, 0, 0, 0, 0]]]])
output, indices = pool(input)
assert (indices == jt.array([[6,8],[16,18]])).all()
out = unpool(output, indices, output_size=input.shape)
assert (out == jt.array([[[[ 0., 0., 0., 0., 0.],
[ 0., 6., 0., 8., 0.],