mirror of https://github.com/Jittor/Jittor
depthwise_conv
This commit is contained in:
parent
14324f364b
commit
895f6ecee2
|
@ -98,6 +98,7 @@ def install_cub(root_folder):
|
|||
return dirname
|
||||
|
||||
def setup_cub():
|
||||
global cub_home
|
||||
from pathlib import Path
|
||||
cub_path = os.path.join(str(Path.home()), ".cache", "jittor", "cub")
|
||||
cub_home = install_cub(cub_path)
|
||||
|
|
|
@ -1,23 +1,27 @@
|
|||
|
||||
import jittor as jt
|
||||
from jittor import init
|
||||
from jittor import nn
|
||||
from jittor import Function
|
||||
|
||||
def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_conv=0):
|
||||
N,C,H,W = data.shape
|
||||
o,i,Kh,Kw = weights.shape
|
||||
assert(i == 1)
|
||||
class DepthwiseConv(Function):
|
||||
def __init__(self, stride=1, padding=0, dilation=1):
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
|
||||
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
||||
|
||||
def execute(self, x, weight):
|
||||
self.save_vars = x, weight
|
||||
N,C,H,W = x.shape
|
||||
o,i,Kh,Kw = weight.shape
|
||||
assert(o == C)
|
||||
if not isinstance(stride, tuple):
|
||||
stride = (stride, stride)
|
||||
|
||||
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
|
||||
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
||||
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
||||
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
||||
filter_height, filter_width = Kh, Kw
|
||||
self.Khw = Kh, Kw
|
||||
output = jt.code(
|
||||
[N, C, oh, ow],
|
||||
data.dtype,
|
||||
[data, weights],
|
||||
x.dtype,
|
||||
[x, weight],
|
||||
cuda_header = """
|
||||
template <typename T,
|
||||
int filter_height,
|
||||
|
@ -76,6 +80,61 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
|||
}
|
||||
}
|
||||
}
|
||||
""",
|
||||
cuda_src=f"""
|
||||
@alias(input, in0)
|
||||
@alias(filter, in1)
|
||||
@alias(output, out)
|
||||
|
||||
const int batch_size = input_shape0;
|
||||
const int input_channels = input_shape1;
|
||||
const int input_height = input_shape2;
|
||||
const int input_width = input_shape3;
|
||||
const int output_channels = output_shape1;
|
||||
const int output_height = output_shape2;
|
||||
const int output_width = output_shape3;
|
||||
const int ksize_height = {Kh};
|
||||
const int ksize_width = {Kw};
|
||||
const int stride_height = {self.stride[0]};
|
||||
const int stride_width = {self.stride[1]};
|
||||
const int padding_height = {self.padding[0]};
|
||||
const int padding_width = {self.padding[1]};
|
||||
const int dilate_height = {self.dilation[0]};
|
||||
const int dilate_width = {self.dilation[1]};
|
||||
|
||||
int thread = 512;
|
||||
if (output_width > 1024 && output_width <= 2048)
|
||||
thread = (output_width - 1) / 2 + 1;
|
||||
else if (output_width > 512 && output_width <= 1024)
|
||||
thread = output_width;
|
||||
int blocks = std::min(std::max(thread / output_width, 1), output_height);
|
||||
dim3 threads(std::min(output_width, thread), blocks, 1);
|
||||
dim3 grid(output_channels, batch_size, 1);
|
||||
KernelDepthwiseConv<
|
||||
input_type, ksize_height, ksize_width,
|
||||
stride_height, stride_width>
|
||||
<<<grid, threads>>>(
|
||||
input_p, filter_p, batch_size, output_channels, output_height,
|
||||
output_width, input_channels, input_height, input_width,
|
||||
padding_height, padding_width, dilate_height,
|
||||
dilate_width, output_p);
|
||||
"""
|
||||
)
|
||||
return output
|
||||
|
||||
def grad(self, grad):
|
||||
x, weight = self.save_vars
|
||||
Kh, Kw = self.Khw
|
||||
return jt.code([x.shape, weight.shape], [x.dtype, weight.dtype], [x, weight, grad],
|
||||
cuda_header = f"#include <{jt.compile_extern.cub_home}/cub/cub.cuh>"+"""
|
||||
template <typename T>
|
||||
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
|
||||
typedef cub::WarpReduce<T> WarpReduce;
|
||||
typename WarpReduce::TempStorage temp_storage;
|
||||
value = WarpReduce(temp_storage).Sum(value);
|
||||
if (cub::LaneId() == 0)
|
||||
atomicAdd(sum, value);
|
||||
}
|
||||
|
||||
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
|
||||
template <typename T,
|
||||
|
@ -153,11 +212,12 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
|||
const int dilate_width, T* filter_grad_data) {
|
||||
T s = 0;
|
||||
|
||||
int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;
|
||||
int gbid = (((blockIdx.z * blockDim.z + threadIdx.z) * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;
|
||||
|
||||
for (int image_w = threadIdx.x; image_w < output_width;
|
||||
image_w += blockDim.x) {
|
||||
for (int bid = 0; bid < num; bid++) {
|
||||
//for (int bid = threadIdx.z; bid < num; bid+=blockDim.z) {
|
||||
for (int image_h = threadIdx.y; image_h < output_height;
|
||||
image_h += blockDim.y) {
|
||||
int kernel_id = blockIdx.z;
|
||||
|
@ -183,56 +243,16 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
|||
}
|
||||
}
|
||||
}
|
||||
atomicAdd(&filter_grad_data[gbid], s);
|
||||
CudaAtomicAddWithWarp(&filter_grad_data[gbid], s);
|
||||
}
|
||||
""",
|
||||
|
||||
cuda_src=f"""
|
||||
@alias(input, in0)
|
||||
@alias(filter, in1)
|
||||
@alias(output, out0)
|
||||
|
||||
const int batch_size = input_shape0;
|
||||
const int input_channels = input_shape1;
|
||||
const int input_height = input_shape2;
|
||||
const int input_width = input_shape3;
|
||||
const int output_channels = output_shape1;
|
||||
const int output_height = output_shape2;
|
||||
const int output_width = output_shape3;
|
||||
const int ksize_height = {Kh};
|
||||
const int ksize_width = {Kw};
|
||||
const int stride_height = {stride[0]};
|
||||
const int stride_width = {stride[1]};
|
||||
const int padding_height = {padding[0]};
|
||||
const int padding_width = {padding[1]};
|
||||
const int dilate_height = {dilation[0]};
|
||||
const int dilate_width = {dilation[1]};
|
||||
|
||||
int thread = 512;
|
||||
if (output_width > 1024 && output_width <= 2048)
|
||||
thread = (output_width - 1) / 2 + 1;
|
||||
else if (output_width > 512 && output_width <= 1024)
|
||||
thread = output_width;
|
||||
int blocks = std::min(std::max(thread / output_width, 1), output_height);
|
||||
dim3 threads(std::min(output_width, thread), blocks, 1);
|
||||
dim3 grid(output_channels, batch_size, 1);
|
||||
KernelDepthwiseConv<
|
||||
input_type, ksize_height, ksize_width,
|
||||
stride_height, stride_width>
|
||||
<<<grid, threads>>>(
|
||||
input_p, filter_p, batch_size, output_channels, output_height,
|
||||
output_width, input_channels, input_height, input_width,
|
||||
padding_height, padding_width, dilate_height,
|
||||
dilate_width, output_p);
|
||||
""",
|
||||
|
||||
cuda_grad_src=[
|
||||
f"""
|
||||
// source for backward to data
|
||||
@alias(input, in0)
|
||||
@alias(filter, in1)
|
||||
@alias(input_grad, out)
|
||||
@alias(output_grad, in2)
|
||||
@alias(input_grad, out0)
|
||||
@alias(filter_grad, out1)
|
||||
|
||||
const int batch_size = input_shape0;
|
||||
const int input_channels = input_shape1;
|
||||
|
@ -243,12 +263,12 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
|||
const int output_width = output_grad_shape3;
|
||||
const int ksize_height = {Kh};
|
||||
const int ksize_width = {Kw};
|
||||
const int stride_height = {stride[0]};
|
||||
const int stride_width = {stride[1]};
|
||||
const int padding_height = {padding[0]};
|
||||
const int padding_width = {padding[1]};
|
||||
const int dilate_height = {dilation[0]};
|
||||
const int dilate_width = {dilation[1]};
|
||||
const int stride_height = {self.stride[0]};
|
||||
const int stride_width = {self.stride[1]};
|
||||
const int padding_height = {self.padding[0]};
|
||||
const int padding_width = {self.padding[1]};
|
||||
const int dilate_height = {self.dilation[0]};
|
||||
const int dilate_width = {self.dilation[1]};
|
||||
|
||||
int thread = 512;
|
||||
if (input_width > 1024 && input_width <= 2048)
|
||||
|
@ -266,28 +286,8 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
|||
output_channels, output_height, output_width, input_channels,
|
||||
input_height, input_width, padding_height,
|
||||
padding_width, dilate_height, dilate_width, input_grad_p);
|
||||
""",
|
||||
f"""
|
||||
// source for backward to filter
|
||||
@alias(input, in0)
|
||||
@alias(filter_grad, out)
|
||||
@alias(output_grad, in2)
|
||||
|
||||
const int batch_size = input_shape0;
|
||||
const int input_channels = input_shape1;
|
||||
const int input_height = input_shape2;
|
||||
const int input_width = input_shape3;
|
||||
const int output_channels = output_grad_shape1;
|
||||
const int output_height = output_grad_shape2;
|
||||
const int output_width = output_grad_shape3;
|
||||
const int ksize_height = {Kh};
|
||||
const int ksize_width = {Kw};
|
||||
const int stride_height = {stride[0]};
|
||||
const int stride_width = {stride[1]};
|
||||
const int padding_height = {padding[0]};
|
||||
const int padding_width = {padding[1]};
|
||||
const int dilate_height = {dilation[0]};
|
||||
const int dilate_width = {dilation[1]};
|
||||
// source for backward to filter
|
||||
|
||||
int block_size = 512;
|
||||
if (output_width > 1024 && output_width <= 2048)
|
||||
|
@ -296,8 +296,10 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
|||
block_size = output_width;
|
||||
int crop_output_height =
|
||||
std::min(std::max(block_size / output_width, 1), output_height);
|
||||
dim3 grid(ksize_width, ksize_height, output_channels);
|
||||
dim3 threads(std::min(output_width, block_size), crop_output_height, 1);
|
||||
|
||||
grid = dim3(ksize_width, ksize_height, output_channels);
|
||||
threads = dim3(std::min(output_width, block_size), crop_output_height, 1);
|
||||
|
||||
|
||||
KernelDepthwiseConvFilterGrad<
|
||||
input_type><<<grid, threads, 0>>>(
|
||||
|
@ -307,6 +309,4 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
|||
stride_height, stride_width, padding_height, padding_width,
|
||||
dilate_height, dilate_width, filter_grad_p);
|
||||
"""
|
||||
]
|
||||
)
|
||||
return output
|
|
@ -153,7 +153,7 @@ def get_init_var_rand(shape, dtype):
|
|||
|
||||
def relu(x): return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x))
|
||||
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
|
||||
def relu6(x): return jt.minimum(jt.maximum(x, 0), 6)
|
||||
def relu6(x): return jt.minimum(jt.maximum(x, 0.0), 6.0)
|
||||
def sign(x):
|
||||
one = jt.ones(x.shape)
|
||||
x = jt.ternary(x>0, one, x)
|
||||
|
@ -473,7 +473,7 @@ ReLU6 = jt.make_module(relu6)
|
|||
Softmax = jt.make_module(softmax, 2)
|
||||
GELU = jt.make_module(gelu)
|
||||
|
||||
from jittor.depthwise_conv import depthwise_conv
|
||||
from jittor.depthwise_conv import DepthwiseConv
|
||||
|
||||
class Conv(Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
|
@ -485,7 +485,8 @@ class Conv(Module):
|
|||
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
||||
self.groups = groups
|
||||
self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels
|
||||
# self.is_depthwise_conv = False
|
||||
if self.is_depthwise_conv and jt.flags.use_cuda:
|
||||
self.depthwise_conv = DepthwiseConv(stride, padding, dilation)
|
||||
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
||||
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
||||
Kh, Kw = self.kernel_size
|
||||
|
@ -505,8 +506,12 @@ class Conv(Module):
|
|||
self.bias = None
|
||||
|
||||
def execute(self, x):
|
||||
if self.is_depthwise_conv:
|
||||
return depthwise_conv(x, self.weight, self.padding, self.dilation, self.stride)
|
||||
if self.is_depthwise_conv and jt.flags.use_cuda:
|
||||
y = self.depthwise_conv(x, self.weight)
|
||||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
elif self.groups == 1:
|
||||
N,C,H,W = x.shape
|
||||
Kh, Kw = self.kernel_size
|
||||
|
@ -541,7 +546,6 @@ class Conv(Module):
|
|||
f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
|
||||
f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
|
||||
])
|
||||
xx.compile_options = {"G":G}
|
||||
# w: [oc, CpG, Kh, Kw]
|
||||
ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [
|
||||
f'i1*{oc//G}+i2',
|
||||
|
@ -549,6 +553,7 @@ class Conv(Module):
|
|||
'i6',
|
||||
'i7'
|
||||
])
|
||||
ww.compile_options = xx.compile_options = {"G":G,"C":C}
|
||||
yy = xx*ww
|
||||
y = yy.reindex_reduce('add', [N, oc, oh, ow], [
|
||||
'i0',
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import jittor.models as jtmodels
|
||||
|
||||
def load_parameters(m1, m2):
|
||||
m1.save('temp.pk')
|
||||
m2.load('temp.pk')
|
||||
|
||||
def compare_parameters(m1, m2):
|
||||
ps1 = m1.parameters()
|
||||
ps2 = m2.parameters()
|
||||
for i in range(len(ps1)):
|
||||
x = ps1[i].data + 1e-8
|
||||
y = ps2[i].data + 1e-8
|
||||
relative_error = abs(x - y) / abs(y)
|
||||
diff = relative_error.mean()
|
||||
assert diff < 1e-4, (diff, 'backward', ps2[i].name())
|
||||
|
||||
class TestDepthwiseConv(unittest.TestCase):
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_data(self):
|
||||
test_img = np.random.random((64,3,224,224)).astype('float32')
|
||||
jittor_test_img = jt.array(test_img)
|
||||
lr = 100
|
||||
|
||||
jittor_model = jtmodels.__dict__['mobilenet_v2']()
|
||||
jittor_model2 = jtmodels.__dict__['mobilenet_v2']()
|
||||
# Set eval to avoid dropout layer & bn errors
|
||||
jittor_model.train()
|
||||
jittor_model.classifier[0].eval()
|
||||
for m in jittor_model.modules():
|
||||
if isinstance(m, jt.nn.BatchNorm):
|
||||
m.eval()
|
||||
|
||||
jittor_model2.train()
|
||||
jittor_model2.classifier[0].eval()
|
||||
for m in jittor_model2.modules():
|
||||
if isinstance(m, jt.nn.BatchNorm):
|
||||
m.eval()
|
||||
|
||||
load_parameters(jittor_model2, jittor_model)
|
||||
for m in jittor_model.modules():
|
||||
if isinstance(m, jt.nn.Conv):
|
||||
m.is_depthwise_conv = False
|
||||
cnt = 0
|
||||
for m in jittor_model2.modules():
|
||||
if isinstance(m, jt.nn.Conv):
|
||||
if (m.is_depthwise_conv):
|
||||
cnt += 1
|
||||
assert cnt == 17, (cnt, '!=', 17)
|
||||
jt_optimizer = jt.nn.SGD(jittor_model.parameters(), lr = lr)
|
||||
jt_optimizer2 = jt.nn.SGD(jittor_model2.parameters(), lr = lr)
|
||||
|
||||
jittor_result = jittor_model(jittor_test_img)
|
||||
loss = jittor_result.sum()
|
||||
jt_optimizer.step(loss)
|
||||
jt.sync_all(True)
|
||||
|
||||
jittor_result2 = jittor_model2(jittor_test_img)
|
||||
loss = jittor_result2.sum()
|
||||
jt_optimizer2.step(loss)
|
||||
jt.sync_all(True)
|
||||
compare_parameters(jittor_model, jittor_model2)
|
||||
|
||||
x = jittor_result2.data + 1e-8
|
||||
y = jittor_result.data + 1e-8
|
||||
relative_error = abs(x - y) / abs(y)
|
||||
diff = relative_error.mean()
|
||||
assert diff < 1e-4, (diff, 'forword')
|
||||
|
||||
jt.clean()
|
||||
jt.gc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue