polish pool interface

This commit is contained in:
Dun Liang 2021-08-04 19:28:25 +08:00
parent 317e07907f
commit 86a3feeaab
3 changed files with 27 additions and 22 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.2.3.87' __version__ = '1.2.3.88'
from jittor_utils import lock from jittor_utils import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int

View File

@ -73,7 +73,7 @@ class Pool(Module):
for (int q = k3; q < k3_; ++q) for (int q = k3; q < k3_; ++q)
if (out_value < @in0(i0, i1, p, q)) {{ if (out_value < @in0(i0, i1, p, q)) {{
out_value = @in0(i0, i1, p, q); out_value = @in0(i0, i1, p, q);
out_index = (p - k2) * {self.kernel_size[1]} + (q - k3); out_index = p * in0_shape3 + q;
}} }}
@out(i0, i1, i2, i3) = out_value; @out(i0, i1, i2, i3) = out_value;
@out1(i0, i1, i2, i3) = out_index; @out1(i0, i1, i2, i3) = out_index;
@ -99,10 +99,7 @@ class Pool(Module):
''' '''
if self.return_indices: if self.return_indices:
return_shapes = [[N,C,h,w]] * 2 return_shapes = [[N,C,h,w]] * 2
return_dtypes = [x.dtype, 'uint8'] return_dtypes = [x.dtype, 'int32']
ks = self.kernel_size[0] * self.kernel_size[1]
if ks > 65536: return_dtypes[1] = 'uint32'
elif ks > 256: return_dtypes[1] = 'uint16'
else: else:
return_shapes = [N,C,h,w] return_shapes = [N,C,h,w]
return_dtypes = x.dtype return_dtypes = x.dtype
@ -260,7 +257,7 @@ class Pool3d(Module):
for (int r = k4; q < k4_; ++r) for (int r = k4; q < k4_; ++r)
if (out_value < @in0(i0, i1, p, q, r)) {{ if (out_value < @in0(i0, i1, p, q, r)) {{
out_value = @in0(i0, i1, p, q, r); out_value = @in0(i0, i1, p, q, r);
out_index = (p - k2) * {self.kernel_size[1]} * {self.kernel_size[2]} + (q - k3) * {self.kernel_size[2]} + (r - k4); out_index = p * in0_shape3 * in0_shape4 + q * in0_shape4 + r;
}} }}
@out(i0, i1, i2, i3, i4) = out_value; @out(i0, i1, i2, i3, i4) = out_value;
@out1(i0, i1, i2, i3, i4) = out_index; @out1(i0, i1, i2, i3, i4) = out_index;
@ -290,10 +287,7 @@ class Pool3d(Module):
''' '''
if self.return_indices: if self.return_indices:
return_shapes = [[N,C,d,h,w]] * 2 return_shapes = [[N,C,d,h,w]] * 2
return_dtypes = [x.dtype, 'uint8'] return_dtypes = [x.dtype, 'int32']
ks = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
if ks > 65536: return_dtypes[1] = 'uint32'
elif ks > 256: return_dtypes[1] = 'uint16'
else: else:
return_shapes = [N,C,d,h,w] return_shapes = [N,C,d,h,w]
return_dtypes = x.dtype return_dtypes = x.dtype
@ -420,8 +414,9 @@ class AdaptiveAvgPool2d(Module):
return xx.reduce("mean", [4,5]) return xx.reduce("mean", [4,5])
class AdaptiveMaxPool2d(Module): class AdaptiveMaxPool2d(Module):
def __init__(self, output_size): def __init__(self, output_size, return_indices=False):
self.output_size = output_size self.output_size = output_size
self.return_indices = return_indices
def execute(self, x): def execute(self, x):
if isinstance(self.output_size, int): if isinstance(self.output_size, int):
@ -439,6 +434,10 @@ class AdaptiveMaxPool2d(Module):
self.sw = math.floor(W / ow) self.sw = math.floor(W / ow)
self.ksh = H - (oh - 1) * self.sh self.ksh = H - (oh - 1) * self.sh
self.ksw = W - (ow - 1) * self.sw self.ksw = W - (ow - 1) * self.sw
if self.return_indices:
return MaxPool2d(
kernel_size=(self.ksh, self.ksw),
stride=(self.sh, self.sw), return_indices=True)(x)
h = (H-self.ksh)//self.sh+1 h = (H-self.ksh)//self.sh+1
w = (W-self.ksw)//self.sw+1 w = (W-self.ksw)//self.sw+1
xx = x.reindex([N,C,h,w,self.ksh,self.ksw], [ xx = x.reindex([N,C,h,w,self.ksh,self.ksw], [
@ -478,12 +477,13 @@ class AdaptiveAvgPool3d(Module):
return xx.reduce("mean", [5,6,7]) return xx.reduce("mean", [5,6,7])
class AdaptiveMaxPool3d(Module): class AdaptiveMaxPool3d(Module):
def __init__(self, output_size): def __init__(self, output_size, return_indices=False):
self.output_size = _triple(output_size) self.output_size = _triple(output_size)
self.return_indices = return_indices
def execute(self, x): def execute(self, x):
od, oh, ow = self.output_size od, oh, ow = self.output_size
if od == 1 and oh == 1 and ow == 1: if od == 1 and oh == 1 and ow == 1 and not self.return_indices:
return x.reduce("maximum", [2,3,4], keepdims=True) return x.reduce("maximum", [2,3,4], keepdims=True)
N,C,D,H,W = x.shape N,C,D,H,W = x.shape
self.sd = math.floor(D / od) self.sd = math.floor(D / od)
@ -492,6 +492,10 @@ class AdaptiveMaxPool3d(Module):
self.ksd = D - (od - 1) * self.sd self.ksd = D - (od - 1) * self.sd
self.ksh = H - (oh - 1) * self.sh self.ksh = H - (oh - 1) * self.sh
self.ksw = W - (ow - 1) * self.sw self.ksw = W - (ow - 1) * self.sw
if self.return_indices:
return MaxPool3d(
kernel_size=(self.ksd, self.ksh, self.ksw),
stride=(self.sd, self.sh, self.sw), return_indices=True)(x)
d = (D-self.ksd)//self.sd+1 d = (D-self.ksd)//self.sd+1
h = (H-self.ksh)//self.sh+1 h = (H-self.ksh)//self.sh+1
w = (W-self.ksw)//self.sw+1 w = (W-self.ksw)//self.sw+1
@ -597,15 +601,15 @@ class MaxUnpool2d(Module):
indexes=['i0', 'i1', f'i2/{kh}', f'i3/{kw}'], indexes=['i0', 'i1', f'i2/{kh}', f'i3/{kw}'],
extras=[id], extras=[id],
overflow_conditions=[ overflow_conditions=[
f'((i2%{kh})*{kw}+i3%{kw}) != @e0(i0,i1,i2/{kh},i3/{kw})'], f'(i2*yshape3+i3) != @e0(i0,i1,i2/{kh},i3/{kw})'],
overflow_value=0) overflow_value=0)
else: else:
x = x.reindex_reduce( x = x.reindex_reduce(
op="add", op="add",
shape=[b, c, h, w], shape=[b, c, h, w],
indexes=['i0', 'i1', indexes=['i0', 'i1',
f'i2*{sh}+@e0(i0,i1,i2,i3)/{kw}', f'@e0(i0,i1,i2,i3)/xshape3',
f'i3*{sw}+@e0(i0,i1,i2,i3)%{kw}'], f'@e0(i0,i1,i2,i3)%xshape3'],
extras=[id], extras=[id],
) )
return x return x
@ -635,16 +639,16 @@ class MaxUnpool3d(Module):
indexes=['i0', 'i1', f'i2/{kd}', f'i3/{kh}', f'i4/{kw}'], indexes=['i0', 'i1', f'i2/{kd}', f'i3/{kh}', f'i4/{kw}'],
extras=[id], extras=[id],
overflow_conditions=[ overflow_conditions=[
f'((i2%{kd})*{kh*kw}+(i3%{kh})*{kw}+i4%{kw}) != @e0(i0,i1,i2/{kd},i3/{kh},i4/{kw})'], f'(i2*yshape3*yshape4+i3*yshape4+i4) != @e0(i0,i1,i2/{kd},i3/{kh},i4/{kw})'],
overflow_value=0) overflow_value=0)
else: else:
x = x.reindex_reduce( x = x.reindex_reduce(
op="add", op="add",
shape=[b, c, d, h, w], shape=[b, c, d, h, w],
indexes=['i0', 'i1', indexes=['i0', 'i1',
f'i2*{sd}+@e0(i0,i1,i2,i3,i4)/{kh*kw}', f'@e0(i0,i1,i2,i3,i4)/(xshape4*xshape3)',
f'i3*{sh}+@e0(i0,i1,i2,i3,i4)/{kw}%{kh}', f'@e0(i0,i1,i2,i3,i4)/xshape4%xshape3',
f'i4*{sw}+@e0(i0,i1,i2,i3,i4)%{kw}'], f'@e0(i0,i1,i2,i3,i4)%xshape4'],
extras=[id], extras=[id],
) )
return x return x

View File

@ -156,7 +156,7 @@ class TestArgPoolOp(unittest.TestCase):
0,0,0,0, 0,0,0,0,
1,0,0,1]).reshape((1,1,4,4)) 1,0,0,1]).reshape((1,1,4,4))
b, idx = pool(a) b, idx = pool(a)
assert (idx.data.reshape((4,)) == [0,1,2,3]).all() assert (idx.data.reshape((4,)) == [0,3,12,15]).all()
def test_unpool(self): def test_unpool(self):
from jittor import nn from jittor import nn
@ -168,6 +168,7 @@ class TestArgPoolOp(unittest.TestCase):
[13, 14, 15, 16,0], [13, 14, 15, 16,0],
[0, 0, 0, 0, 0]]]]) [0, 0, 0, 0, 0]]]])
output, indices = pool(input) output, indices = pool(input)
assert (indices == jt.array([[6,8],[16,18]])).all()
out = unpool(output, indices, output_size=input.shape) out = unpool(output, indices, output_size=input.shape)
assert (out == jt.array([[[[ 0., 0., 0., 0., 0.], assert (out == jt.array([[[[ 0., 0., 0., 0., 0.],
[ 0., 6., 0., 8., 0.], [ 0., 6., 0., 8., 0.],