Merge pull request #190 from Jittor/pool

change pool inference
This commit is contained in:
Xiang-Li Li 2021-03-24 22:13:18 +08:00 committed by GitHub
commit b031e5d617
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 18 deletions

View File

@ -19,39 +19,40 @@ class Pool(Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
assert dilation == None
assert return_indices == None
self.kernel_size = kernel_size
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
self.op = op
self.stride = stride if stride else kernel_size
self.padding = padding
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad and padding != 0
def execute(self, x):
N,C,H,W = x.shape
if self.ceil_mode == False:
h = (H+self.padding*2-self.kernel_size)//self.stride+1
w = (W+self.padding*2-self.kernel_size)//self.stride+1
h = (H+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1
w = (W+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1
use_code_op = self.op in ['maximum', 'minimum']
# some second order avg_pool is require, so we don't use code op here
else:
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
h = (H+self.padding[0]*2-self.kernel_size[0] + self.stride[0] - 1)//self.stride[0]+1
w = (W+self.padding[1]*2-self.kernel_size[1] + self.stride[1] - 1)//self.stride[1]+1
use_code_op = self.op in ['maximum', 'minimum', 'mean']
if use_code_op:
if self.op == 'mean':
if self.count_include_pad:
count = f"int count = {self.kernel_size*self.kernel_size};"
count = f"int count = {self.kernel_size[0]*self.kernel_size[1]};"
else:
count = "int count = (k2_ - k2) * (k3_ - k3);"
count += "float32 rcount = 1.0f / count;"
else:
count = ""
forward_body = f'''{{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
int k3 = i3*{self.stride[1]}-{self.padding[1]};
int k2 = i2*{self.stride[0]}-{self.padding[0]};
int k3_ = min(k3 + {self.kernel_size[1]}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size[0]}, in0_shape2);
k3 = max(0, k3);
k2 = max(0, k2);
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
@ -61,10 +62,10 @@ class Pool(Module):
@out(i0, i1, i2, i3) = {self.op}(out_type, @out(i0, i1, i2, i3), @in0(i0, i1, p, q));
}}'''
backward_body = f'''{{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
int k3 = i3*{self.stride[1]}-{self.padding[1]};
int k2 = i2*{self.stride[0]}-{self.padding[0]};
int k3_ = min(k3 + {self.kernel_size[1]}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size[0]}, in0_shape2);
k3 = max(0, k3);
k2 = max(0, k2);
{count}
@ -152,11 +153,11 @@ class Pool(Module):
return out
else:
# TODO: backward
xx = x.reindex([N,C,h,w,self.kernel_size,self.kernel_size], [
xx = x.reindex([N,C,h,w,self.kernel_size[0],self.kernel_size[1]], [
"i0", # Nid
"i1", # Cid
f"i2*{self.stride}-{self.padding}+i4", # Hid
f"i3*{self.stride}-{self.padding}+i5", # Wid
f"i2*{self.stride[0]}-{self.padding[0]}+i4", # Hid
f"i3*{self.stride[1]}-{self.padding[1]}+i5", # Wid
])
return xx.reduce(self.op, [4,5])

View File

@ -90,6 +90,18 @@ class TestArgPoolOp(unittest.TestCase):
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
@jt.flag_scope(use_cuda=1)
def test_cuda_tuple(self):
jt_model = jt.nn.Sequential(Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1), ceil_mode=True), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool(3, 1, 1))
torch_model = Sequential(MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1), ceil_mode=True), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d(3, 1, 1))
shape = [2, 3, 300, 300]
check(jt_model, torch_model, shape, False)
shape = [2, 3, 157, 300]
check(jt_model, torch_model, shape, False)
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
@unittest.skipIf(True, "TODO: cannot pass this test, fix me")
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
@jt.flag_scope(use_cuda=1)
@ -120,6 +132,16 @@ class TestArgPoolOp(unittest.TestCase):
check(jt_model, torch_model, shape, False)
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
def test_cpu_tuple(self):
jt_model = jt.nn.Sequential(Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1), ceil_mode=True), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool(3, 1, 1))
torch_model = Sequential(MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1), ceil_mode=True), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d(3, 1, 1))
shape = [2, 3, 300, 300]
check(jt_model, torch_model, shape, False)
shape = [2, 3, 157, 300]
check(jt_model, torch_model, shape, False)
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
@jt.flag_scope(use_cuda=1)