change pool inference

This commit is contained in:
zhouwy19 2021-03-23 20:32:07 +08:00
parent cd3319edf4
commit 99b7fbeec1
1 changed files with 19 additions and 18 deletions

View File

@ -19,39 +19,40 @@ class Pool(Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
assert dilation == None
assert return_indices == None
self.kernel_size = kernel_size
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
self.op = op
self.stride = stride if stride else kernel_size
self.padding = padding
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad and padding != 0
def execute(self, x):
N,C,H,W = x.shape
if self.ceil_mode == False:
h = (H+self.padding*2-self.kernel_size)//self.stride+1
w = (W+self.padding*2-self.kernel_size)//self.stride+1
h = (H+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1
w = (W+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1
use_code_op = self.op in ['maximum', 'minimum']
# some second order avg_pool is require, so we don't use code op here
else:
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
h = (H+self.padding[0]*2-self.kernel_size[0] + self.stride[0] - 1)//self.stride[0]+1
w = (W+self.padding[1]*2-self.kernel_size[1] + self.stride[1] - 1)//self.stride[1]+1
use_code_op = self.op in ['maximum', 'minimum', 'mean']
if use_code_op:
if self.op == 'mean':
if self.count_include_pad:
count = f"int count = {self.kernel_size*self.kernel_size};"
count = f"int count = {self.kernel_size[0]*self.kernel_size[1]};"
else:
count = "int count = (k2_ - k2) * (k3_ - k3);"
count += "float32 rcount = 1.0f / count;"
else:
count = ""
forward_body = f'''{{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
int k3 = i3*{self.stride[1]}-{self.padding[1]};
int k2 = i2*{self.stride[0]}-{self.padding[0]};
int k3_ = min(k3 + {self.kernel_size[1]}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size[0]}, in0_shape2);
k3 = max(0, k3);
k2 = max(0, k2);
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
@ -61,10 +62,10 @@ class Pool(Module):
@out(i0, i1, i2, i3) = {self.op}(out_type, @out(i0, i1, i2, i3), @in0(i0, i1, p, q));
}}'''
backward_body = f'''{{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
int k3 = i3*{self.stride[1]}-{self.padding[1]};
int k2 = i2*{self.stride[0]}-{self.padding[0]};
int k3_ = min(k3 + {self.kernel_size[1]}, in0_shape3);
int k2_ = min(k2 + {self.kernel_size[0]}, in0_shape2);
k3 = max(0, k3);
k2 = max(0, k2);
{count}
@ -152,11 +153,11 @@ class Pool(Module):
return out
else:
# TODO: backward
xx = x.reindex([N,C,h,w,self.kernel_size,self.kernel_size], [
xx = x.reindex([N,C,h,w,self.kernel_size[0],self.kernel_size[1]], [
"i0", # Nid
"i1", # Cid
f"i2*{self.stride}-{self.padding}+i4", # Hid
f"i3*{self.stride}-{self.padding}+i5", # Wid
f"i2*{self.stride[0]}-{self.padding[0]}+i4", # Hid
f"i3*{self.stride[1]}-{self.padding[1]}+i5", # Wid
])
return xx.reduce(self.op, [4,5])