This commit is contained in:
guoye 2020-03-27 18:01:59 +08:00
parent ba8b1e6949
commit 0625957143
3 changed files with 260 additions and 39 deletions

View File

@ -13,6 +13,7 @@ import jittor as jt
from jittor import init, Module
import numpy as np
import math
from jittor.pool import Pool, pool
def matmul_transpose(a, b):
'''
@ -64,18 +65,6 @@ def batch_norm(x, is_train, eps=1e-5, momentum=0.1):
return norm_x * w + b
def pool(x, size, op, padding, stride = 1):
N,C,H,W = x.shape
h = (H+padding*2-size)//stride+1
w = (W+padding*2-size)//stride+1
xx = x.reindex([N,C,h,w,size,size], [
"i0", # Nid
"i1", # Cid
f"i2*{stride}-{padding}+i4", # Hid
f"i3*{stride}-{padding}+i5", # Wid
])
return xx.reduce(op, [4,5])
@jt.var_scope('conv')
def conv(x, in_planes, out_planes, kernel_size, padding, stride = 1, init_method=None):
Kw = kernel_size
@ -279,33 +268,6 @@ class BatchNorm(Module):
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
class Pool(Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, op="maximum"):
assert dilation == None
assert return_indices == None
self.kernel_size = kernel_size
self.op = op
self.stride = stride if stride else kernel_size
self.padding = padding
self.ceil_mode = ceil_mode
def execute(self, x):
N,C,H,W = x.shape
if (self.ceil_mode == False):
h = (H+self.padding*2-self.kernel_size)//self.stride+1
w = (W+self.padding*2-self.kernel_size)//self.stride+1
else:
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
xx = x.reindex([N,C,h,w,self.kernel_size,self.kernel_size], [
"i0", # Nid
"i1", # Cid
f"i2*{self.stride}-{self.padding}+i4", # Hid
f"i3*{self.stride}-{self.padding}+i5", # Wid
])
return xx.reduce(self.op, [4,5])
Relu = jt.make_module(relu)
ReLU = Relu
Leaky_relu = jt.make_module(leaky_relu, 2)

185
python/jittor/pool.py Normal file
View File

@ -0,0 +1,185 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Wenyang Zhou <576825820@qq.com>
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.
#
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from jittor import init, Module
import numpy as np
import math
class Pool(Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, op="maximum"):
assert dilation == None
assert return_indices == None
self.kernel_size = kernel_size
self.op = op
self.stride = stride if stride else kernel_size
self.padding = padding
self.ceil_mode = ceil_mode
def execute(self, x):
N,C,H,W = x.shape
if (self.ceil_mode == False):
h = (H+self.padding*2-self.kernel_size)//self.stride+1
w = (W+self.padding*2-self.kernel_size)//self.stride+1
else:
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
if (self.op == 'maximum' or self.op == 'minimum'):
if (self.op == 'maximum'):
op = 'max'
else:
op = 'min'
out = jt.code([N,C,h,w], x.dtype, [x],
cuda_src=f'''
__global__ static void kernel1(@ARGS_DEF) {{
@PRECALC
int p3 = threadIdx.x;
int s3 = blockDim.x;
int p2 = threadIdx.y + blockIdx.x * blockDim.y;
int s2 = blockDim.y * gridDim.x;
int i1 = blockIdx.y;
int i0 = blockIdx.z;
for (int i3 = p3; i3 < outshape3; i3 += s3)
for (int i2 = p2; i2 < outshape2; i2 += s2) {{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = min(k3 + {self.kernel_size}, in0shape3);
int k2_ = min(k2 + {self.kernel_size}, in0shape2);
k3 = max(0, k3);
k2 = max(0, k2);
@out(i0, i1, i2, i3) = @in0(i0, i1, k2, k3);
for (int p = k2; p < k2_; ++p)
for (int q = k3; q < k3_; ++q)
@out(i0, i1, i2, i3) = {op}(@out(i0, i1, i2, i3), @in0(i0, i1, p, q));
}}
}}
int tx = min(1024, outshape3);
int ty = min(1024 / tx, outshape2);
int bx = (outshape2 - 1) / ty + 1;
int by = outshape1;
int bz = outshape0;
dim3 s1(bx, by, bz);
dim3 s2(tx, ty);
kernel1<<<s1, s2>>>(@ARGS);
''',
cuda_grad_src=[f'''
__global__ static void kernel2(@ARGS_DEF) {{
@PRECALC
int p3 = threadIdx.x;
int s3 = blockDim.x;
int p2 = threadIdx.y + blockIdx.x * blockDim.y;
int s2 = blockDim.y * gridDim.x;
int i1 = blockIdx.y;
int i0 = blockIdx.z;
for (int i3 = p3; i3 < outshape3; i3 += s3)
for (int i2 = p2; i2 < outshape2; i2 += s2) {{
@out(i0, i1, i2, i3) = 0;
}}
}}
__global__ static void kernel3(@ARGS_DEF) {{
@PRECALC
int p3 = threadIdx.x;
int s3 = blockDim.x;
int p2 = threadIdx.y + blockIdx.x * blockDim.y;
int s2 = blockDim.y * gridDim.x;
int i1 = blockIdx.y;
int i0 = blockIdx.z;
for (int i3 = p3; i3 < poutshape3; i3 += s3)
for (int i2 = p2; i2 < poutshape2; i2 += s2) {{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = min(k3 + {self.kernel_size}, in0shape3);
int k2_ = min(k2 + {self.kernel_size}, in0shape2);
k3 = max(0, k3);
k2 = max(0, k2);
int bo=1;
for (int p = k2; p < k2_ && bo; ++p)
for (int q = k3; q < k3_ && bo; ++q) {{
if (@pout(i0,i1,i2,i3) == @in0(i0,i1,p,q)) {{
atomicAdd(&@out(i0,i1,p,q), @dout(i0,i1,i2,i3));
bo=0;
}}
}}
}}
}}
int tx = min(1024, outshape3);
int ty = min(1024 / tx, outshape2);
int bx = (outshape2 - 1) / ty + 1;
int by = outshape1;
int bz = outshape0;
dim3 s1(bx, by, bz);
dim3 s2(tx, ty);
kernel2<<<s1, s2>>>(@ARGS);
tx = min(1024, poutshape3);
ty = min(1024 / tx, poutshape2);
bx = (poutshape2 - 1) / ty + 1;
by = poutshape1;
bz = poutshape0;
dim3 s1_(bx, by, bz);
dim3 s2_(tx, ty);
kernel3<<<s1_, s2_>>>(@ARGS);
'''],
cpu_src=f'''
for (int i0=0; i0<outshape0; i0++)
for (int i1=0; i1<outshape1; i1++)
for (int i2=0; i2<outshape2; i2++)
for (int i3=0; i3<outshape3; i3++) {{
int k2 = i2*{self.stride}-{self.padding};
int k3 = i3*{self.stride}-{self.padding};
int k2_ = std::min(k2 + {self.kernel_size}, in0shape2);
int k3_ = std::min(k3 + {self.kernel_size}, in0shape3);
k2 = std::max(0, k2);
k3 = std::max(0, k3);
@out(i0, i1, i2, i3) = @in0(i0, i1, k2, k3);
for (int p = k2; p < k2_; ++p)
for (int q = k3; q < k3_; ++q)
@out(i0, i1, i2, i3) = std::{op}(@out(i0, i1, i2, i3), @in0(i0, i1, p, q));
}}
''',
cpu_grad_src = [f'''
for (int i=0; i<outshape0; i++)
for (int j=0; j<outshape1; j++)
for (int k=0; k<outshape2; k++)
for (int l=0; l<outshape3; l++) @out(i,j,k,l) = 0;
for (int i0=0; i0<poutshape0; i0++)
for (int i1=0; i1<poutshape1; i1++)
for (int i2=0; i2<poutshape2; i2++)
for (int i3=0; i3<poutshape3; i3++) {{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
int k3_ = std::min(k3 + {self.kernel_size}, in0shape3);
int k2_ = std::min(k2 + {self.kernel_size}, in0shape2);
k3 = std::max(0, k3);
k2 = std::max(0, k2);
int bo=1;
for (int p = k2; p < k2_ && bo; ++p)
for (int q = k3; q < k3_ && bo; ++q) {{
if (@pout(i0,i1,i2,i3) == @in0(i0,i1,p,q)) {{
@out(i0,i1,p,q) += @dout(i0,i1,i2,i3);
bo=0;
}}
}}
}}
'''])
return out
else:
xx = x.reindex([N,C,h,w,self.kernel_size,self.kernel_size], [
"i0", # Nid
"i1", # Cid
f"i2*{self.stride}-{self.padding}+i4", # Hid
f"i3*{self.stride}-{self.padding}+i5", # Wid
])
return xx.reduce(self.op, [4,5])
def pool(x, size, op, padding, stride = 1):
return Pool(size, stride, padding, op=op)(x)

View File

@ -0,0 +1,74 @@
# ***************************************************************
# Copyright (c) 2019 Dun Liang <randonlang@gmail.com>. All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
from jittor.nn import Pool, pool
import numpy as np
from .test_core import expect_error
from .test_grad import ngrad
from itertools import permutations
from jittor import compile_extern, Module
from .test_log import find_log_with_re
import random
import pickle as pk
skip_this_test = False
try:
import torch
from torch.nn import MaxPool2d, Sequential
except:
skip_this_test = True
def check(jt_model, torch_model, shape, near_data):
if (near_data):
assert shape[0] * shape[1] * shape[2] * shape[3] % 8 == 0
data = list(range(8)) * int((shape[0] * shape[1] * shape[2] * shape[3]) / 8)
random.shuffle(data)
x = jt.array(data).float32().reshape(shape)
else:
x = jt.random(shape)
y = jt_model(x)
g = jt.grad(y.sum(), x)
x_ = torch.Tensor(x.data)
x_.requires_grad = True
y_ = torch_model(x_)
y_.sum().backward()
y__ = y_.detach().numpy()
g__ = x_.grad.detach().numpy()
assert np.allclose(y.data, y__)
assert np.allclose(g.data, g__)
@unittest.skipIf(skip_this_test, "No Torch found")
class TestArgPoolOp(unittest.TestCase):
@jt.flag_scope(use_cuda=1)
def test_cuda(self):
jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1))
torch_model = Sequential(MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0, ceil_mode=True), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(3, 1, 1))
shape = [64, 64, 300, 300]
check(jt_model, torch_model, shape, False)
shape = [32, 128, 157, 300]
check(jt_model, torch_model, shape, False)
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
def test_cpu_(self):
x = jt.random([32, 128, 157, 300])
x = jt.nn.pool(x, 2, "maximum", 0, 2)
def test_cpu(self):
jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1))
torch_model = Sequential(MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0, ceil_mode=True), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(3, 1, 1))
shape = [64, 64, 300, 300]
check(jt_model, torch_model, shape, False)
shape = [32, 128, 157, 300]
check(jt_model, torch_model, shape, False)
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
if __name__ == "__main__":
unittest.main()