From 895f6ecee295de8a87f1feafe521af7f5831eaf4 Mon Sep 17 00:00:00 2001 From: cxjyxx_me <498731903@qq.com> Date: Mon, 7 Dec 2020 22:43:11 +0800 Subject: [PATCH] depthwise_conv --- python/jittor/compile_extern.py | 1 + python/jittor/depthwise_conv.py | 280 +++++++++++----------- python/jittor/nn.py | 17 +- python/jittor/test/test_depthwise_conv.py | 82 +++++++ 4 files changed, 234 insertions(+), 146 deletions(-) create mode 100644 python/jittor/test/test_depthwise_conv.py diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 24f7690f..6e50d2ce 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -98,6 +98,7 @@ def install_cub(root_folder): return dirname def setup_cub(): + global cub_home from pathlib import Path cub_path = os.path.join(str(Path.home()), ".cache", "jittor", "cub") cub_home = install_cub(cub_path) diff --git a/python/jittor/depthwise_conv.py b/python/jittor/depthwise_conv.py index e0552a75..72eb8f5d 100644 --- a/python/jittor/depthwise_conv.py +++ b/python/jittor/depthwise_conv.py @@ -1,80 +1,139 @@ - import jittor as jt from jittor import init from jittor import nn +from jittor import Function -def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_conv=0): - N,C,H,W = data.shape - o,i,Kh,Kw = weights.shape - assert(i == 1) - assert(o == C) - if not isinstance(stride, tuple): - stride = (stride, stride) +class DepthwiseConv(Function): + def __init__(self, stride=1, padding=0, dilation=1): + self.stride = stride if isinstance(stride, tuple) else (stride, stride) + self.padding = padding if isinstance(padding, tuple) else (padding, padding) + self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) - oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1 - ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1 - filter_height, filter_width = Kh, Kw - output = jt.code( - [N, C, oh, ow], - data.dtype, - [data, weights], - cuda_header = """ - template - __global__ void KernelDepthwiseConv( - const T *const input_data, const T *const filter_data, const int batch_size, - const int output_channels, const int output_height, - const int output_width, const int input_channels, - const int input_height, const int input_width, - const int padding_height, const int padding_width, - const int dilate_height, const int dilate_width, T *const output_data) { - const int kWeghtSize = filter_height * filter_width; - T r_weight[kWeghtSize]; - const int batch = blockIdx.y; - const int c_out = blockIdx.x; - const T* weight = filter_data + c_out * filter_height * filter_width; - for (int i = 0; i < filter_height * filter_width; i++) r_weight[i] = weight[i]; + def execute(self, x, weight): + self.save_vars = x, weight + N,C,H,W = x.shape + o,i,Kh,Kw = weight.shape + assert(o == C) + oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 + ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 + filter_height, filter_width = Kh, Kw + self.Khw = Kh, Kw + output = jt.code( + [N, C, oh, ow], + x.dtype, + [x, weight], + cuda_header = """ + template + __global__ void KernelDepthwiseConv( + const T *const input_data, const T *const filter_data, const int batch_size, + const int output_channels, const int output_height, + const int output_width, const int input_channels, + const int input_height, const int input_width, + const int padding_height, const int padding_width, + const int dilate_height, const int dilate_width, T *const output_data) { + const int kWeghtSize = filter_height * filter_width; + T r_weight[kWeghtSize]; + const int batch = blockIdx.y; + const int c_out = blockIdx.x; + const T* weight = filter_data + c_out * filter_height * filter_width; + for (int i = 0; i < filter_height * filter_width; i++) r_weight[i] = weight[i]; - for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) { - for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) { - const int batch = blockIdx.y; - const int c_out = blockIdx.x; + for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) { + for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) { + const int batch = blockIdx.y; + const int c_out = blockIdx.x; - const int c_in = c_out; - T value = 0; - const int h_in_start = -padding_height + h_out * stride_height; - const int w_in_start = -padding_width + w_out * stride_width; - const int h_in_end = h_in_start + filter_height * dilate_height; - const int w_in_end = w_in_start + filter_width * dilate_width; + const int c_in = c_out; + T value = 0; + const int h_in_start = -padding_height + h_out * stride_height; + const int w_in_start = -padding_width + w_out * stride_width; + const int h_in_end = h_in_start + filter_height * dilate_height; + const int w_in_end = w_in_start + filter_width * dilate_width; - const int in_offset = - ((batch * input_channels + c_in) * input_height) * input_width; + const int in_offset = + ((batch * input_channels + c_in) * input_height) * input_width; - const int h_end = h_in_end < input_height ? h_in_end : input_height; - const int w_end = w_in_end < input_width ? w_in_end : input_width; - const int h_start = h_in_start > 0 ? h_in_start : 0; - const int w_start = w_in_start > 0 ? w_in_start : 0; + const int h_end = h_in_end < input_height ? h_in_end : input_height; + const int w_end = w_in_end < input_width ? w_in_end : input_width; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int w_start = w_in_start > 0 ? w_in_start : 0; - for (int h_in = h_in_start, h_f = 0; h_f < filter_height; - h_in += dilate_height, h_f++) { - for (int w_in = w_in_start, w_f = 0; w_f < filter_width; - w_in += dilate_width, w_f++) { - if (h_in >= 0 && h_in < input_height && w_in >= 0 && - w_in < input_width) { - const int offset = in_offset + h_in * input_width + w_in; - value += r_weight[h_f * filter_width + w_f] * input_data[offset]; + for (int h_in = h_in_start, h_f = 0; h_f < filter_height; + h_in += dilate_height, h_f++) { + for (int w_in = w_in_start, w_f = 0; w_f < filter_width; + w_in += dilate_width, w_f++) { + if (h_in >= 0 && h_in < input_height && w_in >= 0 && + w_in < input_width) { + const int offset = in_offset + h_in * input_width + w_in; + value += r_weight[h_f * filter_width + w_f] * input_data[offset]; + } } } + int index = + ((batch * gridDim.x + c_out) * output_height + h_out) * output_width + + w_out; + output_data[index] = value; } - int index = - ((batch * gridDim.x + c_out) * output_height + h_out) * output_width + - w_out; - output_data[index] = value; } } + """, + cuda_src=f""" + @alias(input, in0) + @alias(filter, in1) + @alias(output, out) + + const int batch_size = input_shape0; + const int input_channels = input_shape1; + const int input_height = input_shape2; + const int input_width = input_shape3; + const int output_channels = output_shape1; + const int output_height = output_shape2; + const int output_width = output_shape3; + const int ksize_height = {Kh}; + const int ksize_width = {Kw}; + const int stride_height = {self.stride[0]}; + const int stride_width = {self.stride[1]}; + const int padding_height = {self.padding[0]}; + const int padding_width = {self.padding[1]}; + const int dilate_height = {self.dilation[0]}; + const int dilate_width = {self.dilation[1]}; + + int thread = 512; + if (output_width > 1024 && output_width <= 2048) + thread = (output_width - 1) / 2 + 1; + else if (output_width > 512 && output_width <= 1024) + thread = output_width; + int blocks = std::min(std::max(thread / output_width, 1), output_height); + dim3 threads(std::min(output_width, thread), blocks, 1); + dim3 grid(output_channels, batch_size, 1); + KernelDepthwiseConv< + input_type, ksize_height, ksize_width, + stride_height, stride_width> + <<>>( + input_p, filter_p, batch_size, output_channels, output_height, + output_width, input_channels, input_height, input_width, + padding_height, padding_width, dilate_height, + dilate_width, output_p); + """ + ) + return output + + def grad(self, grad): + x, weight = self.save_vars + Kh, Kw = self.Khw + return jt.code([x.shape, weight.shape], [x.dtype, weight.dtype], [x, weight, grad], + cuda_header = f"#include <{jt.compile_extern.cub_home}/cub/cub.cuh>"+""" + template + __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { + typedef cub::WarpReduce WarpReduce; + typename WarpReduce::TempStorage temp_storage; + value = WarpReduce(temp_storage).Sum(value); + if (cub::LaneId() == 0) + atomicAdd(sum, value); } // CUDA kernel to compute the depthwise convolution backprop w.r.t input. @@ -153,11 +212,12 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co const int dilate_width, T* filter_grad_data) { T s = 0; - int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; + int gbid = (((blockIdx.z * blockDim.z + threadIdx.z) * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; for (int image_w = threadIdx.x; image_w < output_width; image_w += blockDim.x) { for (int bid = 0; bid < num; bid++) { + //for (int bid = threadIdx.z; bid < num; bid+=blockDim.z) { for (int image_h = threadIdx.y; image_h < output_height; image_h += blockDim.y) { int kernel_id = blockIdx.z; @@ -183,56 +243,16 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co } } } - atomicAdd(&filter_grad_data[gbid], s); + CudaAtomicAddWithWarp(&filter_grad_data[gbid], s); } - """, - + """, cuda_src=f""" - @alias(input, in0) - @alias(filter, in1) - @alias(output, out0) - - const int batch_size = input_shape0; - const int input_channels = input_shape1; - const int input_height = input_shape2; - const int input_width = input_shape3; - const int output_channels = output_shape1; - const int output_height = output_shape2; - const int output_width = output_shape3; - const int ksize_height = {Kh}; - const int ksize_width = {Kw}; - const int stride_height = {stride[0]}; - const int stride_width = {stride[1]}; - const int padding_height = {padding[0]}; - const int padding_width = {padding[1]}; - const int dilate_height = {dilation[0]}; - const int dilate_width = {dilation[1]}; - - int thread = 512; - if (output_width > 1024 && output_width <= 2048) - thread = (output_width - 1) / 2 + 1; - else if (output_width > 512 && output_width <= 1024) - thread = output_width; - int blocks = std::min(std::max(thread / output_width, 1), output_height); - dim3 threads(std::min(output_width, thread), blocks, 1); - dim3 grid(output_channels, batch_size, 1); - KernelDepthwiseConv< - input_type, ksize_height, ksize_width, - stride_height, stride_width> - <<>>( - input_p, filter_p, batch_size, output_channels, output_height, - output_width, input_channels, input_height, input_width, - padding_height, padding_width, dilate_height, - dilate_width, output_p); - """, - - cuda_grad_src=[ - f""" // source for backward to data @alias(input, in0) @alias(filter, in1) - @alias(input_grad, out) @alias(output_grad, in2) + @alias(input_grad, out0) + @alias(filter_grad, out1) const int batch_size = input_shape0; const int input_channels = input_shape1; @@ -243,12 +263,12 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co const int output_width = output_grad_shape3; const int ksize_height = {Kh}; const int ksize_width = {Kw}; - const int stride_height = {stride[0]}; - const int stride_width = {stride[1]}; - const int padding_height = {padding[0]}; - const int padding_width = {padding[1]}; - const int dilate_height = {dilation[0]}; - const int dilate_width = {dilation[1]}; + const int stride_height = {self.stride[0]}; + const int stride_width = {self.stride[1]}; + const int padding_height = {self.padding[0]}; + const int padding_width = {self.padding[1]}; + const int dilate_height = {self.dilation[0]}; + const int dilate_width = {self.dilation[1]}; int thread = 512; if (input_width > 1024 && input_width <= 2048) @@ -266,29 +286,9 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co output_channels, output_height, output_width, input_channels, input_height, input_width, padding_height, padding_width, dilate_height, dilate_width, input_grad_p); - """, - f""" + // source for backward to filter - @alias(input, in0) - @alias(filter_grad, out) - @alias(output_grad, in2) - - const int batch_size = input_shape0; - const int input_channels = input_shape1; - const int input_height = input_shape2; - const int input_width = input_shape3; - const int output_channels = output_grad_shape1; - const int output_height = output_grad_shape2; - const int output_width = output_grad_shape3; - const int ksize_height = {Kh}; - const int ksize_width = {Kw}; - const int stride_height = {stride[0]}; - const int stride_width = {stride[1]}; - const int padding_height = {padding[0]}; - const int padding_width = {padding[1]}; - const int dilate_height = {dilation[0]}; - const int dilate_width = {dilation[1]}; - + int block_size = 512; if (output_width > 1024 && output_width <= 2048) block_size = (output_width - 1) / 2 + 1; @@ -296,8 +296,10 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co block_size = output_width; int crop_output_height = std::min(std::max(block_size / output_width, 1), output_height); - dim3 grid(ksize_width, ksize_height, output_channels); - dim3 threads(std::min(output_width, block_size), crop_output_height, 1); + + grid = dim3(ksize_width, ksize_height, output_channels); + threads = dim3(std::min(output_width, block_size), crop_output_height, 1); + KernelDepthwiseConvFilterGrad< input_type><<>>( @@ -307,6 +309,4 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co stride_height, stride_width, padding_height, padding_width, dilate_height, dilate_width, filter_grad_p); """ - ] - ) - return output \ No newline at end of file + ) \ No newline at end of file diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 78e1baaf..041385b1 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -153,7 +153,7 @@ def get_init_var_rand(shape, dtype): def relu(x): return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x)) def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale) -def relu6(x): return jt.minimum(jt.maximum(x, 0), 6) +def relu6(x): return jt.minimum(jt.maximum(x, 0.0), 6.0) def sign(x): one = jt.ones(x.shape) x = jt.ternary(x>0, one, x) @@ -473,7 +473,7 @@ ReLU6 = jt.make_module(relu6) Softmax = jt.make_module(softmax, 2) GELU = jt.make_module(gelu) -from jittor.depthwise_conv import depthwise_conv +from jittor.depthwise_conv import DepthwiseConv class Conv(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): @@ -485,7 +485,8 @@ class Conv(Module): self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) self.groups = groups self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels - # self.is_depthwise_conv = False + if self.is_depthwise_conv and jt.flags.use_cuda: + self.depthwise_conv = DepthwiseConv(stride, padding, dilation) assert in_channels % groups == 0, 'in_channels must be divisible by groups' assert out_channels % groups == 0, 'out_channels must be divisible by groups' Kh, Kw = self.kernel_size @@ -505,8 +506,12 @@ class Conv(Module): self.bias = None def execute(self, x): - if self.is_depthwise_conv: - return depthwise_conv(x, self.weight, self.padding, self.dilation, self.stride) + if self.is_depthwise_conv and jt.flags.use_cuda: + y = self.depthwise_conv(x, self.weight) + if self.bias is not None: + b = self.bias.broadcast(y.shape, [0,2,3]) + y = y + b + return y elif self.groups == 1: N,C,H,W = x.shape Kh, Kw = self.kernel_size @@ -541,7 +546,6 @@ class Conv(Module): f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid ]) - xx.compile_options = {"G":G} # w: [oc, CpG, Kh, Kw] ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [ f'i1*{oc//G}+i2', @@ -549,6 +553,7 @@ class Conv(Module): 'i6', 'i7' ]) + ww.compile_options = xx.compile_options = {"G":G,"C":C} yy = xx*ww y = yy.reindex_reduce('add', [N, oc, oh, ow], [ 'i0', diff --git a/python/jittor/test/test_depthwise_conv.py b/python/jittor/test/test_depthwise_conv.py new file mode 100644 index 00000000..c6af5f1e --- /dev/null +++ b/python/jittor/test/test_depthwise_conv.py @@ -0,0 +1,82 @@ +# *************************************************************** +# Copyright (c) 2020 Jittor. Authors: Dun Liang . All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.models as jtmodels + +def load_parameters(m1, m2): + m1.save('temp.pk') + m2.load('temp.pk') + +def compare_parameters(m1, m2): + ps1 = m1.parameters() + ps2 = m2.parameters() + for i in range(len(ps1)): + x = ps1[i].data + 1e-8 + y = ps2[i].data + 1e-8 + relative_error = abs(x - y) / abs(y) + diff = relative_error.mean() + assert diff < 1e-4, (diff, 'backward', ps2[i].name()) + +class TestDepthwiseConv(unittest.TestCase): + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_data(self): + test_img = np.random.random((64,3,224,224)).astype('float32') + jittor_test_img = jt.array(test_img) + lr = 100 + + jittor_model = jtmodels.__dict__['mobilenet_v2']() + jittor_model2 = jtmodels.__dict__['mobilenet_v2']() + # Set eval to avoid dropout layer & bn errors + jittor_model.train() + jittor_model.classifier[0].eval() + for m in jittor_model.modules(): + if isinstance(m, jt.nn.BatchNorm): + m.eval() + + jittor_model2.train() + jittor_model2.classifier[0].eval() + for m in jittor_model2.modules(): + if isinstance(m, jt.nn.BatchNorm): + m.eval() + + load_parameters(jittor_model2, jittor_model) + for m in jittor_model.modules(): + if isinstance(m, jt.nn.Conv): + m.is_depthwise_conv = False + cnt = 0 + for m in jittor_model2.modules(): + if isinstance(m, jt.nn.Conv): + if (m.is_depthwise_conv): + cnt += 1 + assert cnt == 17, (cnt, '!=', 17) + jt_optimizer = jt.nn.SGD(jittor_model.parameters(), lr = lr) + jt_optimizer2 = jt.nn.SGD(jittor_model2.parameters(), lr = lr) + + jittor_result = jittor_model(jittor_test_img) + loss = jittor_result.sum() + jt_optimizer.step(loss) + jt.sync_all(True) + + jittor_result2 = jittor_model2(jittor_test_img) + loss = jittor_result2.sum() + jt_optimizer2.step(loss) + jt.sync_all(True) + compare_parameters(jittor_model, jittor_model2) + + x = jittor_result2.data + 1e-8 + y = jittor_result.data + 1e-8 + relative_error = abs(x - y) / abs(y) + diff = relative_error.mean() + assert diff < 1e-4, (diff, 'forword') + + jt.clean() + jt.gc() + +if __name__ == "__main__": + unittest.main() \ No newline at end of file