mirror of https://github.com/Jittor/Jittor
polish pool interface
This commit is contained in:
parent
317e07907f
commit
86a3feeaab
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.2.3.87'
|
__version__ = '1.2.3.88'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -73,7 +73,7 @@ class Pool(Module):
|
||||||
for (int q = k3; q < k3_; ++q)
|
for (int q = k3; q < k3_; ++q)
|
||||||
if (out_value < @in0(i0, i1, p, q)) {{
|
if (out_value < @in0(i0, i1, p, q)) {{
|
||||||
out_value = @in0(i0, i1, p, q);
|
out_value = @in0(i0, i1, p, q);
|
||||||
out_index = (p - k2) * {self.kernel_size[1]} + (q - k3);
|
out_index = p * in0_shape3 + q;
|
||||||
}}
|
}}
|
||||||
@out(i0, i1, i2, i3) = out_value;
|
@out(i0, i1, i2, i3) = out_value;
|
||||||
@out1(i0, i1, i2, i3) = out_index;
|
@out1(i0, i1, i2, i3) = out_index;
|
||||||
|
@ -99,10 +99,7 @@ class Pool(Module):
|
||||||
'''
|
'''
|
||||||
if self.return_indices:
|
if self.return_indices:
|
||||||
return_shapes = [[N,C,h,w]] * 2
|
return_shapes = [[N,C,h,w]] * 2
|
||||||
return_dtypes = [x.dtype, 'uint8']
|
return_dtypes = [x.dtype, 'int32']
|
||||||
ks = self.kernel_size[0] * self.kernel_size[1]
|
|
||||||
if ks > 65536: return_dtypes[1] = 'uint32'
|
|
||||||
elif ks > 256: return_dtypes[1] = 'uint16'
|
|
||||||
else:
|
else:
|
||||||
return_shapes = [N,C,h,w]
|
return_shapes = [N,C,h,w]
|
||||||
return_dtypes = x.dtype
|
return_dtypes = x.dtype
|
||||||
|
@ -260,7 +257,7 @@ class Pool3d(Module):
|
||||||
for (int r = k4; q < k4_; ++r)
|
for (int r = k4; q < k4_; ++r)
|
||||||
if (out_value < @in0(i0, i1, p, q, r)) {{
|
if (out_value < @in0(i0, i1, p, q, r)) {{
|
||||||
out_value = @in0(i0, i1, p, q, r);
|
out_value = @in0(i0, i1, p, q, r);
|
||||||
out_index = (p - k2) * {self.kernel_size[1]} * {self.kernel_size[2]} + (q - k3) * {self.kernel_size[2]} + (r - k4);
|
out_index = p * in0_shape3 * in0_shape4 + q * in0_shape4 + r;
|
||||||
}}
|
}}
|
||||||
@out(i0, i1, i2, i3, i4) = out_value;
|
@out(i0, i1, i2, i3, i4) = out_value;
|
||||||
@out1(i0, i1, i2, i3, i4) = out_index;
|
@out1(i0, i1, i2, i3, i4) = out_index;
|
||||||
|
@ -290,10 +287,7 @@ class Pool3d(Module):
|
||||||
'''
|
'''
|
||||||
if self.return_indices:
|
if self.return_indices:
|
||||||
return_shapes = [[N,C,d,h,w]] * 2
|
return_shapes = [[N,C,d,h,w]] * 2
|
||||||
return_dtypes = [x.dtype, 'uint8']
|
return_dtypes = [x.dtype, 'int32']
|
||||||
ks = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
|
|
||||||
if ks > 65536: return_dtypes[1] = 'uint32'
|
|
||||||
elif ks > 256: return_dtypes[1] = 'uint16'
|
|
||||||
else:
|
else:
|
||||||
return_shapes = [N,C,d,h,w]
|
return_shapes = [N,C,d,h,w]
|
||||||
return_dtypes = x.dtype
|
return_dtypes = x.dtype
|
||||||
|
@ -420,8 +414,9 @@ class AdaptiveAvgPool2d(Module):
|
||||||
return xx.reduce("mean", [4,5])
|
return xx.reduce("mean", [4,5])
|
||||||
|
|
||||||
class AdaptiveMaxPool2d(Module):
|
class AdaptiveMaxPool2d(Module):
|
||||||
def __init__(self, output_size):
|
def __init__(self, output_size, return_indices=False):
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
|
self.return_indices = return_indices
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
if isinstance(self.output_size, int):
|
if isinstance(self.output_size, int):
|
||||||
|
@ -439,6 +434,10 @@ class AdaptiveMaxPool2d(Module):
|
||||||
self.sw = math.floor(W / ow)
|
self.sw = math.floor(W / ow)
|
||||||
self.ksh = H - (oh - 1) * self.sh
|
self.ksh = H - (oh - 1) * self.sh
|
||||||
self.ksw = W - (ow - 1) * self.sw
|
self.ksw = W - (ow - 1) * self.sw
|
||||||
|
if self.return_indices:
|
||||||
|
return MaxPool2d(
|
||||||
|
kernel_size=(self.ksh, self.ksw),
|
||||||
|
stride=(self.sh, self.sw), return_indices=True)(x)
|
||||||
h = (H-self.ksh)//self.sh+1
|
h = (H-self.ksh)//self.sh+1
|
||||||
w = (W-self.ksw)//self.sw+1
|
w = (W-self.ksw)//self.sw+1
|
||||||
xx = x.reindex([N,C,h,w,self.ksh,self.ksw], [
|
xx = x.reindex([N,C,h,w,self.ksh,self.ksw], [
|
||||||
|
@ -478,12 +477,13 @@ class AdaptiveAvgPool3d(Module):
|
||||||
return xx.reduce("mean", [5,6,7])
|
return xx.reduce("mean", [5,6,7])
|
||||||
|
|
||||||
class AdaptiveMaxPool3d(Module):
|
class AdaptiveMaxPool3d(Module):
|
||||||
def __init__(self, output_size):
|
def __init__(self, output_size, return_indices=False):
|
||||||
self.output_size = _triple(output_size)
|
self.output_size = _triple(output_size)
|
||||||
|
self.return_indices = return_indices
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
od, oh, ow = self.output_size
|
od, oh, ow = self.output_size
|
||||||
if od == 1 and oh == 1 and ow == 1:
|
if od == 1 and oh == 1 and ow == 1 and not self.return_indices:
|
||||||
return x.reduce("maximum", [2,3,4], keepdims=True)
|
return x.reduce("maximum", [2,3,4], keepdims=True)
|
||||||
N,C,D,H,W = x.shape
|
N,C,D,H,W = x.shape
|
||||||
self.sd = math.floor(D / od)
|
self.sd = math.floor(D / od)
|
||||||
|
@ -492,6 +492,10 @@ class AdaptiveMaxPool3d(Module):
|
||||||
self.ksd = D - (od - 1) * self.sd
|
self.ksd = D - (od - 1) * self.sd
|
||||||
self.ksh = H - (oh - 1) * self.sh
|
self.ksh = H - (oh - 1) * self.sh
|
||||||
self.ksw = W - (ow - 1) * self.sw
|
self.ksw = W - (ow - 1) * self.sw
|
||||||
|
if self.return_indices:
|
||||||
|
return MaxPool3d(
|
||||||
|
kernel_size=(self.ksd, self.ksh, self.ksw),
|
||||||
|
stride=(self.sd, self.sh, self.sw), return_indices=True)(x)
|
||||||
d = (D-self.ksd)//self.sd+1
|
d = (D-self.ksd)//self.sd+1
|
||||||
h = (H-self.ksh)//self.sh+1
|
h = (H-self.ksh)//self.sh+1
|
||||||
w = (W-self.ksw)//self.sw+1
|
w = (W-self.ksw)//self.sw+1
|
||||||
|
@ -597,15 +601,15 @@ class MaxUnpool2d(Module):
|
||||||
indexes=['i0', 'i1', f'i2/{kh}', f'i3/{kw}'],
|
indexes=['i0', 'i1', f'i2/{kh}', f'i3/{kw}'],
|
||||||
extras=[id],
|
extras=[id],
|
||||||
overflow_conditions=[
|
overflow_conditions=[
|
||||||
f'((i2%{kh})*{kw}+i3%{kw}) != @e0(i0,i1,i2/{kh},i3/{kw})'],
|
f'(i2*yshape3+i3) != @e0(i0,i1,i2/{kh},i3/{kw})'],
|
||||||
overflow_value=0)
|
overflow_value=0)
|
||||||
else:
|
else:
|
||||||
x = x.reindex_reduce(
|
x = x.reindex_reduce(
|
||||||
op="add",
|
op="add",
|
||||||
shape=[b, c, h, w],
|
shape=[b, c, h, w],
|
||||||
indexes=['i0', 'i1',
|
indexes=['i0', 'i1',
|
||||||
f'i2*{sh}+@e0(i0,i1,i2,i3)/{kw}',
|
f'@e0(i0,i1,i2,i3)/xshape3',
|
||||||
f'i3*{sw}+@e0(i0,i1,i2,i3)%{kw}'],
|
f'@e0(i0,i1,i2,i3)%xshape3'],
|
||||||
extras=[id],
|
extras=[id],
|
||||||
)
|
)
|
||||||
return x
|
return x
|
||||||
|
@ -635,16 +639,16 @@ class MaxUnpool3d(Module):
|
||||||
indexes=['i0', 'i1', f'i2/{kd}', f'i3/{kh}', f'i4/{kw}'],
|
indexes=['i0', 'i1', f'i2/{kd}', f'i3/{kh}', f'i4/{kw}'],
|
||||||
extras=[id],
|
extras=[id],
|
||||||
overflow_conditions=[
|
overflow_conditions=[
|
||||||
f'((i2%{kd})*{kh*kw}+(i3%{kh})*{kw}+i4%{kw}) != @e0(i0,i1,i2/{kd},i3/{kh},i4/{kw})'],
|
f'(i2*yshape3*yshape4+i3*yshape4+i4) != @e0(i0,i1,i2/{kd},i3/{kh},i4/{kw})'],
|
||||||
overflow_value=0)
|
overflow_value=0)
|
||||||
else:
|
else:
|
||||||
x = x.reindex_reduce(
|
x = x.reindex_reduce(
|
||||||
op="add",
|
op="add",
|
||||||
shape=[b, c, d, h, w],
|
shape=[b, c, d, h, w],
|
||||||
indexes=['i0', 'i1',
|
indexes=['i0', 'i1',
|
||||||
f'i2*{sd}+@e0(i0,i1,i2,i3,i4)/{kh*kw}',
|
f'@e0(i0,i1,i2,i3,i4)/(xshape4*xshape3)',
|
||||||
f'i3*{sh}+@e0(i0,i1,i2,i3,i4)/{kw}%{kh}',
|
f'@e0(i0,i1,i2,i3,i4)/xshape4%xshape3',
|
||||||
f'i4*{sw}+@e0(i0,i1,i2,i3,i4)%{kw}'],
|
f'@e0(i0,i1,i2,i3,i4)%xshape4'],
|
||||||
extras=[id],
|
extras=[id],
|
||||||
)
|
)
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -156,7 +156,7 @@ class TestArgPoolOp(unittest.TestCase):
|
||||||
0,0,0,0,
|
0,0,0,0,
|
||||||
1,0,0,1]).reshape((1,1,4,4))
|
1,0,0,1]).reshape((1,1,4,4))
|
||||||
b, idx = pool(a)
|
b, idx = pool(a)
|
||||||
assert (idx.data.reshape((4,)) == [0,1,2,3]).all()
|
assert (idx.data.reshape((4,)) == [0,3,12,15]).all()
|
||||||
|
|
||||||
def test_unpool(self):
|
def test_unpool(self):
|
||||||
from jittor import nn
|
from jittor import nn
|
||||||
|
@ -168,6 +168,7 @@ class TestArgPoolOp(unittest.TestCase):
|
||||||
[13, 14, 15, 16,0],
|
[13, 14, 15, 16,0],
|
||||||
[0, 0, 0, 0, 0]]]])
|
[0, 0, 0, 0, 0]]]])
|
||||||
output, indices = pool(input)
|
output, indices = pool(input)
|
||||||
|
assert (indices == jt.array([[6,8],[16,18]])).all()
|
||||||
out = unpool(output, indices, output_size=input.shape)
|
out = unpool(output, indices, output_size=input.shape)
|
||||||
assert (out == jt.array([[[[ 0., 0., 0., 0., 0.],
|
assert (out == jt.array([[[[ 0., 0., 0., 0., 0.],
|
||||||
[ 0., 6., 0., 8., 0.],
|
[ 0., 6., 0., 8., 0.],
|
||||||
|
|
Loading…
Reference in New Issue