mirror of https://github.com/Jittor/Jittor
change pool inference
This commit is contained in:
parent
cd3319edf4
commit
99b7fbeec1
|
@ -19,39 +19,40 @@ class Pool(Module):
|
|||
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
|
||||
assert dilation == None
|
||||
assert return_indices == None
|
||||
self.kernel_size = kernel_size
|
||||
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
||||
self.op = op
|
||||
self.stride = stride if stride else kernel_size
|
||||
self.padding = padding
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
|
||||
self.ceil_mode = ceil_mode
|
||||
self.count_include_pad = count_include_pad and padding != 0
|
||||
|
||||
def execute(self, x):
|
||||
N,C,H,W = x.shape
|
||||
if self.ceil_mode == False:
|
||||
h = (H+self.padding*2-self.kernel_size)//self.stride+1
|
||||
w = (W+self.padding*2-self.kernel_size)//self.stride+1
|
||||
h = (H+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1
|
||||
w = (W+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1
|
||||
use_code_op = self.op in ['maximum', 'minimum']
|
||||
# some second order avg_pool is require, so we don't use code op here
|
||||
else:
|
||||
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
h = (H+self.padding[0]*2-self.kernel_size[0] + self.stride[0] - 1)//self.stride[0]+1
|
||||
w = (W+self.padding[1]*2-self.kernel_size[1] + self.stride[1] - 1)//self.stride[1]+1
|
||||
use_code_op = self.op in ['maximum', 'minimum', 'mean']
|
||||
|
||||
if use_code_op:
|
||||
if self.op == 'mean':
|
||||
if self.count_include_pad:
|
||||
count = f"int count = {self.kernel_size*self.kernel_size};"
|
||||
count = f"int count = {self.kernel_size[0]*self.kernel_size[1]};"
|
||||
else:
|
||||
count = "int count = (k2_ - k2) * (k3_ - k3);"
|
||||
count += "float32 rcount = 1.0f / count;"
|
||||
else:
|
||||
count = ""
|
||||
forward_body = f'''{{
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
|
||||
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
|
||||
int k3 = i3*{self.stride[1]}-{self.padding[1]};
|
||||
int k2 = i2*{self.stride[0]}-{self.padding[0]};
|
||||
int k3_ = min(k3 + {self.kernel_size[1]}, in0_shape3);
|
||||
int k2_ = min(k2 + {self.kernel_size[0]}, in0_shape2);
|
||||
k3 = max(0, k3);
|
||||
k2 = max(0, k2);
|
||||
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
|
||||
|
@ -61,10 +62,10 @@ class Pool(Module):
|
|||
@out(i0, i1, i2, i3) = {self.op}(out_type, @out(i0, i1, i2, i3), @in0(i0, i1, p, q));
|
||||
}}'''
|
||||
backward_body = f'''{{
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
int k3_ = min(k3 + {self.kernel_size}, in0_shape3);
|
||||
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
|
||||
int k3 = i3*{self.stride[1]}-{self.padding[1]};
|
||||
int k2 = i2*{self.stride[0]}-{self.padding[0]};
|
||||
int k3_ = min(k3 + {self.kernel_size[1]}, in0_shape3);
|
||||
int k2_ = min(k2 + {self.kernel_size[0]}, in0_shape2);
|
||||
k3 = max(0, k3);
|
||||
k2 = max(0, k2);
|
||||
{count}
|
||||
|
@ -152,11 +153,11 @@ class Pool(Module):
|
|||
return out
|
||||
else:
|
||||
# TODO: backward
|
||||
xx = x.reindex([N,C,h,w,self.kernel_size,self.kernel_size], [
|
||||
xx = x.reindex([N,C,h,w,self.kernel_size[0],self.kernel_size[1]], [
|
||||
"i0", # Nid
|
||||
"i1", # Cid
|
||||
f"i2*{self.stride}-{self.padding}+i4", # Hid
|
||||
f"i3*{self.stride}-{self.padding}+i5", # Wid
|
||||
f"i2*{self.stride[0]}-{self.padding[0]}+i4", # Hid
|
||||
f"i3*{self.stride[1]}-{self.padding[1]}+i5", # Wid
|
||||
])
|
||||
return xx.reduce(self.op, [4,5])
|
||||
|
||||
|
|
Loading…
Reference in New Issue