add max_pool with index and max_unpool2d

This commit is contained in:
Dun Liang 2021-04-16 16:57:52 +08:00
parent fba55c7e31
commit bb2d187dd3
5 changed files with 132 additions and 15 deletions

View File

@ -13,6 +13,7 @@
namespace jittor {
#ifndef JIT
static auto make_transpose = get_op_info("cutt_transpose")
.get_constructor<VarPtr, Var*, NanoVector>();
@ -55,9 +56,19 @@ VarPtr CuttTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return make_transpose(dout, reverse);
}
void CuttTransposeOp::jit_prepare(JK& jk) {
// do nothing
jk << _CS("[T:1]");
}
unordered_map<string, unsigned int> cutt_plan_cache;
void CuttTransposeOp::run() {
#else // JIT
extern unordered_map<string, unsigned int> cutt_plan_cache;
void CuttTransposeOp::jit_run() {
auto* __restrict__ xp = x->mem_ptr;
auto* __restrict__ yp = y->mem_ptr;
StackVector<int> x_shape;
@ -99,5 +110,6 @@ void CuttTransposeOp::run() {
cuttExecute(plan, xp, yp);
}
}
#endif // JIT
} // jittor

View File

@ -19,7 +19,7 @@ struct CuttTransposeOp : Op {
const char* name() const override { return "cutt_transpose"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
void run() override;
DECLARE_jit_run;
};
} // jittor

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.60'
__version__ = '1.2.2.61'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -18,7 +18,8 @@ import math
class Pool(Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
assert dilation == None
assert return_indices == None
assert return_indices == None or op == "maximum"
self.return_indices = return_indices
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
self.op = op
stride = stride if stride else kernel_size
@ -48,20 +49,36 @@ class Pool(Module):
count += "float32 rcount = 1.0f / count;"
else:
count = ""
forward_body = f'''{{
forward_body = f'''
int k3 = i3*{self.stride[1]}-{self.padding[1]};
int k2 = i2*{self.stride[0]}-{self.padding[0]};
int k3_ = min(k3 + {self.kernel_size[1]}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size[0]}, in0_shape2);
k3 = max(0, k3);
k2 = max(0, k2);
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
{count}
'''
if not self.return_indices:
forward_body += f'''
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
for (int p = k2; p < k2_; ++p)
for (int q = k3; q < k3_; ++q)
@out(i0, i1, i2, i3) = {self.op}(out_type, @out(i0, i1, i2, i3), @in0(i0, i1, p, q));
}}'''
backward_body = f'''{{
'''
else:
forward_body += f'''
auto out_value = init_{self.op}(out_type);
int out_index = -1;
for (int p = k2; p < k2_; ++p)
for (int q = k3; q < k3_; ++q)
if (out_value < @in0(i0, i1, p, q)) {{
out_value = @in0(i0, i1, p, q);
out_index = (p - k2) * {self.kernel_size[0]} + (q - k3);
}}
@out(i0, i1, i2, i3) = out_value;
@out1(i0, i1, i2, i3) = out_index;
'''
backward_body = f'''
int k3 = i3*{self.stride[1]}-{self.padding[1]};
int k2 = i2*{self.stride[0]}-{self.padding[0]};
int k3_ = min(k3 + {self.kernel_size[1]}, in0_shape3);
@ -79,8 +96,14 @@ class Pool(Module):
bo=0;
}}"""}
}}
}}'''
out = jt.code([N,C,h,w], x.dtype, [x],
'''
if self.return_indices:
return_shapes = [[N,C,h,w]] * 2
return_dtypes = [x.dtype, 'uint8']
else:
return_shapes = [N,C,h,w]
return_dtypes = x.dtype
out = jt.code(return_shapes, return_dtypes, [x],
cuda_header="""
#include <ops/binary_op_defs.h>
#include <misc/cuda_limits.h>
@ -96,7 +119,7 @@ class Pool(Module):
int i0 = blockIdx.z;
for (int i3 = p3; i3 < out_shape3; i3 += s3)
for (int i2 = p2; i2 < out_shape2; i2 += s2)
{forward_body}
{{ {forward_body} }}
}}
int tx = min(1024, out_shape3);
int ty = min(1024 / tx, out_shape2);
@ -118,7 +141,7 @@ class Pool(Module):
int i0 = blockIdx.z;
for (int i3 = p3; i3 < pout_shape3; i3 += s3)
for (int i2 = p2; i2 < pout_shape2; i2 += s2)
{backward_body}
{{ {backward_body} }}
}}
cudaMemsetAsync(out_p, 0, out->size);
int tx = min(1024, pout_shape3);
@ -137,7 +160,7 @@ class Pool(Module):
for (int i1=0; i1<out_shape1; i1++)
for (int i2=0; i2<out_shape2; i2++)
for (int i3=0; i3<out_shape3; i3++)
{forward_body}
{{ {forward_body} }}
''',
cpu_grad_src = [f'''
using namespace std;
@ -148,7 +171,7 @@ class Pool(Module):
for (int i1=0; i1<pout_shape1; i1++)
for (int i2=0; i2<pout_shape2; i2++)
for (int i3=0; i3<pout_shape3; i3++)
{backward_body}
{{ {backward_body} }}
'''])
return out
else:
@ -213,4 +236,53 @@ class MaxPool2d(Module):
return self._layer(x)
def max_pool2d(x, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False):
return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)(x)
return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)(x)
class MaxUnpool2d(Module):
def __init__(self, kernel_size, stride=None):
''' MaxUnpool2d is the invert version of MaxPool2d with indices.
It takes the output index of MaxPool2d as input.
The element will be zero if it is not the max pooled value.
Example::
>>> import jittor as jt
>>> from jittor import nn
>>> pool = nn.MaxPool2d(2, stride=2, return_indices=True)
>>> unpool = nn.MaxUnpool2d(2, stride=2)
>>> input = jt.array([[[[ 1., 2, 3, 4,0],
[ 5, 6, 7, 8,0],
[ 9, 10, 11, 12,0],
[13, 14, 15, 16,0],
[0, 0, 0, 0, 0]]]])
>>> output, indices = pool(input)
>>> unpool(output, indices, output_size=input.shape)
jt.array([[[[ 0., 0., 0., 0., 0.],
[ 0., 6., 0., 8., 0.],
[ 0., 0., 0., 0., 0.],
[ 0., 14., 0., 16., 0.],
[ 0., 0., 0., 0., 0.]]]])
'''
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride)
if stride is None: stride = kernel_size
assert stride == kernel_size, "Different stride and kernel is not supported yet."
self.kernel_size = kernel_size
def execute(self, x, id, output_size=None):
b, c, ph, pw = x.shape
kh, kw = self.kernel_size
if output_size:
h, w = output_size[-2:]
else:
h, w = ph * kh, pw * kw
x = x.reindex(shape=[b, c, h, w],
indexes=['i0', 'i1', f'i2/{kh}', f'i3/{kw}'],
extras=[id],
overflow_conditions=[
f'((i2%{kh})*{kw}+i3%{kw}) != @e0(i0,i1,i2/{kh},i3/{kw})'],
overflow_value=0)
return x

View File

@ -143,6 +143,39 @@ class TestArgPoolOp(unittest.TestCase):
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
def test_index_pool(self):
pool = jt.nn.Pool(2, return_indices=True)
a = jt.randn([10,3,100,100])
b, idx = pool(a)
idx.sync()
def test_index_pool2(self):
pool = jt.nn.Pool(2, return_indices=True)
a = jt.array([1,0,0,1,
0,0,0,0,
0,0,0,0,
1,0,0,1]).reshape((1,1,4,4))
b, idx = pool(a)
assert (idx.data.reshape((4,)) == [0,1,2,3]).all()
def test_unpool(self):
from jittor import nn
pool = nn.MaxPool2d(2, stride=2, return_indices=True)
unpool = nn.MaxUnpool2d(2, stride=2)
input = jt.array([[[[ 1., 2, 3, 4,0],
[ 5, 6, 7, 8,0],
[ 9, 10, 11, 12,0],
[13, 14, 15, 16,0],
[0, 0, 0, 0, 0]]]])
output, indices = pool(input)
out = unpool(output, indices, output_size=input.shape)
assert (out == jt.array([[[[ 0., 0., 0., 0., 0.],
[ 0., 6., 0., 8., 0.],
[ 0., 0., 0., 0., 0.],
[ 0., 14., 0., 16., 0.],
[ 0., 0., 0., 0., 0.]]]])).all()
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
@jt.flag_scope(use_cuda=1)
def test_cuda_avg_pool(self):