mirror of https://github.com/Jittor/Jittor
add max_pool with index and max_unpool2d
This commit is contained in:
parent
fba55c7e31
commit
bb2d187dd3
|
@ -13,6 +13,7 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
static auto make_transpose = get_op_info("cutt_transpose")
|
||||
.get_constructor<VarPtr, Var*, NanoVector>();
|
||||
|
||||
|
@ -55,9 +56,19 @@ VarPtr CuttTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return make_transpose(dout, reverse);
|
||||
}
|
||||
|
||||
|
||||
void CuttTransposeOp::jit_prepare(JK& jk) {
|
||||
// do nothing
|
||||
jk << _CS("[T:1]");
|
||||
}
|
||||
|
||||
unordered_map<string, unsigned int> cutt_plan_cache;
|
||||
|
||||
void CuttTransposeOp::run() {
|
||||
#else // JIT
|
||||
|
||||
extern unordered_map<string, unsigned int> cutt_plan_cache;
|
||||
|
||||
void CuttTransposeOp::jit_run() {
|
||||
auto* __restrict__ xp = x->mem_ptr;
|
||||
auto* __restrict__ yp = y->mem_ptr;
|
||||
StackVector<int> x_shape;
|
||||
|
@ -99,5 +110,6 @@ void CuttTransposeOp::run() {
|
|||
cuttExecute(plan, xp, yp);
|
||||
}
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -19,7 +19,7 @@ struct CuttTransposeOp : Op {
|
|||
const char* name() const override { return "cutt_transpose"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
void run() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.60'
|
||||
__version__ = '1.2.2.61'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -18,7 +18,8 @@ import math
|
|||
class Pool(Module):
|
||||
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
|
||||
assert dilation == None
|
||||
assert return_indices == None
|
||||
assert return_indices == None or op == "maximum"
|
||||
self.return_indices = return_indices
|
||||
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
||||
self.op = op
|
||||
stride = stride if stride else kernel_size
|
||||
|
@ -48,20 +49,36 @@ class Pool(Module):
|
|||
count += "float32 rcount = 1.0f / count;"
|
||||
else:
|
||||
count = ""
|
||||
forward_body = f'''{{
|
||||
forward_body = f'''
|
||||
int k3 = i3*{self.stride[1]}-{self.padding[1]};
|
||||
int k2 = i2*{self.stride[0]}-{self.padding[0]};
|
||||
int k3_ = min(k3 + {self.kernel_size[1]}, in0_shape3);
|
||||
int k2_ = min(k2 + {self.kernel_size[0]}, in0_shape2);
|
||||
k3 = max(0, k3);
|
||||
k2 = max(0, k2);
|
||||
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
|
||||
{count}
|
||||
'''
|
||||
if not self.return_indices:
|
||||
forward_body += f'''
|
||||
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
|
||||
for (int p = k2; p < k2_; ++p)
|
||||
for (int q = k3; q < k3_; ++q)
|
||||
@out(i0, i1, i2, i3) = {self.op}(out_type, @out(i0, i1, i2, i3), @in0(i0, i1, p, q));
|
||||
}}'''
|
||||
backward_body = f'''{{
|
||||
'''
|
||||
else:
|
||||
forward_body += f'''
|
||||
auto out_value = init_{self.op}(out_type);
|
||||
int out_index = -1;
|
||||
for (int p = k2; p < k2_; ++p)
|
||||
for (int q = k3; q < k3_; ++q)
|
||||
if (out_value < @in0(i0, i1, p, q)) {{
|
||||
out_value = @in0(i0, i1, p, q);
|
||||
out_index = (p - k2) * {self.kernel_size[0]} + (q - k3);
|
||||
}}
|
||||
@out(i0, i1, i2, i3) = out_value;
|
||||
@out1(i0, i1, i2, i3) = out_index;
|
||||
'''
|
||||
backward_body = f'''
|
||||
int k3 = i3*{self.stride[1]}-{self.padding[1]};
|
||||
int k2 = i2*{self.stride[0]}-{self.padding[0]};
|
||||
int k3_ = min(k3 + {self.kernel_size[1]}, in0_shape3);
|
||||
|
@ -79,8 +96,14 @@ class Pool(Module):
|
|||
bo=0;
|
||||
}}"""}
|
||||
}}
|
||||
}}'''
|
||||
out = jt.code([N,C,h,w], x.dtype, [x],
|
||||
'''
|
||||
if self.return_indices:
|
||||
return_shapes = [[N,C,h,w]] * 2
|
||||
return_dtypes = [x.dtype, 'uint8']
|
||||
else:
|
||||
return_shapes = [N,C,h,w]
|
||||
return_dtypes = x.dtype
|
||||
out = jt.code(return_shapes, return_dtypes, [x],
|
||||
cuda_header="""
|
||||
#include <ops/binary_op_defs.h>
|
||||
#include <misc/cuda_limits.h>
|
||||
|
@ -96,7 +119,7 @@ class Pool(Module):
|
|||
int i0 = blockIdx.z;
|
||||
for (int i3 = p3; i3 < out_shape3; i3 += s3)
|
||||
for (int i2 = p2; i2 < out_shape2; i2 += s2)
|
||||
{forward_body}
|
||||
{{ {forward_body} }}
|
||||
}}
|
||||
int tx = min(1024, out_shape3);
|
||||
int ty = min(1024 / tx, out_shape2);
|
||||
|
@ -118,7 +141,7 @@ class Pool(Module):
|
|||
int i0 = blockIdx.z;
|
||||
for (int i3 = p3; i3 < pout_shape3; i3 += s3)
|
||||
for (int i2 = p2; i2 < pout_shape2; i2 += s2)
|
||||
{backward_body}
|
||||
{{ {backward_body} }}
|
||||
}}
|
||||
cudaMemsetAsync(out_p, 0, out->size);
|
||||
int tx = min(1024, pout_shape3);
|
||||
|
@ -137,7 +160,7 @@ class Pool(Module):
|
|||
for (int i1=0; i1<out_shape1; i1++)
|
||||
for (int i2=0; i2<out_shape2; i2++)
|
||||
for (int i3=0; i3<out_shape3; i3++)
|
||||
{forward_body}
|
||||
{{ {forward_body} }}
|
||||
''',
|
||||
cpu_grad_src = [f'''
|
||||
using namespace std;
|
||||
|
@ -148,7 +171,7 @@ class Pool(Module):
|
|||
for (int i1=0; i1<pout_shape1; i1++)
|
||||
for (int i2=0; i2<pout_shape2; i2++)
|
||||
for (int i3=0; i3<pout_shape3; i3++)
|
||||
{backward_body}
|
||||
{{ {backward_body} }}
|
||||
'''])
|
||||
return out
|
||||
else:
|
||||
|
@ -213,4 +236,53 @@ class MaxPool2d(Module):
|
|||
return self._layer(x)
|
||||
|
||||
def max_pool2d(x, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False):
|
||||
return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)(x)
|
||||
return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)(x)
|
||||
|
||||
class MaxUnpool2d(Module):
|
||||
def __init__(self, kernel_size, stride=None):
|
||||
''' MaxUnpool2d is the invert version of MaxPool2d with indices.
|
||||
It takes the output index of MaxPool2d as input.
|
||||
The element will be zero if it is not the max pooled value.
|
||||
|
||||
Example::
|
||||
|
||||
>>> import jittor as jt
|
||||
>>> from jittor import nn
|
||||
|
||||
>>> pool = nn.MaxPool2d(2, stride=2, return_indices=True)
|
||||
>>> unpool = nn.MaxUnpool2d(2, stride=2)
|
||||
>>> input = jt.array([[[[ 1., 2, 3, 4,0],
|
||||
[ 5, 6, 7, 8,0],
|
||||
[ 9, 10, 11, 12,0],
|
||||
[13, 14, 15, 16,0],
|
||||
[0, 0, 0, 0, 0]]]])
|
||||
>>> output, indices = pool(input)
|
||||
>>> unpool(output, indices, output_size=input.shape)
|
||||
jt.array([[[[ 0., 0., 0., 0., 0.],
|
||||
[ 0., 6., 0., 8., 0.],
|
||||
[ 0., 0., 0., 0., 0.],
|
||||
[ 0., 14., 0., 16., 0.],
|
||||
[ 0., 0., 0., 0., 0.]]]])
|
||||
'''
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride)
|
||||
if stride is None: stride = kernel_size
|
||||
assert stride == kernel_size, "Different stride and kernel is not supported yet."
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
def execute(self, x, id, output_size=None):
|
||||
b, c, ph, pw = x.shape
|
||||
kh, kw = self.kernel_size
|
||||
if output_size:
|
||||
h, w = output_size[-2:]
|
||||
else:
|
||||
h, w = ph * kh, pw * kw
|
||||
x = x.reindex(shape=[b, c, h, w],
|
||||
indexes=['i0', 'i1', f'i2/{kh}', f'i3/{kw}'],
|
||||
extras=[id],
|
||||
overflow_conditions=[
|
||||
f'((i2%{kh})*{kw}+i3%{kw}) != @e0(i0,i1,i2/{kh},i3/{kw})'],
|
||||
overflow_value=0)
|
||||
return x
|
|
@ -143,6 +143,39 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
for i in range(10):
|
||||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
||||
def test_index_pool(self):
|
||||
pool = jt.nn.Pool(2, return_indices=True)
|
||||
a = jt.randn([10,3,100,100])
|
||||
b, idx = pool(a)
|
||||
idx.sync()
|
||||
|
||||
def test_index_pool2(self):
|
||||
pool = jt.nn.Pool(2, return_indices=True)
|
||||
a = jt.array([1,0,0,1,
|
||||
0,0,0,0,
|
||||
0,0,0,0,
|
||||
1,0,0,1]).reshape((1,1,4,4))
|
||||
b, idx = pool(a)
|
||||
assert (idx.data.reshape((4,)) == [0,1,2,3]).all()
|
||||
|
||||
def test_unpool(self):
|
||||
from jittor import nn
|
||||
pool = nn.MaxPool2d(2, stride=2, return_indices=True)
|
||||
unpool = nn.MaxUnpool2d(2, stride=2)
|
||||
input = jt.array([[[[ 1., 2, 3, 4,0],
|
||||
[ 5, 6, 7, 8,0],
|
||||
[ 9, 10, 11, 12,0],
|
||||
[13, 14, 15, 16,0],
|
||||
[0, 0, 0, 0, 0]]]])
|
||||
output, indices = pool(input)
|
||||
out = unpool(output, indices, output_size=input.shape)
|
||||
assert (out == jt.array([[[[ 0., 0., 0., 0., 0.],
|
||||
[ 0., 6., 0., 8., 0.],
|
||||
[ 0., 0., 0., 0., 0.],
|
||||
[ 0., 14., 0., 16., 0.],
|
||||
[ 0., 0., 0., 0., 0.]]]])).all()
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda_avg_pool(self):
|
||||
|
|
Loading…
Reference in New Issue