diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index e8b7df91..752857fe 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.3.87' +__version__ = '1.2.3.88' from jittor_utils import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/pool.py b/python/jittor/pool.py index a34c564b..150aff68 100644 --- a/python/jittor/pool.py +++ b/python/jittor/pool.py @@ -73,7 +73,7 @@ class Pool(Module): for (int q = k3; q < k3_; ++q) if (out_value < @in0(i0, i1, p, q)) {{ out_value = @in0(i0, i1, p, q); - out_index = (p - k2) * {self.kernel_size[1]} + (q - k3); + out_index = p * in0_shape3 + q; }} @out(i0, i1, i2, i3) = out_value; @out1(i0, i1, i2, i3) = out_index; @@ -99,10 +99,7 @@ class Pool(Module): ''' if self.return_indices: return_shapes = [[N,C,h,w]] * 2 - return_dtypes = [x.dtype, 'uint8'] - ks = self.kernel_size[0] * self.kernel_size[1] - if ks > 65536: return_dtypes[1] = 'uint32' - elif ks > 256: return_dtypes[1] = 'uint16' + return_dtypes = [x.dtype, 'int32'] else: return_shapes = [N,C,h,w] return_dtypes = x.dtype @@ -260,7 +257,7 @@ class Pool3d(Module): for (int r = k4; q < k4_; ++r) if (out_value < @in0(i0, i1, p, q, r)) {{ out_value = @in0(i0, i1, p, q, r); - out_index = (p - k2) * {self.kernel_size[1]} * {self.kernel_size[2]} + (q - k3) * {self.kernel_size[2]} + (r - k4); + out_index = p * in0_shape3 * in0_shape4 + q * in0_shape4 + r; }} @out(i0, i1, i2, i3, i4) = out_value; @out1(i0, i1, i2, i3, i4) = out_index; @@ -290,10 +287,7 @@ class Pool3d(Module): ''' if self.return_indices: return_shapes = [[N,C,d,h,w]] * 2 - return_dtypes = [x.dtype, 'uint8'] - ks = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2] - if ks > 65536: return_dtypes[1] = 'uint32' - elif ks > 256: return_dtypes[1] = 'uint16' + return_dtypes = [x.dtype, 'int32'] else: return_shapes = [N,C,d,h,w] return_dtypes = x.dtype @@ -420,8 +414,9 @@ class AdaptiveAvgPool2d(Module): return xx.reduce("mean", [4,5]) class AdaptiveMaxPool2d(Module): - def __init__(self, output_size): + def __init__(self, output_size, return_indices=False): self.output_size = output_size + self.return_indices = return_indices def execute(self, x): if isinstance(self.output_size, int): @@ -439,6 +434,10 @@ class AdaptiveMaxPool2d(Module): self.sw = math.floor(W / ow) self.ksh = H - (oh - 1) * self.sh self.ksw = W - (ow - 1) * self.sw + if self.return_indices: + return MaxPool2d( + kernel_size=(self.ksh, self.ksw), + stride=(self.sh, self.sw), return_indices=True)(x) h = (H-self.ksh)//self.sh+1 w = (W-self.ksw)//self.sw+1 xx = x.reindex([N,C,h,w,self.ksh,self.ksw], [ @@ -478,12 +477,13 @@ class AdaptiveAvgPool3d(Module): return xx.reduce("mean", [5,6,7]) class AdaptiveMaxPool3d(Module): - def __init__(self, output_size): + def __init__(self, output_size, return_indices=False): self.output_size = _triple(output_size) + self.return_indices = return_indices def execute(self, x): od, oh, ow = self.output_size - if od == 1 and oh == 1 and ow == 1: + if od == 1 and oh == 1 and ow == 1 and not self.return_indices: return x.reduce("maximum", [2,3,4], keepdims=True) N,C,D,H,W = x.shape self.sd = math.floor(D / od) @@ -492,6 +492,10 @@ class AdaptiveMaxPool3d(Module): self.ksd = D - (od - 1) * self.sd self.ksh = H - (oh - 1) * self.sh self.ksw = W - (ow - 1) * self.sw + if self.return_indices: + return MaxPool3d( + kernel_size=(self.ksd, self.ksh, self.ksw), + stride=(self.sd, self.sh, self.sw), return_indices=True)(x) d = (D-self.ksd)//self.sd+1 h = (H-self.ksh)//self.sh+1 w = (W-self.ksw)//self.sw+1 @@ -597,15 +601,15 @@ class MaxUnpool2d(Module): indexes=['i0', 'i1', f'i2/{kh}', f'i3/{kw}'], extras=[id], overflow_conditions=[ - f'((i2%{kh})*{kw}+i3%{kw}) != @e0(i0,i1,i2/{kh},i3/{kw})'], + f'(i2*yshape3+i3) != @e0(i0,i1,i2/{kh},i3/{kw})'], overflow_value=0) else: x = x.reindex_reduce( op="add", shape=[b, c, h, w], indexes=['i0', 'i1', - f'i2*{sh}+@e0(i0,i1,i2,i3)/{kw}', - f'i3*{sw}+@e0(i0,i1,i2,i3)%{kw}'], + f'@e0(i0,i1,i2,i3)/xshape3', + f'@e0(i0,i1,i2,i3)%xshape3'], extras=[id], ) return x @@ -635,16 +639,16 @@ class MaxUnpool3d(Module): indexes=['i0', 'i1', f'i2/{kd}', f'i3/{kh}', f'i4/{kw}'], extras=[id], overflow_conditions=[ - f'((i2%{kd})*{kh*kw}+(i3%{kh})*{kw}+i4%{kw}) != @e0(i0,i1,i2/{kd},i3/{kh},i4/{kw})'], + f'(i2*yshape3*yshape4+i3*yshape4+i4) != @e0(i0,i1,i2/{kd},i3/{kh},i4/{kw})'], overflow_value=0) else: x = x.reindex_reduce( op="add", shape=[b, c, d, h, w], indexes=['i0', 'i1', - f'i2*{sd}+@e0(i0,i1,i2,i3,i4)/{kh*kw}', - f'i3*{sh}+@e0(i0,i1,i2,i3,i4)/{kw}%{kh}', - f'i4*{sw}+@e0(i0,i1,i2,i3,i4)%{kw}'], + f'@e0(i0,i1,i2,i3,i4)/(xshape4*xshape3)', + f'@e0(i0,i1,i2,i3,i4)/xshape4%xshape3', + f'@e0(i0,i1,i2,i3,i4)%xshape4'], extras=[id], ) return x diff --git a/python/jittor/test/test_arg_pool_op.py b/python/jittor/test/test_arg_pool_op.py index 1a07f7a8..4ccb692f 100644 --- a/python/jittor/test/test_arg_pool_op.py +++ b/python/jittor/test/test_arg_pool_op.py @@ -156,7 +156,7 @@ class TestArgPoolOp(unittest.TestCase): 0,0,0,0, 1,0,0,1]).reshape((1,1,4,4)) b, idx = pool(a) - assert (idx.data.reshape((4,)) == [0,1,2,3]).all() + assert (idx.data.reshape((4,)) == [0,3,12,15]).all() def test_unpool(self): from jittor import nn @@ -168,6 +168,7 @@ class TestArgPoolOp(unittest.TestCase): [13, 14, 15, 16,0], [0, 0, 0, 0, 0]]]]) output, indices = pool(input) + assert (indices == jt.array([[6,8],[16,18]])).all() out = unpool(output, indices, output_size=input.shape) assert (out == jt.array([[[[ 0., 0., 0., 0., 0.], [ 0., 6., 0., 8., 0.],