mirror of https://github.com/Jittor/Jittor
polish pool interface
This commit is contained in:
parent
317e07907f
commit
86a3feeaab
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.87'
|
||||
__version__ = '1.2.3.88'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -73,7 +73,7 @@ class Pool(Module):
|
|||
for (int q = k3; q < k3_; ++q)
|
||||
if (out_value < @in0(i0, i1, p, q)) {{
|
||||
out_value = @in0(i0, i1, p, q);
|
||||
out_index = (p - k2) * {self.kernel_size[1]} + (q - k3);
|
||||
out_index = p * in0_shape3 + q;
|
||||
}}
|
||||
@out(i0, i1, i2, i3) = out_value;
|
||||
@out1(i0, i1, i2, i3) = out_index;
|
||||
|
@ -99,10 +99,7 @@ class Pool(Module):
|
|||
'''
|
||||
if self.return_indices:
|
||||
return_shapes = [[N,C,h,w]] * 2
|
||||
return_dtypes = [x.dtype, 'uint8']
|
||||
ks = self.kernel_size[0] * self.kernel_size[1]
|
||||
if ks > 65536: return_dtypes[1] = 'uint32'
|
||||
elif ks > 256: return_dtypes[1] = 'uint16'
|
||||
return_dtypes = [x.dtype, 'int32']
|
||||
else:
|
||||
return_shapes = [N,C,h,w]
|
||||
return_dtypes = x.dtype
|
||||
|
@ -260,7 +257,7 @@ class Pool3d(Module):
|
|||
for (int r = k4; q < k4_; ++r)
|
||||
if (out_value < @in0(i0, i1, p, q, r)) {{
|
||||
out_value = @in0(i0, i1, p, q, r);
|
||||
out_index = (p - k2) * {self.kernel_size[1]} * {self.kernel_size[2]} + (q - k3) * {self.kernel_size[2]} + (r - k4);
|
||||
out_index = p * in0_shape3 * in0_shape4 + q * in0_shape4 + r;
|
||||
}}
|
||||
@out(i0, i1, i2, i3, i4) = out_value;
|
||||
@out1(i0, i1, i2, i3, i4) = out_index;
|
||||
|
@ -290,10 +287,7 @@ class Pool3d(Module):
|
|||
'''
|
||||
if self.return_indices:
|
||||
return_shapes = [[N,C,d,h,w]] * 2
|
||||
return_dtypes = [x.dtype, 'uint8']
|
||||
ks = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
|
||||
if ks > 65536: return_dtypes[1] = 'uint32'
|
||||
elif ks > 256: return_dtypes[1] = 'uint16'
|
||||
return_dtypes = [x.dtype, 'int32']
|
||||
else:
|
||||
return_shapes = [N,C,d,h,w]
|
||||
return_dtypes = x.dtype
|
||||
|
@ -420,8 +414,9 @@ class AdaptiveAvgPool2d(Module):
|
|||
return xx.reduce("mean", [4,5])
|
||||
|
||||
class AdaptiveMaxPool2d(Module):
|
||||
def __init__(self, output_size):
|
||||
def __init__(self, output_size, return_indices=False):
|
||||
self.output_size = output_size
|
||||
self.return_indices = return_indices
|
||||
|
||||
def execute(self, x):
|
||||
if isinstance(self.output_size, int):
|
||||
|
@ -439,6 +434,10 @@ class AdaptiveMaxPool2d(Module):
|
|||
self.sw = math.floor(W / ow)
|
||||
self.ksh = H - (oh - 1) * self.sh
|
||||
self.ksw = W - (ow - 1) * self.sw
|
||||
if self.return_indices:
|
||||
return MaxPool2d(
|
||||
kernel_size=(self.ksh, self.ksw),
|
||||
stride=(self.sh, self.sw), return_indices=True)(x)
|
||||
h = (H-self.ksh)//self.sh+1
|
||||
w = (W-self.ksw)//self.sw+1
|
||||
xx = x.reindex([N,C,h,w,self.ksh,self.ksw], [
|
||||
|
@ -478,12 +477,13 @@ class AdaptiveAvgPool3d(Module):
|
|||
return xx.reduce("mean", [5,6,7])
|
||||
|
||||
class AdaptiveMaxPool3d(Module):
|
||||
def __init__(self, output_size):
|
||||
def __init__(self, output_size, return_indices=False):
|
||||
self.output_size = _triple(output_size)
|
||||
self.return_indices = return_indices
|
||||
|
||||
def execute(self, x):
|
||||
od, oh, ow = self.output_size
|
||||
if od == 1 and oh == 1 and ow == 1:
|
||||
if od == 1 and oh == 1 and ow == 1 and not self.return_indices:
|
||||
return x.reduce("maximum", [2,3,4], keepdims=True)
|
||||
N,C,D,H,W = x.shape
|
||||
self.sd = math.floor(D / od)
|
||||
|
@ -492,6 +492,10 @@ class AdaptiveMaxPool3d(Module):
|
|||
self.ksd = D - (od - 1) * self.sd
|
||||
self.ksh = H - (oh - 1) * self.sh
|
||||
self.ksw = W - (ow - 1) * self.sw
|
||||
if self.return_indices:
|
||||
return MaxPool3d(
|
||||
kernel_size=(self.ksd, self.ksh, self.ksw),
|
||||
stride=(self.sd, self.sh, self.sw), return_indices=True)(x)
|
||||
d = (D-self.ksd)//self.sd+1
|
||||
h = (H-self.ksh)//self.sh+1
|
||||
w = (W-self.ksw)//self.sw+1
|
||||
|
@ -597,15 +601,15 @@ class MaxUnpool2d(Module):
|
|||
indexes=['i0', 'i1', f'i2/{kh}', f'i3/{kw}'],
|
||||
extras=[id],
|
||||
overflow_conditions=[
|
||||
f'((i2%{kh})*{kw}+i3%{kw}) != @e0(i0,i1,i2/{kh},i3/{kw})'],
|
||||
f'(i2*yshape3+i3) != @e0(i0,i1,i2/{kh},i3/{kw})'],
|
||||
overflow_value=0)
|
||||
else:
|
||||
x = x.reindex_reduce(
|
||||
op="add",
|
||||
shape=[b, c, h, w],
|
||||
indexes=['i0', 'i1',
|
||||
f'i2*{sh}+@e0(i0,i1,i2,i3)/{kw}',
|
||||
f'i3*{sw}+@e0(i0,i1,i2,i3)%{kw}'],
|
||||
f'@e0(i0,i1,i2,i3)/xshape3',
|
||||
f'@e0(i0,i1,i2,i3)%xshape3'],
|
||||
extras=[id],
|
||||
)
|
||||
return x
|
||||
|
@ -635,16 +639,16 @@ class MaxUnpool3d(Module):
|
|||
indexes=['i0', 'i1', f'i2/{kd}', f'i3/{kh}', f'i4/{kw}'],
|
||||
extras=[id],
|
||||
overflow_conditions=[
|
||||
f'((i2%{kd})*{kh*kw}+(i3%{kh})*{kw}+i4%{kw}) != @e0(i0,i1,i2/{kd},i3/{kh},i4/{kw})'],
|
||||
f'(i2*yshape3*yshape4+i3*yshape4+i4) != @e0(i0,i1,i2/{kd},i3/{kh},i4/{kw})'],
|
||||
overflow_value=0)
|
||||
else:
|
||||
x = x.reindex_reduce(
|
||||
op="add",
|
||||
shape=[b, c, d, h, w],
|
||||
indexes=['i0', 'i1',
|
||||
f'i2*{sd}+@e0(i0,i1,i2,i3,i4)/{kh*kw}',
|
||||
f'i3*{sh}+@e0(i0,i1,i2,i3,i4)/{kw}%{kh}',
|
||||
f'i4*{sw}+@e0(i0,i1,i2,i3,i4)%{kw}'],
|
||||
f'@e0(i0,i1,i2,i3,i4)/(xshape4*xshape3)',
|
||||
f'@e0(i0,i1,i2,i3,i4)/xshape4%xshape3',
|
||||
f'@e0(i0,i1,i2,i3,i4)%xshape4'],
|
||||
extras=[id],
|
||||
)
|
||||
return x
|
||||
|
|
|
@ -156,7 +156,7 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
0,0,0,0,
|
||||
1,0,0,1]).reshape((1,1,4,4))
|
||||
b, idx = pool(a)
|
||||
assert (idx.data.reshape((4,)) == [0,1,2,3]).all()
|
||||
assert (idx.data.reshape((4,)) == [0,3,12,15]).all()
|
||||
|
||||
def test_unpool(self):
|
||||
from jittor import nn
|
||||
|
@ -168,6 +168,7 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
[13, 14, 15, 16,0],
|
||||
[0, 0, 0, 0, 0]]]])
|
||||
output, indices = pool(input)
|
||||
assert (indices == jt.array([[6,8],[16,18]])).all()
|
||||
out = unpool(output, indices, output_size=input.shape)
|
||||
assert (out == jt.array([[[[ 0., 0., 0., 0., 0.],
|
||||
[ 0., 6., 0., 8., 0.],
|
||||
|
|
Loading…
Reference in New Issue